feat: simplify worker reflection
This commit is contained in:
@@ -623,12 +623,16 @@ async def subscribe_worker_memory_triggers(
|
||||
colony_memory_dir: Path,
|
||||
recall_cache: dict[str, str],
|
||||
) -> list[str]:
|
||||
"""Subscribe shared colony memory reflection/recall for top-level worker runs."""
|
||||
from framework.agents.queen.recall_selector import update_recall_cache
|
||||
"""Subscribe colony memory lifecycle events for worker runs.
|
||||
|
||||
Short reflection is now handled synchronously at node handoff in
|
||||
``WorkerAgent._reflect_colony_memory()``. This function only manages:
|
||||
- Recall cache initialisation on execution start
|
||||
- Final long reflection + cleanup on execution end
|
||||
"""
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
_lock = asyncio.Lock()
|
||||
_short_counts: dict[str, int] = {}
|
||||
_terminal_lock = asyncio.Lock()
|
||||
|
||||
def _is_worker_event(event: Any) -> bool:
|
||||
return bool(
|
||||
@@ -636,57 +640,6 @@ async def subscribe_worker_memory_triggers(
|
||||
and getattr(event, "stream_id", None) not in ("queen", "judge")
|
||||
)
|
||||
|
||||
async def _update_cache(execution_id: str) -> None:
|
||||
session_dir = worker_sessions_dir / execution_id
|
||||
await update_recall_cache(
|
||||
session_dir,
|
||||
llm,
|
||||
memory_dir=colony_memory_dir,
|
||||
cache_setter=lambda block, execution_id=execution_id: recall_cache.__setitem__(
|
||||
execution_id, block
|
||||
),
|
||||
heading="Colony Memories",
|
||||
)
|
||||
|
||||
async def _on_turn_complete(event: Any) -> None:
|
||||
if not _is_worker_event(event):
|
||||
return
|
||||
if _lock.locked():
|
||||
logger.debug("reflect: worker colony reflection skipped — lock busy")
|
||||
return
|
||||
|
||||
execution_id = event.execution_id
|
||||
if execution_id is None:
|
||||
return
|
||||
session_dir = worker_sessions_dir / execution_id
|
||||
|
||||
async with _lock:
|
||||
try:
|
||||
_short_counts[execution_id] = _short_counts.get(execution_id, 0) + 1
|
||||
await run_short_reflection(session_dir, llm, colony_memory_dir)
|
||||
if _short_counts[execution_id] % _LONG_REFLECT_INTERVAL == 0:
|
||||
await run_long_reflection(llm, colony_memory_dir)
|
||||
await _update_cache(execution_id)
|
||||
except Exception:
|
||||
logger.warning("reflect: worker colony reflection failed", exc_info=True)
|
||||
_write_error("worker colony reflection")
|
||||
|
||||
async def _on_compaction(event: Any) -> None:
|
||||
if not _is_worker_event(event):
|
||||
return
|
||||
if _lock.locked():
|
||||
return
|
||||
execution_id = event.execution_id
|
||||
if execution_id is None:
|
||||
return
|
||||
async with _lock:
|
||||
try:
|
||||
await run_long_reflection(llm, colony_memory_dir)
|
||||
await _update_cache(execution_id)
|
||||
except Exception:
|
||||
logger.warning("reflect: worker compaction reflection failed", exc_info=True)
|
||||
_write_error("worker compaction reflection")
|
||||
|
||||
async def _on_execution_started(event: Any) -> None:
|
||||
if not _is_worker_event(event):
|
||||
return
|
||||
@@ -699,7 +652,7 @@ async def subscribe_worker_memory_triggers(
|
||||
execution_id = event.execution_id
|
||||
if execution_id is None:
|
||||
return
|
||||
async with _lock:
|
||||
async with _terminal_lock:
|
||||
try:
|
||||
await run_long_reflection(llm, colony_memory_dir)
|
||||
except Exception:
|
||||
@@ -707,21 +660,12 @@ async def subscribe_worker_memory_triggers(
|
||||
_write_error("worker final reflection")
|
||||
finally:
|
||||
recall_cache.pop(execution_id, None)
|
||||
_short_counts.pop(execution_id, None)
|
||||
|
||||
return [
|
||||
event_bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_STARTED],
|
||||
handler=_on_execution_started,
|
||||
),
|
||||
event_bus.subscribe(
|
||||
event_types=[EventType.LLM_TURN_COMPLETE],
|
||||
handler=_on_turn_complete,
|
||||
),
|
||||
event_bus.subscribe(
|
||||
event_types=[EventType.CONTEXT_COMPACTED],
|
||||
handler=_on_compaction,
|
||||
),
|
||||
event_bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_COMPLETED, EventType.EXECUTION_FAILED],
|
||||
handler=_on_execution_terminal,
|
||||
|
||||
@@ -67,6 +67,12 @@ class GraphContext:
|
||||
# Retry tracking: worker_id → retry_count (for execution quality assessment)
|
||||
retry_counts: dict[str, int] = field(default_factory=dict)
|
||||
nodes_with_retries: set[str] = field(default_factory=set)
|
||||
# Colony memory reflection at node handoff
|
||||
colony_memory_dir: Any = None # Path | None
|
||||
worker_sessions_dir: Any = None # Path | None
|
||||
colony_recall_cache: dict[str, str] = field(default_factory=dict)
|
||||
colony_reflect_llm: Any = None # LLMProvider for reflection
|
||||
_colony_reflect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
|
||||
|
||||
def build_scoped_buffer(buffer: DataBuffer, node_spec: NodeSpec) -> DataBuffer:
|
||||
|
||||
@@ -160,6 +160,10 @@ class GraphExecutor:
|
||||
skill_dirs: list[str] | None = None,
|
||||
context_warn_ratio: float | None = None,
|
||||
batch_init_nudge: str | None = None,
|
||||
colony_memory_dir: Any = None,
|
||||
colony_worker_sessions_dir: Any = None,
|
||||
colony_recall_cache: dict[str, str] | None = None,
|
||||
colony_reflect_llm: Any = None,
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -221,6 +225,10 @@ class GraphExecutor:
|
||||
self.skill_dirs: list[str] = skill_dirs or []
|
||||
self.context_warn_ratio: float | None = context_warn_ratio
|
||||
self.batch_init_nudge: str | None = batch_init_nudge
|
||||
self.colony_memory_dir = colony_memory_dir
|
||||
self.colony_worker_sessions_dir = colony_worker_sessions_dir
|
||||
self.colony_recall_cache = colony_recall_cache or {}
|
||||
self.colony_reflect_llm = colony_reflect_llm
|
||||
|
||||
if protocols_prompt:
|
||||
self.logger.info(
|
||||
@@ -1318,6 +1326,10 @@ class GraphExecutor:
|
||||
iteration_metadata_provider=self.iteration_metadata_provider,
|
||||
loop_config=self._loop_config,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
colony_memory_dir=self.colony_memory_dir,
|
||||
worker_sessions_dir=self.colony_worker_sessions_dir,
|
||||
colony_recall_cache=self.colony_recall_cache,
|
||||
colony_reflect_llm=self.colony_reflect_llm,
|
||||
)
|
||||
|
||||
# Create one WorkerAgent per node
|
||||
|
||||
@@ -318,6 +318,8 @@ class WorkerAgent:
|
||||
self.lifecycle = WorkerLifecycle.COMPLETED
|
||||
self._last_result = result
|
||||
self._last_activations = activations
|
||||
# Colony memory reflection — runs before downstream activation
|
||||
await self._reflect_colony_memory()
|
||||
completion = WorkerCompletion(
|
||||
worker_id=node_spec.id,
|
||||
success=True,
|
||||
@@ -338,6 +340,8 @@ class WorkerAgent:
|
||||
self.lifecycle = WorkerLifecycle.FAILED
|
||||
self._last_result = result
|
||||
self._last_activations = activations
|
||||
# Colony memory reflection — capture learnings even on failure
|
||||
await self._reflect_colony_memory()
|
||||
await self._publish_failure(result.error or "Unknown error")
|
||||
except Exception as exc:
|
||||
error = str(exc) or type(exc).__name__
|
||||
@@ -649,6 +653,58 @@ class WorkerAgent:
|
||||
pause_event=self._pause_requested,
|
||||
)
|
||||
|
||||
async def _reflect_colony_memory(self) -> None:
|
||||
"""Run colony memory reflection at node handoff.
|
||||
|
||||
Awaits the shared colony lock so parallel workers queue (never skip).
|
||||
"""
|
||||
gc = self._gc
|
||||
if gc.colony_memory_dir is None or gc.colony_reflect_llm is None:
|
||||
return
|
||||
if gc.worker_sessions_dir is None:
|
||||
return
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
session_dir = Path(gc.worker_sessions_dir) / gc.execution_id
|
||||
if not session_dir.exists():
|
||||
return
|
||||
|
||||
# Await lock — serializes reflection but never skips
|
||||
async with gc._colony_reflect_lock:
|
||||
try:
|
||||
from framework.agents.queen.reflection_agent import run_short_reflection
|
||||
|
||||
await run_short_reflection(
|
||||
session_dir, gc.colony_reflect_llm, gc.colony_memory_dir
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Worker %s: colony reflection failed",
|
||||
self.node_spec.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Update recall cache outside lock (per-execution key, no write races)
|
||||
try:
|
||||
from framework.agents.queen.recall_selector import update_recall_cache
|
||||
|
||||
await update_recall_cache(
|
||||
session_dir,
|
||||
gc.colony_reflect_llm,
|
||||
memory_dir=gc.colony_memory_dir,
|
||||
cache_setter=lambda block: gc.colony_recall_cache.__setitem__(
|
||||
gc.execution_id, block
|
||||
),
|
||||
heading="Colony Memories",
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Worker %s: recall cache update failed",
|
||||
self.node_spec.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event publishing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -239,6 +239,11 @@ class AgentRuntime:
|
||||
self._tool_executor = tool_executor
|
||||
self._accounts_prompt = accounts_prompt
|
||||
self._dynamic_memory_provider_factory: Callable[[str], Callable[[], str] | None] | None = None
|
||||
# Colony memory config for reflection-at-handoff (set by session_manager)
|
||||
self._colony_memory_dir: Any = None
|
||||
self._colony_worker_sessions_dir: Any = None
|
||||
self._colony_recall_cache: dict[str, str] | None = None
|
||||
self._colony_reflect_llm: Any = None
|
||||
self._accounts_data = accounts_data
|
||||
self._tool_provider_map = tool_provider_map
|
||||
|
||||
@@ -362,6 +367,10 @@ class AgentRuntime:
|
||||
context_warn_ratio=self.context_warn_ratio,
|
||||
batch_init_nudge=self.batch_init_nudge,
|
||||
dynamic_memory_provider_factory=self._dynamic_memory_provider_factory,
|
||||
colony_memory_dir=self._colony_memory_dir,
|
||||
colony_worker_sessions_dir=self._colony_worker_sessions_dir,
|
||||
colony_recall_cache=self._colony_recall_cache,
|
||||
colony_reflect_llm=self._colony_reflect_llm,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
|
||||
@@ -192,6 +192,10 @@ class ExecutionStream:
|
||||
context_warn_ratio: float | None = None,
|
||||
batch_init_nudge: str | None = None,
|
||||
dynamic_memory_provider_factory: Callable[[str], Callable[[], str] | None] | None = None,
|
||||
colony_memory_dir: Any = None,
|
||||
colony_worker_sessions_dir: Any = None,
|
||||
colony_recall_cache: dict[str, str] | None = None,
|
||||
colony_reflect_llm: Any = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -247,6 +251,10 @@ class ExecutionStream:
|
||||
self._context_warn_ratio: float | None = context_warn_ratio
|
||||
self._batch_init_nudge: str | None = batch_init_nudge
|
||||
self._dynamic_memory_provider_factory = dynamic_memory_provider_factory
|
||||
self._colony_memory_dir = colony_memory_dir
|
||||
self._colony_worker_sessions_dir = colony_worker_sessions_dir
|
||||
self._colony_recall_cache = colony_recall_cache
|
||||
self._colony_reflect_llm = colony_reflect_llm
|
||||
|
||||
_es_logger = logging.getLogger(__name__)
|
||||
if protocols_prompt:
|
||||
@@ -727,6 +735,10 @@ class ExecutionStream:
|
||||
if self._dynamic_memory_provider_factory is not None
|
||||
else None
|
||||
),
|
||||
colony_memory_dir=self._colony_memory_dir,
|
||||
colony_worker_sessions_dir=self._colony_worker_sessions_dir,
|
||||
colony_recall_cache=self._colony_recall_cache,
|
||||
colony_reflect_llm=self._colony_reflect_llm,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
@@ -789,6 +789,12 @@ class SessionManager:
|
||||
)
|
||||
)
|
||||
|
||||
# Colony memory config for reflection-at-handoff
|
||||
runtime._colony_memory_dir = colony_dir
|
||||
runtime._colony_worker_sessions_dir = worker_sessions_dir
|
||||
runtime._colony_recall_cache = session.worker_colony_recall_blocks
|
||||
runtime._colony_reflect_llm = session.llm
|
||||
|
||||
session.worker_memory_subs = await subscribe_worker_memory_triggers(
|
||||
session.event_bus,
|
||||
session.llm,
|
||||
|
||||
@@ -540,7 +540,13 @@ def test_queen_phase_state_appends_colony_and_global_memory_blocks():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_colony_reflection_updates_shared_memory_and_cache(tmp_path: Path):
|
||||
async def test_worker_colony_reflection_at_handoff(tmp_path: Path):
|
||||
"""Colony reflection runs via WorkerAgent._reflect_colony_memory at node handoff."""
|
||||
import asyncio
|
||||
|
||||
from framework.graph.context import GraphContext
|
||||
from framework.graph.worker_agent import WorkerAgent
|
||||
|
||||
worker_sessions_dir = tmp_path / "worker-sessions"
|
||||
execution_id = "exec-1"
|
||||
session_dir = worker_sessions_dir / execution_id / "conversations" / "parts"
|
||||
@@ -556,11 +562,11 @@ async def test_worker_colony_reflection_updates_shared_memory_and_cache(tmp_path
|
||||
|
||||
colony_dir = tmp_path / "colony"
|
||||
colony_dir.mkdir()
|
||||
recall_cache: dict[str, str] = {}
|
||||
bus = EventBus()
|
||||
recall_cache: dict[str, str] = {execution_id: ""}
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.acomplete.side_effect = [
|
||||
reflect_llm = AsyncMock()
|
||||
reflect_llm.acomplete.side_effect = [
|
||||
# Short reflection: write a memory file
|
||||
MagicMock(
|
||||
content="",
|
||||
raw_response={
|
||||
@@ -583,38 +589,64 @@ async def test_worker_colony_reflection_updates_shared_memory_and_cache(tmp_path
|
||||
]
|
||||
},
|
||||
),
|
||||
# Short reflection done
|
||||
MagicMock(content="done", raw_response={}),
|
||||
# Recall selector picks the new memory
|
||||
MagicMock(content=json.dumps({"selected_memories": ["user-prefers-terse-summaries.md"]})),
|
||||
]
|
||||
|
||||
subs = await subscribe_worker_memory_triggers(
|
||||
bus,
|
||||
llm,
|
||||
worker_sessions_dir=worker_sessions_dir,
|
||||
colony_memory_dir=colony_dir,
|
||||
recall_cache=recall_cache,
|
||||
)
|
||||
try:
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="default",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TURN_COMPLETE,
|
||||
stream_id="default",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
for sub_id in subs:
|
||||
bus.unsubscribe(sub_id)
|
||||
# Build a minimal GraphContext with colony memory fields
|
||||
gc = MagicMock(spec=GraphContext)
|
||||
gc.colony_memory_dir = colony_dir
|
||||
gc.worker_sessions_dir = worker_sessions_dir
|
||||
gc.colony_recall_cache = recall_cache
|
||||
gc.colony_reflect_llm = reflect_llm
|
||||
gc.execution_id = execution_id
|
||||
gc._colony_reflect_lock = asyncio.Lock()
|
||||
|
||||
node_spec = SimpleNamespace(id="test-node")
|
||||
worker = WorkerAgent.__new__(WorkerAgent)
|
||||
worker._gc = gc
|
||||
worker.node_spec = node_spec
|
||||
|
||||
await worker._reflect_colony_memory()
|
||||
|
||||
assert (colony_dir / "user-prefers-terse-summaries.md").exists()
|
||||
assert "Colony Memories" in recall_cache[execution_id]
|
||||
assert "terse summaries" in recall_cache[execution_id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_worker_triggers_only_lifecycle_events(tmp_path: Path):
|
||||
"""After simplification, worker triggers only subscribe to start and terminal events."""
|
||||
colony_dir = tmp_path / "colony"
|
||||
colony_dir.mkdir()
|
||||
recall_cache: dict[str, str] = {}
|
||||
bus = EventBus()
|
||||
llm = AsyncMock()
|
||||
|
||||
subs = await subscribe_worker_memory_triggers(
|
||||
bus,
|
||||
llm,
|
||||
worker_sessions_dir=tmp_path / "sessions",
|
||||
colony_memory_dir=colony_dir,
|
||||
recall_cache=recall_cache,
|
||||
)
|
||||
try:
|
||||
# Should have exactly 2 subscriptions (start + terminal)
|
||||
assert len(subs) == 2
|
||||
|
||||
# EXECUTION_STARTED initialises cache
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="default",
|
||||
execution_id="exec-1",
|
||||
)
|
||||
)
|
||||
assert recall_cache.get("exec-1") == ""
|
||||
finally:
|
||||
for sub_id in subs:
|
||||
bus.unsubscribe(sub_id)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user