feat: escalation implementation
This commit is contained in:
@@ -137,7 +137,16 @@ def memory():
|
||||
return SharedMemory()
|
||||
|
||||
|
||||
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
|
||||
def build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=None,
|
||||
input_data=None,
|
||||
goal_context="",
|
||||
stream_id=None,
|
||||
):
|
||||
"""Build a NodeContext for testing."""
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
@@ -148,6 +157,7 @@ def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal
|
||||
llm=llm,
|
||||
available_tools=tools or [],
|
||||
goal_context=goal_context,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -708,6 +718,80 @@ class TestClientFacingBlocking:
|
||||
tool_names = [t.name for t in (call["tools"] or [])]
|
||||
assert "ask_user" not in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_to_coder_available_for_worker_stream(self, runtime, memory):
|
||||
"""Workers should receive escalate_to_coder synthetic tool."""
|
||||
spec = NodeSpec(
|
||||
id="internal",
|
||||
name="Internal",
|
||||
description="internal node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
ctx = build_ctx(runtime, spec, memory, llm, stream_id="worker")
|
||||
|
||||
await node.execute(ctx)
|
||||
|
||||
assert llm._call_index >= 1
|
||||
tool_names = [t.name for t in (llm.stream_calls[0]["tools"] or [])]
|
||||
assert "escalate_to_coder" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_to_coder_not_available_for_queen_stream(self, runtime, memory):
|
||||
"""Queen stream should not receive escalate_to_coder tool."""
|
||||
spec = NodeSpec(
|
||||
id="queen",
|
||||
name="Queen",
|
||||
description="queen node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("monitoring...")])
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
ctx = build_ctx(runtime, spec, memory, llm, stream_id="queen")
|
||||
|
||||
await node.execute(ctx)
|
||||
|
||||
assert llm._call_index >= 1
|
||||
tool_names = [t.name for t in (llm.stream_calls[0]["tools"] or [])]
|
||||
assert "escalate_to_coder" not in tool_names
|
||||
|
||||
|
||||
class TestEscalateToCoder:
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_to_coder_emits_event(self, runtime, node_spec, memory):
|
||||
"""escalate_to_coder() should publish ESCALATION_REQUESTED."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario(
|
||||
"escalate_to_coder",
|
||||
{"reason": "tool failure", "context": "HTTP 401 from upstream"},
|
||||
tool_use_id="escalate_1",
|
||||
),
|
||||
text_scenario("Escalated to queen."),
|
||||
]
|
||||
)
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def capture(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.ESCALATION_REQUESTED], handler=capture)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm, stream_id="worker")
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.ESCALATION_REQUESTED
|
||||
assert received[0].data["reason"] == "tool failure"
|
||||
assert "HTTP 401" in received[0].data["context"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Client-facing: _cf_expecting_work state machine
|
||||
@@ -1765,6 +1849,71 @@ class TestToolDoomLoopIntegration:
|
||||
assert len(doom_events) == 1
|
||||
assert "search" in doom_events[0].data["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_worker_doom_loop_escalates_to_queen(
|
||||
self,
|
||||
runtime,
|
||||
memory,
|
||||
):
|
||||
"""Client-facing worker doom loops should escalate instead of blocking for user input."""
|
||||
spec = NodeSpec(
|
||||
id="worker",
|
||||
name="Worker",
|
||||
description="worker node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
llm = ToolRepeatLLM("search", {"q": "hello"}, tool_turns=3)
|
||||
bus = EventBus()
|
||||
escalation_events: list = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.ESCALATION_REQUESTED],
|
||||
handler=lambda e: escalation_events.append(e),
|
||||
)
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
stream_id="worker",
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
event_bus=bus,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert len(escalation_events) >= 1
|
||||
assert escalation_events[0].data["reason"] == "Tool doom loop detected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doom_loop_disabled(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user