fix: reflection agent runner
This commit is contained in:
@@ -12,6 +12,9 @@ Two reflection types:
|
||||
|
||||
Concurrency: an ``asyncio.Lock`` prevents overlapping runs. If a trigger
|
||||
fires while a reflection is already active the event is skipped.
|
||||
|
||||
All reflections are fire-and-forget (spawned via ``asyncio.create_task``)
|
||||
so they never block the queen's event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -217,7 +220,7 @@ async def _reflection_loop(
|
||||
last_text: str = ""
|
||||
|
||||
for _turn in range(max_turns):
|
||||
logger.debug("reflect: loop turn %d/%d", _turn + 1, max_turns)
|
||||
logger.info("reflect: loop turn %d/%d (msgs=%d)", _turn + 1, max_turns, len(messages))
|
||||
try:
|
||||
resp: LLMResponse = await llm.acomplete(
|
||||
messages=messages,
|
||||
@@ -225,13 +228,40 @@ async def _reflection_loop(
|
||||
tools=_REFLECTION_TOOLS,
|
||||
max_tokens=2048,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("reflect: LLM call cancelled (task cancelled)")
|
||||
return False, changed_files, last_text
|
||||
except Exception:
|
||||
logger.warning("reflect: LLM call failed", exc_info=True)
|
||||
return False, changed_files, last_text
|
||||
|
||||
# Extract tool calls from litellm/OpenAI response object.
|
||||
tool_calls_raw: list[dict[str, Any]] = []
|
||||
if resp.raw_response and isinstance(resp.raw_response, dict):
|
||||
tool_calls_raw = resp.raw_response.get("tool_calls", [])
|
||||
raw = resp.raw_response
|
||||
if raw is not None:
|
||||
# litellm returns a ModelResponse object; tool calls live on
|
||||
# choices[0].message.tool_calls as a list of ChatCompletionMessageToolCall.
|
||||
try:
|
||||
msg_obj = raw.choices[0].message
|
||||
if hasattr(msg_obj, "tool_calls") and msg_obj.tool_calls:
|
||||
for tc in msg_obj.tool_calls:
|
||||
fn = tc.function
|
||||
try:
|
||||
args = json.loads(fn.arguments) if fn.arguments else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
tool_calls_raw.append({
|
||||
"id": tc.id,
|
||||
"name": fn.name,
|
||||
"input": args,
|
||||
})
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
"reflect: LLM responded, text=%d chars, tool_calls=%d",
|
||||
len(resp.content or ""), len(tool_calls_raw),
|
||||
)
|
||||
|
||||
turn_text = resp.content or ""
|
||||
if turn_text:
|
||||
@@ -353,11 +383,12 @@ async def run_short_reflection(
|
||||
memory_dir: Path | None = None,
|
||||
) -> None:
|
||||
"""Run a short reflection: extract user knowledge from conversation."""
|
||||
logger.debug("reflect: starting short reflection")
|
||||
logger.info("reflect: starting short reflection for %s", session_dir)
|
||||
mem_dir = memory_dir or global_memory_dir()
|
||||
|
||||
messages = await _read_conversation_parts(session_dir)
|
||||
if not messages:
|
||||
logger.info("reflect: no conversation parts found in %s, skipping", session_dir)
|
||||
return
|
||||
|
||||
transcript_lines: list[str] = []
|
||||
@@ -372,6 +403,7 @@ async def run_short_reflection(
|
||||
transcript_lines.append(f"[{label}]: {content}")
|
||||
|
||||
if not transcript_lines:
|
||||
logger.info("reflect: no transcript lines after filtering, skipping")
|
||||
return
|
||||
|
||||
transcript = "\n".join(transcript_lines)
|
||||
@@ -383,9 +415,9 @@ async def run_short_reflection(
|
||||
|
||||
_, changed, reason = await _reflection_loop(llm, _SHORT_REFLECT_SYSTEM, user_msg, mem_dir)
|
||||
if changed:
|
||||
logger.debug("reflect: short reflection done, changed files: %s", changed)
|
||||
logger.info("reflect: short reflection done, changed files: %s", changed)
|
||||
else:
|
||||
logger.debug("reflect: short reflection done, no changes — %s", reason or "no reason")
|
||||
logger.info("reflect: short reflection done, no changes — %s", reason or "no reason")
|
||||
|
||||
|
||||
async def run_long_reflection(
|
||||
@@ -419,6 +451,28 @@ async def run_long_reflection(
|
||||
)
|
||||
|
||||
|
||||
async def run_shutdown_reflection(
|
||||
session_dir: Path,
|
||||
llm: Any,
|
||||
memory_dir: Path | None = None,
|
||||
) -> None:
|
||||
"""Run a final short reflection on session shutdown.
|
||||
|
||||
Called during session teardown so recent conversation insights are
|
||||
persisted before the session is destroyed.
|
||||
"""
|
||||
logger.info("reflect: running shutdown reflection for %s", session_dir)
|
||||
mem_dir = memory_dir or global_memory_dir()
|
||||
try:
|
||||
await run_short_reflection(session_dir, llm, mem_dir)
|
||||
logger.info("reflect: shutdown reflection completed for %s", session_dir)
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("reflect: shutdown reflection cancelled for %s", session_dir)
|
||||
except Exception:
|
||||
logger.warning("reflect: shutdown reflection failed", exc_info=True)
|
||||
_write_error("shutdown reflection")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event-bus integration
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -443,32 +497,9 @@ async def subscribe_reflection_triggers(
|
||||
mem_dir = memory_dir or global_memory_dir()
|
||||
_lock = asyncio.Lock()
|
||||
_short_count = 0
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def _on_turn_complete(event: Any) -> None:
|
||||
nonlocal _short_count
|
||||
|
||||
if getattr(event, "stream_id", None) != "queen":
|
||||
return
|
||||
|
||||
_short_count += 1
|
||||
|
||||
event_data = getattr(event, "data", {}) or {}
|
||||
stop_reason = event_data.get("stop_reason", "")
|
||||
is_tool_turn = stop_reason in ("tool_use", "tool_calls")
|
||||
is_interval = _short_count % _LONG_REFLECT_INTERVAL == 0
|
||||
|
||||
if is_tool_turn and not is_interval:
|
||||
logger.debug("reflect: skipping tool turn (count=%d)", _short_count)
|
||||
return
|
||||
|
||||
if _lock.locked():
|
||||
logger.debug("reflect: skipping, already running (count=%d)", _short_count)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"reflect: triggered (count=%d, interval=%s, stop_reason=%s)",
|
||||
_short_count, is_interval, stop_reason,
|
||||
)
|
||||
async def _do_turn_reflect(is_interval: bool, count: int) -> None:
|
||||
async with _lock:
|
||||
try:
|
||||
if is_interval:
|
||||
@@ -497,6 +528,47 @@ async def subscribe_reflection_triggers(
|
||||
except Exception:
|
||||
logger.debug("recall: cache update failed", exc_info=True)
|
||||
|
||||
async def _do_compaction_reflect() -> None:
|
||||
async with _lock:
|
||||
try:
|
||||
await run_long_reflection(llm, mem_dir)
|
||||
except Exception:
|
||||
logger.warning("reflect: compaction-triggered reflection failed", exc_info=True)
|
||||
_write_error("compaction reflection")
|
||||
|
||||
def _fire_and_forget(coro: Any) -> None:
|
||||
"""Spawn a background task and prevent GC before it finishes."""
|
||||
task = asyncio.create_task(coro)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
async def _on_turn_complete(event: Any) -> None:
|
||||
nonlocal _short_count
|
||||
|
||||
if getattr(event, "stream_id", None) != "queen":
|
||||
return
|
||||
|
||||
_short_count += 1
|
||||
|
||||
event_data = getattr(event, "data", {}) or {}
|
||||
stop_reason = event_data.get("stop_reason", "")
|
||||
is_tool_turn = stop_reason in ("tool_use", "tool_calls")
|
||||
is_interval = _short_count % _LONG_REFLECT_INTERVAL == 0
|
||||
|
||||
if is_tool_turn and not is_interval:
|
||||
logger.debug("reflect: skipping tool turn (count=%d)", _short_count)
|
||||
return
|
||||
|
||||
if _lock.locked():
|
||||
logger.debug("reflect: skipping, already running (count=%d)", _short_count)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"reflect: triggered (count=%d, interval=%s, stop_reason=%s)",
|
||||
_short_count, is_interval, stop_reason,
|
||||
)
|
||||
_fire_and_forget(_do_turn_reflect(is_interval, _short_count))
|
||||
|
||||
async def _on_compaction(event: Any) -> None:
|
||||
if getattr(event, "stream_id", None) != "queen":
|
||||
return
|
||||
@@ -504,12 +576,7 @@ async def subscribe_reflection_triggers(
|
||||
logger.debug("reflect: skipping compaction trigger, already running")
|
||||
return
|
||||
logger.debug("reflect: compaction triggered long reflection")
|
||||
async with _lock:
|
||||
try:
|
||||
await run_long_reflection(llm, mem_dir)
|
||||
except Exception:
|
||||
logger.warning("reflect: compaction-triggered reflection failed", exc_info=True)
|
||||
_write_error("compaction reflection")
|
||||
_fire_and_forget(_do_compaction_reflect())
|
||||
|
||||
sub_ids: list[str] = []
|
||||
|
||||
|
||||
@@ -65,6 +65,8 @@ class Session:
|
||||
# directory instead of creating a new one. This lets cold-restores accumulate
|
||||
# all messages in the original session folder so history is never fragmented.
|
||||
queen_resume_from: str | None = None
|
||||
# Queen session directory (set during _start_queen, used for shutdown reflection)
|
||||
queen_dir: Path | None = None
|
||||
|
||||
|
||||
class SessionManager:
|
||||
@@ -80,6 +82,9 @@ class SessionManager:
|
||||
self._model = model
|
||||
self._credential_store = credential_store
|
||||
self._lock = asyncio.Lock()
|
||||
# Strong references for fire-and-forget background tasks (e.g. shutdown
|
||||
# reflections) so they aren't garbage-collected before completion.
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session lifecycle
|
||||
@@ -631,6 +636,21 @@ class SessionManager:
|
||||
pass
|
||||
session.memory_reflection_subs.clear()
|
||||
|
||||
# Run a final shutdown reflection so recent conversation insights
|
||||
# are persisted before the session is destroyed (fire-and-forget).
|
||||
if session.queen_dir is not None:
|
||||
try:
|
||||
from framework.agents.queen.reflection_agent import run_shutdown_reflection
|
||||
|
||||
task = asyncio.create_task(
|
||||
asyncio.shield(run_shutdown_reflection(session.queen_dir, session.llm)),
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
logger.info("Session '%s': shutdown reflection spawned", session_id)
|
||||
except Exception:
|
||||
logger.warning("Session '%s': failed to spawn shutdown reflection", session_id, exc_info=True)
|
||||
|
||||
if session.queen_task is not None:
|
||||
session.queen_task.cancel()
|
||||
session.queen_task = None
|
||||
@@ -741,6 +761,7 @@ class SessionManager:
|
||||
storage_session_id = session.queen_resume_from or session.id
|
||||
queen_dir = hive_home / "queen" / "session" / storage_session_id
|
||||
queen_dir.mkdir(parents=True, exist_ok=True)
|
||||
session.queen_dir = queen_dir
|
||||
|
||||
# Always write/update session metadata so history sidebar has correct
|
||||
# agent name, path, and last-active timestamp (important so the original
|
||||
|
||||
Reference in New Issue
Block a user