feat: simplify worker reflection

This commit is contained in:
Richard Tang
2026-04-03 13:03:47 -07:00
parent 4f588b3010
commit 771efd5ce4
8 changed files with 172 additions and 95 deletions
@@ -623,12 +623,16 @@ async def subscribe_worker_memory_triggers(
colony_memory_dir: Path,
recall_cache: dict[str, str],
) -> list[str]:
"""Subscribe shared colony memory reflection/recall for top-level worker runs."""
from framework.agents.queen.recall_selector import update_recall_cache
"""Subscribe colony memory lifecycle events for worker runs.
Short reflection is now handled synchronously at node handoff in
``WorkerAgent._reflect_colony_memory()``. This function only manages:
- Recall cache initialisation on execution start
- Final long reflection + cleanup on execution end
"""
from framework.runtime.event_bus import EventType
_lock = asyncio.Lock()
_short_counts: dict[str, int] = {}
_terminal_lock = asyncio.Lock()
def _is_worker_event(event: Any) -> bool:
return bool(
@@ -636,57 +640,6 @@ async def subscribe_worker_memory_triggers(
and getattr(event, "stream_id", None) not in ("queen", "judge")
)
async def _update_cache(execution_id: str) -> None:
session_dir = worker_sessions_dir / execution_id
await update_recall_cache(
session_dir,
llm,
memory_dir=colony_memory_dir,
cache_setter=lambda block, execution_id=execution_id: recall_cache.__setitem__(
execution_id, block
),
heading="Colony Memories",
)
async def _on_turn_complete(event: Any) -> None:
if not _is_worker_event(event):
return
if _lock.locked():
logger.debug("reflect: worker colony reflection skipped — lock busy")
return
execution_id = event.execution_id
if execution_id is None:
return
session_dir = worker_sessions_dir / execution_id
async with _lock:
try:
_short_counts[execution_id] = _short_counts.get(execution_id, 0) + 1
await run_short_reflection(session_dir, llm, colony_memory_dir)
if _short_counts[execution_id] % _LONG_REFLECT_INTERVAL == 0:
await run_long_reflection(llm, colony_memory_dir)
await _update_cache(execution_id)
except Exception:
logger.warning("reflect: worker colony reflection failed", exc_info=True)
_write_error("worker colony reflection")
async def _on_compaction(event: Any) -> None:
if not _is_worker_event(event):
return
if _lock.locked():
return
execution_id = event.execution_id
if execution_id is None:
return
async with _lock:
try:
await run_long_reflection(llm, colony_memory_dir)
await _update_cache(execution_id)
except Exception:
logger.warning("reflect: worker compaction reflection failed", exc_info=True)
_write_error("worker compaction reflection")
async def _on_execution_started(event: Any) -> None:
if not _is_worker_event(event):
return
@@ -699,7 +652,7 @@ async def subscribe_worker_memory_triggers(
execution_id = event.execution_id
if execution_id is None:
return
async with _lock:
async with _terminal_lock:
try:
await run_long_reflection(llm, colony_memory_dir)
except Exception:
@@ -707,21 +660,12 @@ async def subscribe_worker_memory_triggers(
_write_error("worker final reflection")
finally:
recall_cache.pop(execution_id, None)
_short_counts.pop(execution_id, None)
return [
event_bus.subscribe(
event_types=[EventType.EXECUTION_STARTED],
handler=_on_execution_started,
),
event_bus.subscribe(
event_types=[EventType.LLM_TURN_COMPLETE],
handler=_on_turn_complete,
),
event_bus.subscribe(
event_types=[EventType.CONTEXT_COMPACTED],
handler=_on_compaction,
),
event_bus.subscribe(
event_types=[EventType.EXECUTION_COMPLETED, EventType.EXECUTION_FAILED],
handler=_on_execution_terminal,
+6
View File
@@ -67,6 +67,12 @@ class GraphContext:
# Retry tracking: worker_id → retry_count (for execution quality assessment)
retry_counts: dict[str, int] = field(default_factory=dict)
nodes_with_retries: set[str] = field(default_factory=set)
# Colony memory reflection at node handoff
colony_memory_dir: Any = None # Path | None
worker_sessions_dir: Any = None # Path | None
colony_recall_cache: dict[str, str] = field(default_factory=dict)
colony_reflect_llm: Any = None # LLMProvider for reflection
_colony_reflect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
def build_scoped_buffer(buffer: DataBuffer, node_spec: NodeSpec) -> DataBuffer:
+12
View File
@@ -160,6 +160,10 @@ class GraphExecutor:
skill_dirs: list[str] | None = None,
context_warn_ratio: float | None = None,
batch_init_nudge: str | None = None,
colony_memory_dir: Any = None,
colony_worker_sessions_dir: Any = None,
colony_recall_cache: dict[str, str] | None = None,
colony_reflect_llm: Any = None,
):
"""
Initialize the executor.
@@ -221,6 +225,10 @@ class GraphExecutor:
self.skill_dirs: list[str] = skill_dirs or []
self.context_warn_ratio: float | None = context_warn_ratio
self.batch_init_nudge: str | None = batch_init_nudge
self.colony_memory_dir = colony_memory_dir
self.colony_worker_sessions_dir = colony_worker_sessions_dir
self.colony_recall_cache = colony_recall_cache or {}
self.colony_reflect_llm = colony_reflect_llm
if protocols_prompt:
self.logger.info(
@@ -1318,6 +1326,10 @@ class GraphExecutor:
iteration_metadata_provider=self.iteration_metadata_provider,
loop_config=self._loop_config,
node_visit_counts=dict(node_visit_counts),
colony_memory_dir=self.colony_memory_dir,
worker_sessions_dir=self.colony_worker_sessions_dir,
colony_recall_cache=self.colony_recall_cache,
colony_reflect_llm=self.colony_reflect_llm,
)
# Create one WorkerAgent per node
+56
View File
@@ -318,6 +318,8 @@ class WorkerAgent:
self.lifecycle = WorkerLifecycle.COMPLETED
self._last_result = result
self._last_activations = activations
# Colony memory reflection — runs before downstream activation
await self._reflect_colony_memory()
completion = WorkerCompletion(
worker_id=node_spec.id,
success=True,
@@ -338,6 +340,8 @@ class WorkerAgent:
self.lifecycle = WorkerLifecycle.FAILED
self._last_result = result
self._last_activations = activations
# Colony memory reflection — capture learnings even on failure
await self._reflect_colony_memory()
await self._publish_failure(result.error or "Unknown error")
except Exception as exc:
error = str(exc) or type(exc).__name__
@@ -649,6 +653,58 @@ class WorkerAgent:
pause_event=self._pause_requested,
)
async def _reflect_colony_memory(self) -> None:
"""Run colony memory reflection at node handoff.
Awaits the shared colony lock so parallel workers queue (never skip).
"""
gc = self._gc
if gc.colony_memory_dir is None or gc.colony_reflect_llm is None:
return
if gc.worker_sessions_dir is None:
return
from pathlib import Path
session_dir = Path(gc.worker_sessions_dir) / gc.execution_id
if not session_dir.exists():
return
# Await lock — serializes reflection but never skips
async with gc._colony_reflect_lock:
try:
from framework.agents.queen.reflection_agent import run_short_reflection
await run_short_reflection(
session_dir, gc.colony_reflect_llm, gc.colony_memory_dir
)
except Exception:
logger.warning(
"Worker %s: colony reflection failed",
self.node_spec.id,
exc_info=True,
)
# Update recall cache outside lock (per-execution key, no write races)
try:
from framework.agents.queen.recall_selector import update_recall_cache
await update_recall_cache(
session_dir,
gc.colony_reflect_llm,
memory_dir=gc.colony_memory_dir,
cache_setter=lambda block: gc.colony_recall_cache.__setitem__(
gc.execution_id, block
),
heading="Colony Memories",
)
except Exception:
logger.warning(
"Worker %s: recall cache update failed",
self.node_spec.id,
exc_info=True,
)
# ------------------------------------------------------------------
# Event publishing
# ------------------------------------------------------------------
+9
View File
@@ -239,6 +239,11 @@ class AgentRuntime:
self._tool_executor = tool_executor
self._accounts_prompt = accounts_prompt
self._dynamic_memory_provider_factory: Callable[[str], Callable[[], str] | None] | None = None
# Colony memory config for reflection-at-handoff (set by session_manager)
self._colony_memory_dir: Any = None
self._colony_worker_sessions_dir: Any = None
self._colony_recall_cache: dict[str, str] | None = None
self._colony_reflect_llm: Any = None
self._accounts_data = accounts_data
self._tool_provider_map = tool_provider_map
@@ -362,6 +367,10 @@ class AgentRuntime:
context_warn_ratio=self.context_warn_ratio,
batch_init_nudge=self.batch_init_nudge,
dynamic_memory_provider_factory=self._dynamic_memory_provider_factory,
colony_memory_dir=self._colony_memory_dir,
colony_worker_sessions_dir=self._colony_worker_sessions_dir,
colony_recall_cache=self._colony_recall_cache,
colony_reflect_llm=self._colony_reflect_llm,
)
await stream.start()
self._streams[ep_id] = stream
@@ -192,6 +192,10 @@ class ExecutionStream:
context_warn_ratio: float | None = None,
batch_init_nudge: str | None = None,
dynamic_memory_provider_factory: Callable[[str], Callable[[], str] | None] | None = None,
colony_memory_dir: Any = None,
colony_worker_sessions_dir: Any = None,
colony_recall_cache: dict[str, str] | None = None,
colony_reflect_llm: Any = None,
):
"""
Initialize execution stream.
@@ -247,6 +251,10 @@ class ExecutionStream:
self._context_warn_ratio: float | None = context_warn_ratio
self._batch_init_nudge: str | None = batch_init_nudge
self._dynamic_memory_provider_factory = dynamic_memory_provider_factory
self._colony_memory_dir = colony_memory_dir
self._colony_worker_sessions_dir = colony_worker_sessions_dir
self._colony_recall_cache = colony_recall_cache
self._colony_reflect_llm = colony_reflect_llm
_es_logger = logging.getLogger(__name__)
if protocols_prompt:
@@ -727,6 +735,10 @@ class ExecutionStream:
if self._dynamic_memory_provider_factory is not None
else None
),
colony_memory_dir=self._colony_memory_dir,
colony_worker_sessions_dir=self._colony_worker_sessions_dir,
colony_recall_cache=self._colony_recall_cache,
colony_reflect_llm=self._colony_reflect_llm,
)
# Track executor so inject_input() can reach EventLoopNode instances
self._active_executors[execution_id] = executor
+6
View File
@@ -789,6 +789,12 @@ class SessionManager:
)
)
# Colony memory config for reflection-at-handoff
runtime._colony_memory_dir = colony_dir
runtime._colony_worker_sessions_dir = worker_sessions_dir
runtime._colony_recall_cache = session.worker_colony_recall_blocks
runtime._colony_reflect_llm = session.llm
session.worker_memory_subs = await subscribe_worker_memory_triggers(
session.event_bus,
session.llm,
+62 -30
View File
@@ -540,7 +540,13 @@ def test_queen_phase_state_appends_colony_and_global_memory_blocks():
@pytest.mark.asyncio
async def test_worker_colony_reflection_updates_shared_memory_and_cache(tmp_path: Path):
async def test_worker_colony_reflection_at_handoff(tmp_path: Path):
"""Colony reflection runs via WorkerAgent._reflect_colony_memory at node handoff."""
import asyncio
from framework.graph.context import GraphContext
from framework.graph.worker_agent import WorkerAgent
worker_sessions_dir = tmp_path / "worker-sessions"
execution_id = "exec-1"
session_dir = worker_sessions_dir / execution_id / "conversations" / "parts"
@@ -556,11 +562,11 @@ async def test_worker_colony_reflection_updates_shared_memory_and_cache(tmp_path
colony_dir = tmp_path / "colony"
colony_dir.mkdir()
recall_cache: dict[str, str] = {}
bus = EventBus()
recall_cache: dict[str, str] = {execution_id: ""}
llm = AsyncMock()
llm.acomplete.side_effect = [
reflect_llm = AsyncMock()
reflect_llm.acomplete.side_effect = [
# Short reflection: write a memory file
MagicMock(
content="",
raw_response={
@@ -583,38 +589,64 @@ async def test_worker_colony_reflection_updates_shared_memory_and_cache(tmp_path
]
},
),
# Short reflection done
MagicMock(content="done", raw_response={}),
# Recall selector picks the new memory
MagicMock(content=json.dumps({"selected_memories": ["user-prefers-terse-summaries.md"]})),
]
subs = await subscribe_worker_memory_triggers(
bus,
llm,
worker_sessions_dir=worker_sessions_dir,
colony_memory_dir=colony_dir,
recall_cache=recall_cache,
)
try:
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="default",
execution_id=execution_id,
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TURN_COMPLETE,
stream_id="default",
execution_id=execution_id,
)
)
finally:
for sub_id in subs:
bus.unsubscribe(sub_id)
# Build a minimal GraphContext with colony memory fields
gc = MagicMock(spec=GraphContext)
gc.colony_memory_dir = colony_dir
gc.worker_sessions_dir = worker_sessions_dir
gc.colony_recall_cache = recall_cache
gc.colony_reflect_llm = reflect_llm
gc.execution_id = execution_id
gc._colony_reflect_lock = asyncio.Lock()
node_spec = SimpleNamespace(id="test-node")
worker = WorkerAgent.__new__(WorkerAgent)
worker._gc = gc
worker.node_spec = node_spec
await worker._reflect_colony_memory()
assert (colony_dir / "user-prefers-terse-summaries.md").exists()
assert "Colony Memories" in recall_cache[execution_id]
assert "terse summaries" in recall_cache[execution_id]
@pytest.mark.asyncio
async def test_subscribe_worker_triggers_only_lifecycle_events(tmp_path: Path):
"""After simplification, worker triggers only subscribe to start and terminal events."""
colony_dir = tmp_path / "colony"
colony_dir.mkdir()
recall_cache: dict[str, str] = {}
bus = EventBus()
llm = AsyncMock()
subs = await subscribe_worker_memory_triggers(
bus,
llm,
worker_sessions_dir=tmp_path / "sessions",
colony_memory_dir=colony_dir,
recall_cache=recall_cache,
)
try:
# Should have exactly 2 subscriptions (start + terminal)
assert len(subs) == 2
# EXECUTION_STARTED initialises cache
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="default",
execution_id="exec-1",
)
)
assert recall_cache.get("exec-1") == ""
finally:
for sub_id in subs:
bus.unsubscribe(sub_id)