feat: escalation implementation

This commit is contained in:
Richard Tang
2026-03-04 19:59:02 -08:00
parent b9a3c67fea
commit 6ade844722
7 changed files with 424 additions and 86 deletions
+150 -1
View File
@@ -137,7 +137,16 @@ def memory():
return SharedMemory()
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
def build_ctx(
runtime,
node_spec,
memory,
llm,
tools=None,
input_data=None,
goal_context="",
stream_id=None,
):
"""Build a NodeContext for testing."""
return NodeContext(
runtime=runtime,
@@ -148,6 +157,7 @@ def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal
llm=llm,
available_tools=tools or [],
goal_context=goal_context,
stream_id=stream_id,
)
@@ -708,6 +718,80 @@ class TestClientFacingBlocking:
tool_names = [t.name for t in (call["tools"] or [])]
assert "ask_user" not in tool_names
@pytest.mark.asyncio
async def test_escalate_to_coder_available_for_worker_stream(self, runtime, memory):
"""Workers should receive escalate_to_coder synthetic tool."""
spec = NodeSpec(
id="internal",
name="Internal",
description="internal node",
node_type="event_loop",
output_keys=[],
)
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
node = EventLoopNode(config=LoopConfig(max_iterations=2))
ctx = build_ctx(runtime, spec, memory, llm, stream_id="worker")
await node.execute(ctx)
assert llm._call_index >= 1
tool_names = [t.name for t in (llm.stream_calls[0]["tools"] or [])]
assert "escalate_to_coder" in tool_names
@pytest.mark.asyncio
async def test_escalate_to_coder_not_available_for_queen_stream(self, runtime, memory):
"""Queen stream should not receive escalate_to_coder tool."""
spec = NodeSpec(
id="queen",
name="Queen",
description="queen node",
node_type="event_loop",
output_keys=[],
)
llm = MockStreamingLLM(scenarios=[text_scenario("monitoring...")])
node = EventLoopNode(config=LoopConfig(max_iterations=2))
ctx = build_ctx(runtime, spec, memory, llm, stream_id="queen")
await node.execute(ctx)
assert llm._call_index >= 1
tool_names = [t.name for t in (llm.stream_calls[0]["tools"] or [])]
assert "escalate_to_coder" not in tool_names
class TestEscalateToCoder:
@pytest.mark.asyncio
async def test_escalate_to_coder_emits_event(self, runtime, node_spec, memory):
"""escalate_to_coder() should publish ESCALATION_REQUESTED."""
node_spec.output_keys = []
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"escalate_to_coder",
{"reason": "tool failure", "context": "HTTP 401 from upstream"},
tool_use_id="escalate_1",
),
text_scenario("Escalated to queen."),
]
)
bus = EventBus()
received = []
async def capture(event):
received.append(event)
bus.subscribe(event_types=[EventType.ESCALATION_REQUESTED], handler=capture)
ctx = build_ctx(runtime, node_spec, memory, llm, stream_id="worker")
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert len(received) == 1
assert received[0].type == EventType.ESCALATION_REQUESTED
assert received[0].data["reason"] == "tool failure"
assert "HTTP 401" in received[0].data["context"]
# ===========================================================================
# Client-facing: _cf_expecting_work state machine
@@ -1765,6 +1849,71 @@ class TestToolDoomLoopIntegration:
assert len(doom_events) == 1
assert "search" in doom_events[0].data["description"]
@pytest.mark.asyncio
async def test_client_facing_worker_doom_loop_escalates_to_queen(
self,
runtime,
memory,
):
"""Client-facing worker doom loops should escalate instead of blocking for user input."""
spec = NodeSpec(
id="worker",
name="Worker",
description="worker node",
node_type="event_loop",
output_keys=[],
client_facing=True,
)
judge = AsyncMock(spec=JudgeProtocol)
eval_count = 0
async def judge_eval(*args, **kwargs):
nonlocal eval_count
eval_count += 1
if eval_count >= 4:
return JudgeVerdict(action="ACCEPT")
return JudgeVerdict(action="RETRY")
judge.evaluate = judge_eval
llm = ToolRepeatLLM("search", {"q": "hello"}, tool_turns=3)
bus = EventBus()
escalation_events: list = []
bus.subscribe(
event_types=[EventType.ESCALATION_REQUESTED],
handler=lambda e: escalation_events.append(e),
)
def tool_exec(tool_use: ToolUse) -> ToolResult:
return ToolResult(
tool_use_id=tool_use.id,
content="result",
is_error=False,
)
ctx = build_ctx(
runtime,
spec,
memory,
llm,
tools=[Tool(name="search", description="s", parameters={})],
stream_id="worker",
)
node = EventLoopNode(
judge=judge,
tool_executor=tool_exec,
event_bus=bus,
config=LoopConfig(
max_iterations=10,
tool_doom_loop_threshold=3,
),
)
result = await node.execute(ctx)
assert result.success is True
assert len(escalation_events) >= 1
assert escalation_events[0].data["reason"] == "Tool doom loop detected"
@pytest.mark.asyncio
async def test_doom_loop_disabled(
self,