feat: allow judge to wait queen input

This commit is contained in:
Richard Tang
2026-03-05 12:33:27 -08:00
parent 8c9892f9f6
commit 44d609b719
3 changed files with 390 additions and 22 deletions
+158 -21
View File
@@ -626,6 +626,7 @@ class EventLoopNode(NodeProtocol):
user_input_requested,
ask_user_prompt,
ask_user_options,
queen_input_requested,
request_system_prompt,
request_messages,
) = await self._run_single_turn(
@@ -825,6 +826,7 @@ class EventLoopNode(NodeProtocol):
and not real_tool_results
and not outputs_set
and not user_input_requested
and not queen_input_requested
)
if truly_empty and accumulator is not None:
missing = self._get_missing_output_keys(
@@ -1258,25 +1260,105 @@ class EventLoopNode(NodeProtocol):
if not _outputs_complete:
_cf_text_only_streak = 0
_continue_count += 1
if ctx.runtime_logger:
iter_latency_ms = int((time.time() - iter_start) * 1000)
ctx.runtime_logger.log_step(
node_id=node_id,
node_type="event_loop",
step_index=iteration,
verdict="CONTINUE",
verdict_feedback=("Blocked for ask_user input (skip judge)"),
tool_calls=logged_tool_calls,
llm_text=assistant_text,
input_tokens=turn_tokens.get("input", 0),
output_tokens=turn_tokens.get("output", 0),
latency_ms=iter_latency_ms,
)
self._log_skip_judge(
ctx, node_id, iteration,
"Blocked for ask_user input (skip judge)",
logged_tool_calls, assistant_text, turn_tokens, iter_start,
)
continue
# All outputs set -- fall through to judge
# Auto-block beyond grace -- fall through to judge (6i)
# 6h''. Worker wait for queen guidance
# If a worker escalates with wait_for_response=true, pause here and
# skip judge evaluation until queen injects guidance.
if queen_input_requested:
if self._shutdown:
await self._publish_loop_completed(
stream_id, node_id, iteration + 1, execution_id
)
latency_ms = int((time.time() - start_time) * 1000)
_continue_count += 1
self._log_skip_judge(
ctx, node_id, iteration,
"Shutdown signaled (waiting for queen input)",
logged_tool_calls, assistant_text, turn_tokens, iter_start,
)
if ctx.runtime_logger:
ctx.runtime_logger.log_node_complete(
node_id=node_id,
node_name=ctx.node_spec.name,
node_type="event_loop",
success=True,
total_steps=iteration + 1,
tokens_used=total_input_tokens + total_output_tokens,
input_tokens=total_input_tokens,
output_tokens=total_output_tokens,
latency_ms=latency_ms,
exit_status="success",
accept_count=_accept_count,
retry_count=_retry_count,
escalate_count=_escalate_count,
continue_count=_continue_count,
)
return NodeResult(
success=True,
output=accumulator.to_dict(),
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
conversation=conversation if _is_continuous else None,
)
logger.info("[%s] iter=%d: waiting for queen input...", node_id, iteration)
got_input = await self._await_user_input(ctx, prompt="", emit_client_request=False)
logger.info("[%s] iter=%d: queen wait unblocked, got_input=%s", node_id, iteration, got_input)
if not got_input:
await self._publish_loop_completed(
stream_id, node_id, iteration + 1, execution_id
)
latency_ms = int((time.time() - start_time) * 1000)
_continue_count += 1
self._log_skip_judge(
ctx, node_id, iteration,
"No queen input received (shutdown during wait)",
logged_tool_calls, assistant_text, turn_tokens, iter_start,
)
if ctx.runtime_logger:
ctx.runtime_logger.log_node_complete(
node_id=node_id,
node_name=ctx.node_spec.name,
node_type="event_loop",
success=True,
total_steps=iteration + 1,
tokens_used=total_input_tokens + total_output_tokens,
input_tokens=total_input_tokens,
output_tokens=total_output_tokens,
latency_ms=latency_ms,
exit_status="success",
accept_count=_accept_count,
retry_count=_retry_count,
escalate_count=_escalate_count,
continue_count=_continue_count,
)
return NodeResult(
success=True,
output=accumulator.to_dict(),
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
conversation=conversation if _is_continuous else None,
)
recent_responses.clear()
_cf_text_only_streak = 0
_continue_count += 1
self._log_skip_judge(
ctx, node_id, iteration,
"Blocked for queen input (skip judge)",
logged_tool_calls, assistant_text, turn_tokens, iter_start,
)
continue
# 6i. Judge evaluation
should_judge = (
ctx.is_subagent_mode # Always evaluate subagents
@@ -1557,6 +1639,7 @@ class EventLoopNode(NodeProtocol):
prompt: str = "",
*,
options: list[str] | None = None,
emit_client_request: bool = True,
) -> bool:
"""Block until user input arrives or shutdown is signaled.
@@ -1570,6 +1653,9 @@ class EventLoopNode(NodeProtocol):
options: Optional predefined choices for the user (from ask_user).
Passed through to the CLIENT_INPUT_REQUESTED event so the
frontend can render a QuestionWidget with buttons.
emit_client_request: When False, wait silently without publishing
CLIENT_INPUT_REQUESTED. Used for worker waits where input is
expected from the queen via inject_worker_message().
Returns True if input arrived, False if shutdown was signaled.
"""
@@ -1584,7 +1670,7 @@ class EventLoopNode(NodeProtocol):
# without injecting, so the wait still blocks until the user types.
self._input_ready.clear()
if self._event_bus:
if emit_client_request and self._event_bus:
await self._event_bus.emit_client_input_requested(
stream_id=ctx.stream_id or ctx.node_id,
node_id=ctx.node_id,
@@ -1620,13 +1706,15 @@ class EventLoopNode(NodeProtocol):
bool,
str,
list[str] | None,
bool,
str,
list[dict[str, Any]],
]:
"""Run a single LLM turn with streaming and tool execution.
Returns (assistant_text, real_tool_results, outputs_set, token_counts, logged_tool_calls,
user_input_requested, ask_user_prompt, ask_user_options, system_prompt, messages).
user_input_requested, ask_user_prompt, ask_user_options, queen_input_requested,
system_prompt, messages).
``real_tool_results`` contains only results from actual tools (web_search,
etc.), NOT from synthetic framework tools such as ``set_output``,
@@ -1635,6 +1723,9 @@ class EventLoopNode(NodeProtocol):
this turn. ``user_input_requested`` is True if the LLM called
``ask_user`` during this turn. This separation lets the caller treat
synthetic tools as framework concerns rather than tool-execution concerns.
``queen_input_requested`` is True when the worker called
``escalate_to_coder(wait_for_response=true)`` and should wait for
queen guidance before judge evaluation.
``logged_tool_calls`` accumulates ALL tool calls across inner iterations
(real tools, set_output, and discarded calls) for L3 logging. Unlike
@@ -1654,6 +1745,7 @@ class EventLoopNode(NodeProtocol):
user_input_requested = False
ask_user_prompt = ""
ask_user_options: list[str] | None = None
queen_input_requested = False
# Accumulate ALL tool calls across inner iterations for L3 logging.
# Unlike real_tool_results (reset each inner iteration), this persists.
logged_tool_calls: list[dict] = []
@@ -1804,6 +1896,7 @@ class EventLoopNode(NodeProtocol):
user_input_requested,
ask_user_prompt,
ask_user_options,
queen_input_requested,
final_system_prompt,
final_messages,
)
@@ -1953,6 +2046,7 @@ class EventLoopNode(NodeProtocol):
# --- Framework-level escalate_to_coder handling ---
reason = str(tc.tool_input.get("reason", "")).strip()
context = str(tc.tool_input.get("context", "")).strip()
wait_for_response = bool(tc.tool_input.get("wait_for_response", True))
if stream_id in ("queen", "judge"):
result = ToolResult(
@@ -1984,10 +2078,16 @@ class EventLoopNode(NodeProtocol):
context=context,
execution_id=execution_id,
)
if wait_for_response:
queen_input_requested = True
result = ToolResult(
tool_use_id=tc.tool_use_id,
content="Escalation requested to hive_coder (queen).",
content=(
"Escalation requested to hive_coder (queen); waiting for guidance."
if wait_for_response
else "Escalation requested to hive_coder (queen)."
),
is_error=False,
)
results_by_id[tc.tool_use_id] = result
@@ -2278,6 +2378,7 @@ class EventLoopNode(NodeProtocol):
user_input_requested,
ask_user_prompt,
ask_user_options,
queen_input_requested,
final_system_prompt,
final_messages,
)
@@ -2296,9 +2397,9 @@ class EventLoopNode(NodeProtocol):
conversation.usage_ratio() * 100,
)
# If ask_user was called, return immediately so the outer loop
# can block for user input instead of re-invoking the LLM.
if user_input_requested:
# If the turn requested external input (ask_user or queen handoff),
# return immediately so the outer loop can block before judge eval.
if user_input_requested or queen_input_requested:
return (
final_text,
real_tool_results,
@@ -2308,6 +2409,7 @@ class EventLoopNode(NodeProtocol):
user_input_requested,
ask_user_prompt,
ask_user_options,
queen_input_requested,
final_system_prompt,
final_messages,
)
@@ -2407,7 +2509,8 @@ class EventLoopNode(NodeProtocol):
description=(
"Escalate to the Hive Coder queen when blocked by errors, missing "
"credentials, or ambiguous constraints that require supervisor "
"guidance. Include a concise reason and optional context."
"guidance. Include a concise reason and optional context. Set "
"wait_for_response=true to pause until the queen injects guidance."
),
parameters={
"type": "object",
@@ -2422,6 +2525,14 @@ class EventLoopNode(NodeProtocol):
"type": "string",
"description": "Optional diagnostic details for the queen.",
},
"wait_for_response": {
"type": "boolean",
"description": (
"When true (default), block this node until queen guidance "
"arrives via injected input."
),
"default": True,
},
},
"required": ["reason"],
},
@@ -3739,6 +3850,32 @@ class EventLoopNode(NodeProtocol):
iteration=iteration,
)
def _log_skip_judge(
self,
ctx: NodeContext,
node_id: str,
iteration: int,
feedback: str,
tool_calls: list[dict],
llm_text: str,
turn_tokens: dict[str, int],
iter_start: float,
) -> None:
"""Log a CONTINUE step that skips judge evaluation (e.g., waiting for input)."""
if ctx.runtime_logger:
ctx.runtime_logger.log_step(
node_id=node_id,
node_type="event_loop",
step_index=iteration,
verdict="CONTINUE",
verdict_feedback=feedback,
tool_calls=tool_calls,
llm_text=llm_text,
input_tokens=turn_tokens.get("input", 0),
output_tokens=turn_tokens.get("output", 0),
latency_ms=int((time.time() - iter_start) * 1000),
)
async def _publish_loop_completed(
self, stream_id: str, node_id: str, iterations: int, execution_id: str = ""
) -> None:
+102 -1
View File
@@ -31,6 +31,7 @@ from framework.llm.stream_events import (
)
from framework.runtime.core import Runtime
from framework.runtime.event_bus import EventBus, EventType
from framework.server.session_manager import Session, SessionManager
from framework.storage.conversation_store import FileConversationStore
# ---------------------------------------------------------------------------
@@ -768,7 +769,11 @@ class TestEscalateToCoder:
scenarios=[
tool_call_scenario(
"escalate_to_coder",
{"reason": "tool failure", "context": "HTTP 401 from upstream"},
{
"reason": "tool failure",
"context": "HTTP 401 from upstream",
"wait_for_response": False,
},
tool_use_id="escalate_1",
),
text_scenario("Escalated to queen."),
@@ -792,6 +797,102 @@ class TestEscalateToCoder:
assert received[0].data["reason"] == "tool failure"
assert "HTTP 401" in received[0].data["context"]
@pytest.mark.asyncio
async def test_escalate_to_coder_handoff_reaches_queen(self, runtime, node_spec, memory):
"""Worker escalation should be routed to queen via SessionManager handoff sub."""
node_spec.output_keys = []
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"escalate_to_coder",
{
"reason": "blocked",
"context": "dependency missing",
"wait_for_response": False,
},
tool_use_id="escalate_1",
),
text_scenario("Escalation sent."),
]
)
bus = EventBus()
manager = SessionManager()
session = Session(id="handoff_test", event_bus=bus, llm=object(), loaded_at=0.0)
queen_node = MagicMock()
queen_node.inject_event = AsyncMock()
queen_executor = MagicMock()
queen_executor.node_registry = {"queen": queen_node}
manager._subscribe_worker_handoffs(session, queen_executor)
ctx = build_ctx(runtime, node_spec, memory, llm, stream_id="worker")
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
queen_node.inject_event.assert_awaited_once()
injected = queen_node.inject_event.await_args.args[0]
kwargs = queen_node.inject_event.await_args.kwargs
assert "[WORKER_ESCALATION_REQUEST]" in injected
assert "stream_id: worker" in injected
assert "node_id: test_loop" in injected
assert "reason: blocked" in injected
assert "dependency missing" in injected
assert kwargs["is_client_input"] is False
@pytest.mark.asyncio
async def test_escalate_waits_for_queen_input_and_skips_judge(
self, runtime, node_spec, memory
):
"""wait_for_response=true should block for queen input before judge evaluation."""
node_spec.output_keys = ["result"]
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"escalate_to_coder",
{
"reason": "need direction",
"context": "conflicting constraints",
"wait_for_response": True,
},
tool_use_id="escalate_1",
),
tool_call_scenario(
"set_output",
{"key": "result", "value": "resolved after queen guidance"},
tool_use_id="set_1",
),
text_scenario("Completed."),
]
)
bus = EventBus()
client_input_events = []
async def capture_input(event):
client_input_events.append(event)
bus.subscribe(event_types=[EventType.CLIENT_INPUT_REQUESTED], handler=capture_input)
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
ctx = build_ctx(runtime, node_spec, memory, llm, stream_id="worker")
node = EventLoopNode(judge=judge, event_bus=bus, config=LoopConfig(max_iterations=5))
async def queen_reply():
await asyncio.sleep(0.05)
assert judge.evaluate.await_count == 0
await node.inject_event("Use fallback mode and continue.")
task = asyncio.create_task(queen_reply())
result = await node.execute(ctx)
await task
assert result.success is True
assert result.output["result"] == "resolved after queen guidance"
assert judge.evaluate.await_count >= 1
assert len(client_input_events) == 0
# ===========================================================================
# Client-facing: _cf_expecting_work state machine
@@ -0,0 +1,130 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from framework.runtime.event_bus import EventBus
from framework.server.session_manager import Session, SessionManager
def _make_session(event_bus: EventBus, session_id: str = "session_handoff") -> Session:
return Session(id=session_id, event_bus=event_bus, llm=object(), loaded_at=0.0)
def _make_executor(queen_node) -> SimpleNamespace:
node_registry = {}
if queen_node is not None:
node_registry["queen"] = queen_node
return SimpleNamespace(node_registry=node_registry)
@pytest.mark.asyncio
async def test_worker_handoff_injects_formatted_request_into_queen() -> None:
bus = EventBus()
manager = SessionManager()
session = _make_session(bus)
queen_node = SimpleNamespace(inject_event=AsyncMock())
manager._subscribe_worker_handoffs(session, _make_executor(queen_node))
await bus.emit_escalation_requested(
stream_id="worker_a",
node_id="research_node",
reason="Credential wall",
context="HTTP 401 while calling external API",
execution_id="exec_123",
)
queen_node.inject_event.assert_awaited_once()
injected = queen_node.inject_event.await_args.args[0]
kwargs = queen_node.inject_event.await_args.kwargs
assert "[WORKER_ESCALATION_REQUEST]" in injected
assert "stream_id: worker_a" in injected
assert "node_id: research_node" in injected
assert "reason: Credential wall" in injected
assert "context:\nHTTP 401 while calling external API" in injected
assert kwargs["is_client_input"] is False
@pytest.mark.asyncio
async def test_worker_handoff_ignores_queen_and_judge_streams() -> None:
bus = EventBus()
manager = SessionManager()
session = _make_session(bus)
queen_node = SimpleNamespace(inject_event=AsyncMock())
manager._subscribe_worker_handoffs(session, _make_executor(queen_node))
await bus.emit_escalation_requested(
stream_id="queen",
node_id="queen",
reason="should be ignored",
)
await bus.emit_escalation_requested(
stream_id="judge",
node_id="judge",
reason="should be ignored",
)
assert queen_node.inject_event.await_count == 0
@pytest.mark.asyncio
async def test_worker_handoff_resubscribe_replaces_previous_subscription() -> None:
bus = EventBus()
manager = SessionManager()
session = _make_session(bus)
old_queen_node = SimpleNamespace(inject_event=AsyncMock())
manager._subscribe_worker_handoffs(session, _make_executor(old_queen_node))
first_sub = session.worker_handoff_sub
assert first_sub is not None
new_queen_node = SimpleNamespace(inject_event=AsyncMock())
manager._subscribe_worker_handoffs(session, _make_executor(new_queen_node))
second_sub = session.worker_handoff_sub
assert second_sub is not None
assert second_sub != first_sub
assert first_sub not in bus._subscriptions
await bus.emit_escalation_requested(
stream_id="worker_b",
node_id="planner",
reason="stuck",
)
assert old_queen_node.inject_event.await_count == 0
new_queen_node.inject_event.assert_awaited_once()
@pytest.mark.asyncio
async def test_stop_session_unsubscribes_worker_handoff() -> None:
bus = EventBus()
manager = SessionManager()
session = _make_session(bus, session_id="session_stop")
queen_node = SimpleNamespace(inject_event=AsyncMock())
manager._subscribe_worker_handoffs(session, _make_executor(queen_node))
manager._sessions[session.id] = session
await bus.emit_escalation_requested(
stream_id="worker_main",
node_id="node_1",
reason="before stop",
)
assert queen_node.inject_event.await_count == 1
stopped = await manager.stop_session(session.id)
assert stopped is True
assert session.worker_handoff_sub is None
await bus.emit_escalation_requested(
stream_id="worker_main",
node_id="node_1",
reason="after stop",
)
assert queen_node.inject_event.await_count == 1