fix: reflection agent runner

This commit is contained in:
Richard Tang
2026-04-07 13:07:41 -07:00
parent c9d62139af
commit 3dbd20040a
2 changed files with 125 additions and 37 deletions
+104 -37
View File
@@ -12,6 +12,9 @@ Two reflection types:
Concurrency: an ``asyncio.Lock`` prevents overlapping runs. If a trigger
fires while a reflection is already active the event is skipped.
All reflections are fire-and-forget (spawned via ``asyncio.create_task``)
so they never block the queen's event loop.
"""
from __future__ import annotations
@@ -217,7 +220,7 @@ async def _reflection_loop(
last_text: str = ""
for _turn in range(max_turns):
logger.debug("reflect: loop turn %d/%d", _turn + 1, max_turns)
logger.info("reflect: loop turn %d/%d (msgs=%d)", _turn + 1, max_turns, len(messages))
try:
resp: LLMResponse = await llm.acomplete(
messages=messages,
@@ -225,13 +228,40 @@ async def _reflection_loop(
tools=_REFLECTION_TOOLS,
max_tokens=2048,
)
except asyncio.CancelledError:
logger.warning("reflect: LLM call cancelled (task cancelled)")
return False, changed_files, last_text
except Exception:
logger.warning("reflect: LLM call failed", exc_info=True)
return False, changed_files, last_text
# Extract tool calls from litellm/OpenAI response object.
tool_calls_raw: list[dict[str, Any]] = []
if resp.raw_response and isinstance(resp.raw_response, dict):
tool_calls_raw = resp.raw_response.get("tool_calls", [])
raw = resp.raw_response
if raw is not None:
# litellm returns a ModelResponse object; tool calls live on
# choices[0].message.tool_calls as a list of ChatCompletionMessageToolCall.
try:
msg_obj = raw.choices[0].message
if hasattr(msg_obj, "tool_calls") and msg_obj.tool_calls:
for tc in msg_obj.tool_calls:
fn = tc.function
try:
args = json.loads(fn.arguments) if fn.arguments else {}
except (json.JSONDecodeError, TypeError):
args = {}
tool_calls_raw.append({
"id": tc.id,
"name": fn.name,
"input": args,
})
except (AttributeError, IndexError):
pass
logger.info(
"reflect: LLM responded, text=%d chars, tool_calls=%d",
len(resp.content or ""), len(tool_calls_raw),
)
turn_text = resp.content or ""
if turn_text:
@@ -353,11 +383,12 @@ async def run_short_reflection(
memory_dir: Path | None = None,
) -> None:
"""Run a short reflection: extract user knowledge from conversation."""
logger.debug("reflect: starting short reflection")
logger.info("reflect: starting short reflection for %s", session_dir)
mem_dir = memory_dir or global_memory_dir()
messages = await _read_conversation_parts(session_dir)
if not messages:
logger.info("reflect: no conversation parts found in %s, skipping", session_dir)
return
transcript_lines: list[str] = []
@@ -372,6 +403,7 @@ async def run_short_reflection(
transcript_lines.append(f"[{label}]: {content}")
if not transcript_lines:
logger.info("reflect: no transcript lines after filtering, skipping")
return
transcript = "\n".join(transcript_lines)
@@ -383,9 +415,9 @@ async def run_short_reflection(
_, changed, reason = await _reflection_loop(llm, _SHORT_REFLECT_SYSTEM, user_msg, mem_dir)
if changed:
logger.debug("reflect: short reflection done, changed files: %s", changed)
logger.info("reflect: short reflection done, changed files: %s", changed)
else:
logger.debug("reflect: short reflection done, no changes — %s", reason or "no reason")
logger.info("reflect: short reflection done, no changes — %s", reason or "no reason")
async def run_long_reflection(
@@ -419,6 +451,28 @@ async def run_long_reflection(
)
async def run_shutdown_reflection(
session_dir: Path,
llm: Any,
memory_dir: Path | None = None,
) -> None:
"""Run a final short reflection on session shutdown.
Called during session teardown so recent conversation insights are
persisted before the session is destroyed.
"""
logger.info("reflect: running shutdown reflection for %s", session_dir)
mem_dir = memory_dir or global_memory_dir()
try:
await run_short_reflection(session_dir, llm, mem_dir)
logger.info("reflect: shutdown reflection completed for %s", session_dir)
except asyncio.CancelledError:
logger.warning("reflect: shutdown reflection cancelled for %s", session_dir)
except Exception:
logger.warning("reflect: shutdown reflection failed", exc_info=True)
_write_error("shutdown reflection")
# ---------------------------------------------------------------------------
# Event-bus integration
# ---------------------------------------------------------------------------
@@ -443,32 +497,9 @@ async def subscribe_reflection_triggers(
mem_dir = memory_dir or global_memory_dir()
_lock = asyncio.Lock()
_short_count = 0
_background_tasks: set[asyncio.Task] = set()
async def _on_turn_complete(event: Any) -> None:
nonlocal _short_count
if getattr(event, "stream_id", None) != "queen":
return
_short_count += 1
event_data = getattr(event, "data", {}) or {}
stop_reason = event_data.get("stop_reason", "")
is_tool_turn = stop_reason in ("tool_use", "tool_calls")
is_interval = _short_count % _LONG_REFLECT_INTERVAL == 0
if is_tool_turn and not is_interval:
logger.debug("reflect: skipping tool turn (count=%d)", _short_count)
return
if _lock.locked():
logger.debug("reflect: skipping, already running (count=%d)", _short_count)
return
logger.debug(
"reflect: triggered (count=%d, interval=%s, stop_reason=%s)",
_short_count, is_interval, stop_reason,
)
async def _do_turn_reflect(is_interval: bool, count: int) -> None:
async with _lock:
try:
if is_interval:
@@ -497,6 +528,47 @@ async def subscribe_reflection_triggers(
except Exception:
logger.debug("recall: cache update failed", exc_info=True)
async def _do_compaction_reflect() -> None:
async with _lock:
try:
await run_long_reflection(llm, mem_dir)
except Exception:
logger.warning("reflect: compaction-triggered reflection failed", exc_info=True)
_write_error("compaction reflection")
def _fire_and_forget(coro: Any) -> None:
"""Spawn a background task and prevent GC before it finishes."""
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def _on_turn_complete(event: Any) -> None:
nonlocal _short_count
if getattr(event, "stream_id", None) != "queen":
return
_short_count += 1
event_data = getattr(event, "data", {}) or {}
stop_reason = event_data.get("stop_reason", "")
is_tool_turn = stop_reason in ("tool_use", "tool_calls")
is_interval = _short_count % _LONG_REFLECT_INTERVAL == 0
if is_tool_turn and not is_interval:
logger.debug("reflect: skipping tool turn (count=%d)", _short_count)
return
if _lock.locked():
logger.debug("reflect: skipping, already running (count=%d)", _short_count)
return
logger.debug(
"reflect: triggered (count=%d, interval=%s, stop_reason=%s)",
_short_count, is_interval, stop_reason,
)
_fire_and_forget(_do_turn_reflect(is_interval, _short_count))
async def _on_compaction(event: Any) -> None:
if getattr(event, "stream_id", None) != "queen":
return
@@ -504,12 +576,7 @@ async def subscribe_reflection_triggers(
logger.debug("reflect: skipping compaction trigger, already running")
return
logger.debug("reflect: compaction triggered long reflection")
async with _lock:
try:
await run_long_reflection(llm, mem_dir)
except Exception:
logger.warning("reflect: compaction-triggered reflection failed", exc_info=True)
_write_error("compaction reflection")
_fire_and_forget(_do_compaction_reflect())
sub_ids: list[str] = []
+21
View File
@@ -65,6 +65,8 @@ class Session:
# directory instead of creating a new one. This lets cold-restores accumulate
# all messages in the original session folder so history is never fragmented.
queen_resume_from: str | None = None
# Queen session directory (set during _start_queen, used for shutdown reflection)
queen_dir: Path | None = None
class SessionManager:
@@ -80,6 +82,9 @@ class SessionManager:
self._model = model
self._credential_store = credential_store
self._lock = asyncio.Lock()
# Strong references for fire-and-forget background tasks (e.g. shutdown
# reflections) so they aren't garbage-collected before completion.
self._background_tasks: set[asyncio.Task] = set()
# ------------------------------------------------------------------
# Session lifecycle
@@ -631,6 +636,21 @@ class SessionManager:
pass
session.memory_reflection_subs.clear()
# Run a final shutdown reflection so recent conversation insights
# are persisted before the session is destroyed (fire-and-forget).
if session.queen_dir is not None:
try:
from framework.agents.queen.reflection_agent import run_shutdown_reflection
task = asyncio.create_task(
asyncio.shield(run_shutdown_reflection(session.queen_dir, session.llm)),
)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
logger.info("Session '%s': shutdown reflection spawned", session_id)
except Exception:
logger.warning("Session '%s': failed to spawn shutdown reflection", session_id, exc_info=True)
if session.queen_task is not None:
session.queen_task.cancel()
session.queen_task = None
@@ -741,6 +761,7 @@ class SessionManager:
storage_session_id = session.queen_resume_from or session.id
queen_dir = hive_home / "queen" / "session" / storage_session_id
queen_dir.mkdir(parents=True, exist_ok=True)
session.queen_dir = queen_dir
# Always write/update session metadata so history sidebar has correct
# agent name, path, and last-active timestamp (important so the original