From 6be026fcb1792f9027dd0ef7811f1f6df679f3aa Mon Sep 17 00:00:00 2001 From: Timothy Date: Fri, 17 Apr 2026 04:06:59 -0700 Subject: [PATCH] fix: partial parts and system nudge --- .claude/settings.json | 4 +- core/framework/agent_loop/agent_loop.py | 329 ++++++++++++++++--- core/framework/agent_loop/conversation.py | 160 +++++++++ core/framework/agent_loop/internals/types.py | 37 ++- core/framework/host/event_bus.py | 97 ++++++ core/framework/storage/conversation_store.py | 46 +++ core/tests/test_event_loop_node.py | 134 ++++++++ core/tests/test_node_conversation.py | 206 ++++++++++++ 8 files changed, 962 insertions(+), 51 deletions(-) diff --git a/.claude/settings.json b/.claude/settings.json index 1b61758d..6b843f12 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -55,7 +55,9 @@ "mcp__gcu-tools__browser_click_coordinate", "mcp__gcu-tools__browser_get_rect", "mcp__gcu-tools__browser_type_focused", - "mcp__gcu-tools__browser_wait" + "mcp__gcu-tools__browser_wait", + "Bash(python3 -c ' *)", + "Bash(python3 scripts/debug_queen_prompt.py independent)" ], "additionalDirectories": [ "/home/timothy/.hive/skills/writing-hive-skills", diff --git a/core/framework/agent_loop/agent_loop.py b/core/framework/agent_loop/agent_loop.py index c31948a1..0d34facd 100644 --- a/core/framework/agent_loop/agent_loop.py +++ b/core/framework/agent_loop/agent_loop.py @@ -2335,6 +2335,11 @@ class AgentLoop(AgentProtocol): execution_id, ) + # Continue-nudge counter: how many times we've re-streamed within this + # _run_single_turn because the idle/TTFT watchdog fired. Caps to avoid + # nudging forever when the endpoint is genuinely dead. + _nudge_count_this_turn = 0 + # Inner tool loop: stream may produce tool calls requiring re-invocation while True: # Pre-send guard: if context is at or over budget, compact before @@ -2423,7 +2428,16 @@ class AgentLoop(AgentProtocol): # Capture loop-scoped variables as defaults to satisfy B023. # _stream_last_event_at is bumped on every event; the watchdog # below uses it to detect silently hung HTTP connections. - _stream_last_event_at = time.monotonic() + _stream_start_at = time.monotonic() + _stream_last_event_at = _stream_start_at + # None until the first event arrives. Before first event, the + # watchdog uses the (much looser) TTFT budget — large-context + # local models legitimately take minutes to first token. Once + # any event has been observed, tight inter-event idle applies. + _first_event_at: float | None = None + # Partial tool_calls accumulated so far, as OpenAI-format dicts + # ready for persistence if the stream is cut short. + _partial_tc_dicts: list[dict[str, Any]] = [] async def _do_stream( _msgs: list = messages, # noqa: B006 @@ -2432,8 +2446,10 @@ class AgentLoop(AgentProtocol): _safe_names: set = _early_safe_names, # noqa: B006,B008 _tasks: dict = _early_tasks, # noqa: B006,B008 _exec_fn=_timed_execute, + _partial_dicts: list[dict[str, Any]] = _partial_tc_dicts, # noqa: B006,B008 ) -> None: nonlocal accumulated_text, _stream_error, _stream_last_event_at + nonlocal _first_event_at _clean_snapshot = "" # visible-only text for the frontend async for event in ctx.llm.stream( @@ -2443,6 +2459,8 @@ class AgentLoop(AgentProtocol): max_tokens=ctx.max_tokens, ): _stream_last_event_at = time.monotonic() + if _first_event_at is None: + _first_event_at = _stream_last_event_at if isinstance(event, TextDeltaEvent): accumulated_text = event.snapshot # Strip internal reasoning tags from the full @@ -2462,9 +2480,46 @@ class AgentLoop(AgentProtocol): iteration=iteration, inner_turn=inner_turn, ) + # Checkpoint partial state so a watchdog cancel or + # crash doesn't discard whatever the model has + # produced so far. Cheap — one atomic file write. + try: + await conversation.checkpoint_partial_assistant( + accumulated_text, + _partial_dicts or None, + ) + except Exception as _cp_err: # noqa: BLE001 + logger.debug( + "[_run_single_turn] partial checkpoint failed: %s", + _cp_err, + ) elif isinstance(event, ToolCallEvent): _tc.append(event) + _partial_dicts.append( + { + "id": event.tool_use_id, + "type": "function", + "function": { + "name": event.tool_name, + "arguments": json.dumps(event.tool_input), + }, + } + ) + # Checkpoint now that a tool call has landed — + # this is the important one: if the stream dies + # right after a tool call but before FinishEvent, + # we still have the intent recorded. + try: + await conversation.checkpoint_partial_assistant( + accumulated_text, + _partial_dicts or None, + ) + except Exception as _cp_err: # noqa: BLE001 + logger.debug( + "[_run_single_turn] partial checkpoint failed: %s", + _cp_err, + ) # Gap 1: start concurrency-safe tools immediately # while the rest of the stream is still arriving, # so read-heavy turns don't stall after the last @@ -2492,55 +2547,99 @@ class AgentLoop(AgentProtocol): _llm_stream_t0 = time.monotonic() self._stream_task = asyncio.create_task(_do_stream()) logger.debug("[_run_single_turn] inner_turn=%d: Stream task created, waiting...", inner_turn) - _inactivity_limit = self._config.llm_stream_inactivity_timeout_seconds + + # Watchdog budgets — see LoopConfig docstring for rationale. + _ttft_limit = self._config.llm_stream_ttft_timeout_seconds + _inter_event_limit = self._config.llm_stream_inter_event_idle_seconds + # Back-compat: if the legacy inactivity knob was overridden to + # a value below the new default, respect it as the inter-event + # budget (historic behaviour) so existing configs don't regress. + _legacy = self._config.llm_stream_inactivity_timeout_seconds + if _legacy and _legacy > 0 and _legacy < _inter_event_limit: + _inter_event_limit = _legacy + _watchdog_active = (_ttft_limit and _ttft_limit > 0) or ( + _inter_event_limit and _inter_event_limit > 0 + ) + # Result of the watchdog: "ok" (stream finished), "ttft" (no first + # event in budget), "inactive" (silence after first event). + _watchdog_verdict: str = "ok" + _watchdog_elapsed: float = 0.0 + _watchdog_limit: float = 0.0 + try: - if _inactivity_limit and _inactivity_limit > 0: - # Heartbeat-aware wait: poll the task and cancel it if - # no stream event has been observed within the window. - # A silently dead HTTP connection otherwise hangs here - # forever — no exception, no delta, no timeout. - # - # Must use asyncio.wait (not wait_for) so we can tell - # "poll interval elapsed" apart from "task raised a - # TimeoutError of its own" — wait_for conflates them. - _check_interval = min(5.0, _inactivity_limit / 2) + if _watchdog_active: + # Poll cheapest-valid interval: at most every 5s, at least + # half the tighter budget. Must use asyncio.wait (not + # wait_for) so "poll interval elapsed" and "task raised + # TimeoutError of its own" stay distinguishable. + _tight = min( + _ttft_limit or float("inf"), + _inter_event_limit or float("inf"), + ) + _check_interval = max(1.0, min(5.0, _tight / 2)) while True: - done, _pending = await asyncio.wait({self._stream_task}, timeout=_check_interval) + done, _pending = await asyncio.wait( + {self._stream_task}, timeout=_check_interval + ) if self._stream_task in done: - # Let any exception the task raised propagate - # naturally via the outer ``await`` below. break - idle = time.monotonic() - _stream_last_event_at - if idle >= _inactivity_limit: - logger.warning( - "[_run_single_turn] inner_turn=%d: " - "stream inactivity %.0fs >= %.0fs — " - "cancelling stream task", - inner_turn, - idle, - _inactivity_limit, - ) - self._bump("stream_inactivity_watchdog") - self._stream_task.cancel() - try: - await self._stream_task - except BaseException: - pass - raise ConnectionError( - f"LLM stream idle for {idle:.0f}s " - f"(inactivity limit {_inactivity_limit:.0f}s) — " - "connection presumed dead" - ) from None + now = time.monotonic() + if _first_event_at is None: + # TTFT phase — stream open but silent. Use the + # looser budget; don't confuse slow models with + # dead connections. + elapsed = now - _stream_start_at + if _ttft_limit and _ttft_limit > 0 and elapsed >= _ttft_limit: + _watchdog_verdict = "ttft" + _watchdog_elapsed = elapsed + _watchdog_limit = _ttft_limit + break + else: + # Post-first-event silence. A stream that produced + # events and then went quiet is a real stall. + idle = now - _stream_last_event_at + if ( + _inter_event_limit + and _inter_event_limit > 0 + and idle >= _inter_event_limit + ): + _watchdog_verdict = "inactive" + _watchdog_elapsed = idle + _watchdog_limit = _inter_event_limit + break # Still active — keep polling. - # Re-raise any exception the stream task stored. When the - # watchdog loop exited via ``break`` the task is done, and - # ``await`` is the cheapest way to surface its exception. - await self._stream_task - logger.debug("[_run_single_turn] inner_turn=%d: Stream task completed normally", inner_turn) + + if _watchdog_verdict != "ok": + logger.warning( + "[_run_single_turn] inner_turn=%d: watchdog=%s %.0fs >= %.0fs — cancelling stream", + inner_turn, + _watchdog_verdict, + _watchdog_elapsed, + _watchdog_limit, + ) + self._bump(f"stream_watchdog_{_watchdog_verdict}") + self._stream_task.cancel() + try: + await self._stream_task + except BaseException: + pass + else: + # Re-raise any exception the stream task stored. When the + # watchdog loop exited via ``break`` the task is done, and + # ``await`` is the cheapest way to surface its exception. + await self._stream_task + logger.debug( + "[_run_single_turn] inner_turn=%d: Stream task completed normally", + inner_turn, + ) except asyncio.CancelledError: logger.debug("[_run_single_turn] inner_turn=%d: Stream task cancelled", inner_turn) - if accumulated_text: - await conversation.add_assistant_message(content=accumulated_text) + if accumulated_text or _partial_tc_dicts: + await conversation.add_assistant_message( + content=accumulated_text, + tool_calls=_partial_tc_dicts or None, + truncated=True, + ) # Gap 1: kill any early-dispatched tool tasks too. # Without this, a safe tool started during streaming # would leak past cancellation and keep running. @@ -2568,6 +2667,100 @@ class AgentLoop(AgentProtocol): raise finally: self._stream_task = None + + # Continue-nudge recovery path. Runs AFTER the stream task is + # cleaned up so all state is consistent. We persist whatever + # partial text + tool-calls the model produced (as a truncated + # message so the model can see its own in-flight work on the + # next turn), cancel early tool tasks, append a terse + # continuation hint, and restart the stream. + if _watchdog_verdict != "ok": + # Kill any safe-tool tasks the stream dispatched early — + # their results would have had nowhere to land anyway + # because the assistant message was incomplete. + for _early in _early_tasks.values(): + if not _early.done(): + _early.cancel() + # Promote whatever we captured into a real truncated + # message. The partial checkpoint for this seq is cleared + # automatically when add_assistant_message persists. + if accumulated_text or _partial_tc_dicts: + await conversation.add_assistant_message( + content=accumulated_text, + tool_calls=_partial_tc_dicts or None, + truncated=True, + ) + + reason_label = ( + "no tokens before TTFT budget" + if _watchdog_verdict == "ttft" + else "stream went silent after producing events" + ) + if self._event_bus: + if _watchdog_verdict == "ttft": + await self._event_bus.emit_stream_ttft_exceeded( + stream_id=stream_id, + node_id=node_id, + ttft_seconds=_watchdog_elapsed, + limit_seconds=_watchdog_limit, + execution_id=execution_id, + ) + else: + await self._event_bus.emit_stream_inactive( + stream_id=stream_id, + node_id=node_id, + idle_seconds=_watchdog_elapsed, + limit_seconds=_watchdog_limit, + execution_id=execution_id, + ) + + nudge_enabled = self._config.continue_nudge_enabled + nudge_cap = self._config.continue_nudge_max_per_turn + if nudge_enabled and _nudge_count_this_turn < nudge_cap: + _nudge_count_this_turn += 1 + nudge_msg = ( + f"[System: the previous stream stalled ({reason_label}, " + f"{_watchdog_elapsed:.0f}s). Continue from the last tool " + "result already in this conversation. Do NOT repeat tool " + "calls whose results are visible above — reuse them and " + "move to the next step.]" + ) + await conversation.add_user_message( + nudge_msg, + is_system_nudge=True, + ) + if self._event_bus: + await self._event_bus.emit_stream_nudge_sent( + stream_id=stream_id, + node_id=node_id, + reason=_watchdog_verdict, + nudge_count=_nudge_count_this_turn, + execution_id=execution_id, + ) + logger.info( + "[%s] continue-nudge sent (count=%d/%d, reason=%s)", + node_id, + _nudge_count_this_turn, + nudge_cap, + _watchdog_verdict, + ) + # Reset the outer _turn_t0 timer so the "LLM done in + # Xms" log line reflects real work not the nudge cycle. + _llm_stream_ms = int((time.monotonic() - _llm_stream_t0) * 1000) + logger.debug( + "[_run_single_turn] inner_turn=%d: nudge restart after %dms", + inner_turn, + _llm_stream_ms, + ) + continue # restart the inner loop, re-fetches messages + # Nudge disabled or cap exhausted — fall back to the + # existing retry path so a truly dead endpoint eventually + # surfaces as an error. + raise ConnectionError( + f"LLM stream {_watchdog_verdict} for {_watchdog_elapsed:.0f}s " + f"(limit {_watchdog_limit:.0f}s) — nudge cap reached" + ) + _llm_stream_ms = int((time.monotonic() - _llm_stream_t0) * 1000) # If a recoverable stream error produced an empty response, @@ -2667,6 +2860,12 @@ class AgentLoop(AgentProtocol): results_by_id: dict[str, ToolResult] = {} timing_by_id: dict[str, dict[str, Any]] = {} # tool_use_id -> {start_timestamp, duration_s} pending_real: list[ToolCallEvent] = [] + # Replay detector: per-turn map from tool_use_id -> steer prefix. + # Populated below when we detect that the model is re-emitting a + # tool call whose (name + canonical args) matches a prior success. + # Applied to the stored tool result content so the model sees the + # nudge on its next turn without losing the real execution output. + replay_prefixes_by_id: dict[str, str] = {} for tc in tool_calls: tool_call_count += 1 @@ -2939,6 +3138,39 @@ class AgentLoop(AgentProtocol): ) results_by_id[tc.tool_use_id] = result else: + # Replay detector: flag re-executions of recent + # successful calls. We still run the tool (some + # are legitimately repeated, e.g. screenshots and + # read-only evaluates) but prepend a terse steer + # onto the stored result so the model sees the + # signal on its next turn. + if self._config.replay_detector_enabled: + prior = conversation.find_completed_tool_call( + tc.tool_name, + tc.tool_input, + within_last_turns=self._config.replay_detector_within_last_turns, + ) + if prior is not None: + logger.warning( + "[%s] replay detected: %s matches prior seq=%d — executing anyway", + node_id, + tc.tool_name, + prior.seq, + ) + self._bump("tool_call_replay_detected") + if self._event_bus: + await self._event_bus.emit_tool_call_replay_detected( + stream_id=stream_id, + node_id=node_id, + tool_name=tc.tool_name, + prior_seq=prior.seq, + execution_id=execution_id, + ) + replay_prefixes_by_id[tc.tool_use_id] = ( + f"[Replay detected: {tc.tool_name} matches " + f"seq={prior.seq}. Result still produced below — " + "consider whether the retry was necessary.]\n" + ) pending_real.append(tc) # Phase 2a: partition real tools by concurrency safety. @@ -3136,9 +3368,18 @@ class AgentLoop(AgentProtocol): ) image_content = None + # Apply replay-detector steer prefix if this call matched a + # recent successful invocation. Only applies to non-error + # results — an error already breaks the replay chain. + stored_content = result.content + if not result.is_error: + _prefix = replay_prefixes_by_id.get(tc.tool_use_id) + if _prefix: + stored_content = f"{_prefix}{stored_content or ''}" + await conversation.add_tool_result( tool_use_id=tc.tool_use_id, - content=result.content, + content=stored_content, is_error=result.is_error, image_content=image_content, is_skill_content=result.is_skill_content, diff --git a/core/framework/agent_loop/conversation.py b/core/framework/agent_loop/conversation.py index e992cd5f..54e02611 100644 --- a/core/framework/agent_loop/conversation.py +++ b/core/framework/agent_loop/conversation.py @@ -48,6 +48,14 @@ class Message: is_skill_content: bool = False # Logical worker run identifier for shared-session persistence run_id: str | None = None + # True when this is a framework-injected continuation hint (continue-nudge + # on stream stall). Stored as a user message for API compatibility, but + # the UI should render it as a compact system notice, not user speech. + is_system_nudge: bool = False + # True when this message is a partial/truncated assistant turn reconstructed + # from a crashed or watchdog-cancelled stream. Signals that the original + # turn never finished — the model may or may not choose to redo it. + truncated: bool = False def to_llm_dict(self) -> dict[str, Any]: """Convert to OpenAI-format message dict.""" @@ -109,6 +117,10 @@ class Message: d["image_content"] = self.image_content if self.run_id is not None: d["run_id"] = self.run_id + if self.is_system_nudge: + d["is_system_nudge"] = self.is_system_nudge + if self.truncated: + d["truncated"] = self.truncated return d @classmethod @@ -126,6 +138,8 @@ class Message: is_client_input=data.get("is_client_input", False), image_content=data.get("image_content"), run_id=data.get("run_id"), + is_system_nudge=data.get("is_system_nudge", False), + truncated=data.get("truncated", False), ) @@ -317,6 +331,14 @@ class ConversationStore(Protocol): async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None: ... + async def write_partial(self, seq: int, data: dict[str, Any]) -> None: ... + + async def read_partial(self, seq: int) -> dict[str, Any] | None: ... + + async def read_all_partials(self) -> list[dict[str, Any]]: ... + + async def clear_partial(self, seq: int) -> None: ... + async def close(self) -> None: ... async def destroy(self) -> None: ... @@ -462,6 +484,7 @@ class NodeConversation: is_transition_marker: bool = False, is_client_input: bool = False, image_content: list[dict[str, Any]] | None = None, + is_system_nudge: bool = False, ) -> Message: msg = Message( seq=self._next_seq, @@ -472,6 +495,7 @@ class NodeConversation: is_transition_marker=is_transition_marker, is_client_input=is_client_input, image_content=image_content, + is_system_nudge=is_system_nudge, ) self._messages.append(msg) self._next_seq += 1 @@ -485,6 +509,8 @@ class NodeConversation: self, content: str, tool_calls: list[dict[str, Any]] | None = None, + *, + truncated: bool = False, ) -> Message: msg = Message( seq=self._next_seq, @@ -493,6 +519,7 @@ class NodeConversation: tool_calls=tool_calls, phase_id=self._current_phase, run_id=self._run_id, + truncated=truncated, ) self._messages.append(msg) self._next_seq += 1 @@ -548,6 +575,59 @@ class NodeConversation: # --- Query ------------------------------------------------------------- + def find_completed_tool_call( + self, + name: str, + tool_input: dict[str, Any], + within_last_turns: int = 3, + ) -> Message | None: + """Return the most recent assistant message that issued a tool call + with the same (name + canonical-json args) AND received a non-error + tool result, within the last ``within_last_turns`` assistant turns. + + Used by the replay detector to flag when the model is about to redo + a successful call — we prepend a steer onto the upcoming result but + still execute, so tools like browser_screenshot that are legitimately + repeated are not silently skipped. + """ + try: + target_canonical = json.dumps(tool_input, sort_keys=True, default=str) + except (TypeError, ValueError): + target_canonical = str(tool_input) + + # Walk backwards over recent assistant messages + assistant_turns_seen = 0 + for idx in range(len(self._messages) - 1, -1, -1): + m = self._messages[idx] + if m.role != "assistant": + continue + assistant_turns_seen += 1 + if assistant_turns_seen > within_last_turns: + break + if not m.tool_calls: + continue + for tc in m.tool_calls: + func = tc.get("function", {}) if isinstance(tc, dict) else {} + tc_name = func.get("name") + if tc_name != name: + continue + args_str = func.get("arguments", "") + try: + parsed = json.loads(args_str) if isinstance(args_str, str) else args_str + canonical = json.dumps(parsed, sort_keys=True, default=str) + except (TypeError, ValueError): + canonical = str(args_str) + if canonical != target_canonical: + continue + # Found a match — now verify its result was not an error. + tc_id = tc.get("id") + for later in self._messages[idx + 1 :]: + if later.role == "tool" and later.tool_use_id == tc_id: + if not later.is_error: + return m + break + return None + def to_llm_messages(self) -> list[dict[str, Any]]: """Return messages as OpenAI-format dicts (system prompt excluded). @@ -1365,6 +1445,45 @@ class NodeConversation: await self._persist_meta() await self._store.write_part(message.seq, message.to_storage_dict()) await self._write_next_seq() + # Any partial checkpoint for this seq is now superseded by the real + # part — clear it so a future restore doesn't resurrect stale text. + try: + await self._store.clear_partial(message.seq) + except AttributeError: + # Older stores may not implement partials; ignore. + pass + + async def checkpoint_partial_assistant( + self, + accumulated_text: str, + tool_calls: list[dict[str, Any]] | None = None, + ) -> None: + """Write an in-flight assistant turn's state to disk under the next seq. + + Called from the stream event loop. Safe to call repeatedly — each call + overwrites the prior checkpoint. Persisted via ``write_partial`` so it + does NOT appear in ``read_parts()`` and cannot be double-loaded. Cleared + automatically when ``add_assistant_message`` for this seq lands. + """ + if self._store is None: + return + if not self._meta_persisted: + await self._persist_meta() + payload: dict[str, Any] = { + "seq": self._next_seq, + "role": "assistant", + "content": accumulated_text, + "phase_id": self._current_phase, + "run_id": self._run_id, + "truncated": True, + } + if tool_calls: + payload["tool_calls"] = tool_calls + try: + await self._store.write_partial(self._next_seq, payload) + except AttributeError: + # Older stores may not implement partials; ignore. + pass async def _persist_meta(self) -> None: """Lazily write conversation metadata to the store (called once). @@ -1461,4 +1580,45 @@ class NodeConversation: elif conv._messages: conv._next_seq = conv._messages[-1].seq + 1 + # Surface any leftover partial checkpoints as truncated messages so + # the next turn sees what the interrupted stream was in the middle + # of producing. Only partials whose seq is >= next_seq are meaningful; + # anything lower was already superseded by a real part. + try: + partials = await store.read_all_partials() + except AttributeError: + partials = [] + for p in partials: + pseq = p.get("seq", -1) + if pseq < conv._next_seq: + # Stale — clean it up. + try: + await store.clear_partial(pseq) + except AttributeError: + pass + continue + # Only resurrect partials relevant to this run / phase. + if run_id and not is_legacy_run_id(run_id) and p.get("run_id") != run_id: + continue + if phase_id and p.get("phase_id") is not None and p.get("phase_id") != phase_id: + continue + # Reconstruct as a truncated assistant message. + msg = Message( + seq=pseq, + role="assistant", + content=p.get("content", "") or "", + tool_calls=p.get("tool_calls"), + phase_id=p.get("phase_id"), + run_id=p.get("run_id"), + truncated=True, + ) + conv._messages.append(msg) + conv._next_seq = max(conv._next_seq, pseq + 1) + logger.info( + "restore: resurrected truncated partial seq=%d (text=%d chars, tool_calls=%d)", + pseq, + len(msg.content), + len(msg.tool_calls or []), + ) + return conv diff --git a/core/framework/agent_loop/internals/types.py b/core/framework/agent_loop/internals/types.py index 26d44ed3..64c5bc0b 100644 --- a/core/framework/agent_loop/internals/types.py +++ b/core/framework/agent_loop/internals/types.py @@ -131,14 +131,39 @@ class LoopConfig: # Per-tool-call timeout. tool_call_timeout_seconds: float = 60.0 - # LLM stream inactivity watchdog. If no stream event (delta, tool call, - # finish) arrives within this many seconds, the stream task is cancelled - # and a transient error is raised so the retry loop can back off and - # reconnect. Prevents agents from hanging forever on a silently dead - # HTTP connection (no provider heartbeat, no exception, just silence). - # Set to 0 to disable. + # LLM stream inactivity watchdog. Split into two budgets so legitimate + # slow TTFT on large contexts doesn't get mistaken for a dead connection. + # - ttft: stream open -> first event. Large-context local models can + # legitimately take minutes before the first token arrives. + # - inter_event: last event -> now, ONLY after the first event. A stream + # that started producing and then went silent is a real stall. + # Whichever fires first cancels the stream. Set to 0 to disable that + # individual budget; set both to 0 to fully disable the watchdog. + llm_stream_ttft_timeout_seconds: float = 600.0 + llm_stream_inter_event_idle_seconds: float = 120.0 + # Deprecated alias — kept so existing configs keep working. If set to a + # non-default value it overrides inter_event_idle (historical behavior). llm_stream_inactivity_timeout_seconds: float = 120.0 + # Continue-nudge recovery. When the idle watchdog fires on a live but + # stuck stream, cancel the stream and append a short continuation + # hint to the conversation instead of raising a ConnectionError and + # re-running the whole turn. Preserves any partial text/tool-calls the + # stream emitted before the stall. + continue_nudge_enabled: bool = True + # Cap so a truly dead endpoint eventually falls back to the error path + # instead of nudging forever. + continue_nudge_max_per_turn: int = 3 + + # Tool-call replay detector. When the model emits a tool call whose + # (name + canonical-args) matches a prior successful call in the last + # K assistant turns, emit telemetry and prepend a short steer onto the + # tool result — but still execute. Weaker models legitimately repeat + # read-only calls (screenshot, evaluate), so silent skipping would + # cause surprising behavior. + replay_detector_enabled: bool = True + replay_detector_within_last_turns: int = 3 + # Subagent delegation timeout (wall-clock max). subagent_timeout_seconds: float = 3600.0 diff --git a/core/framework/host/event_bus.py b/core/framework/host/event_bus.py index d193452a..2d556e2a 100644 --- a/core/framework/host/event_bus.py +++ b/core/framework/host/event_bus.py @@ -111,6 +111,15 @@ class EventType(StrEnum): # Retry tracking NODE_RETRY = "node_retry" + # Stream-health observability. Split from NODE_RETRY so the UI can + # distinguish "slow TTFT on a huge context" (healthy, just slow) from + # "stream went silent mid-generation" (probable stall) from "we nudged + # the model to continue" (recovery), which NODE_RETRY used to conflate. + STREAM_TTFT_EXCEEDED = "stream_ttft_exceeded" + STREAM_INACTIVE = "stream_inactive" + STREAM_NUDGE_SENT = "stream_nudge_sent" + TOOL_CALL_REPLAY_DETECTED = "tool_call_replay_detected" + # Worker agent lifecycle WORKER_COMPLETED = "worker_completed" WORKER_FAILED = "worker_failed" @@ -1061,6 +1070,94 @@ class EventBus: ) ) + async def emit_stream_ttft_exceeded( + self, + stream_id: str, + node_id: str, + ttft_seconds: float, + limit_seconds: float, + execution_id: str | None = None, + ) -> None: + """Emit when a stream stayed silent past the TTFT budget (no first event).""" + await self.publish( + AgentEvent( + type=EventType.STREAM_TTFT_EXCEEDED, + stream_id=stream_id, + node_id=node_id, + execution_id=execution_id, + data={ + "ttft_seconds": ttft_seconds, + "limit_seconds": limit_seconds, + }, + ) + ) + + async def emit_stream_inactive( + self, + stream_id: str, + node_id: str, + idle_seconds: float, + limit_seconds: float, + execution_id: str | None = None, + ) -> None: + """Emit when a stream that had produced events went silent past budget.""" + await self.publish( + AgentEvent( + type=EventType.STREAM_INACTIVE, + stream_id=stream_id, + node_id=node_id, + execution_id=execution_id, + data={ + "idle_seconds": idle_seconds, + "limit_seconds": limit_seconds, + }, + ) + ) + + async def emit_stream_nudge_sent( + self, + stream_id: str, + node_id: str, + reason: str, + nudge_count: int, + execution_id: str | None = None, + ) -> None: + """Emit when the continue-nudge was injected (recovery, not retry).""" + await self.publish( + AgentEvent( + type=EventType.STREAM_NUDGE_SENT, + stream_id=stream_id, + node_id=node_id, + execution_id=execution_id, + data={ + "reason": reason, + "nudge_count": nudge_count, + }, + ) + ) + + async def emit_tool_call_replay_detected( + self, + stream_id: str, + node_id: str, + tool_name: str, + prior_seq: int, + execution_id: str | None = None, + ) -> None: + """Emit when the model is about to re-execute a prior successful call.""" + await self.publish( + AgentEvent( + type=EventType.TOOL_CALL_REPLAY_DETECTED, + stream_id=stream_id, + node_id=node_id, + execution_id=execution_id, + data={ + "tool_name": tool_name, + "prior_seq": prior_seq, + }, + ) + ) + async def emit_worker_completed( self, stream_id: str, diff --git a/core/framework/storage/conversation_store.py b/core/framework/storage/conversation_store.py index ac812b85..55f12edb 100644 --- a/core/framework/storage/conversation_store.py +++ b/core/framework/storage/conversation_store.py @@ -43,6 +43,10 @@ class FileConversationStore: def __init__(self, base_path: str | Path) -> None: self._base = Path(base_path) self._parts_dir = self._base / "parts" + # Partial checkpoints for in-flight assistant turns. Written on every + # stream event, deleted atomically when the final part lands. Kept + # in a sibling dir so the parts/ glob doesn't pick them up. + self._partials_dir = self._base / "partials" # --- sync helpers -------------------------------------------------------- @@ -99,6 +103,44 @@ class FileConversationStore: async def read_cursor(self) -> dict[str, Any] | None: return await self._run(self._read_json, self._base / "cursor.json") + async def write_partial(self, seq: int, data: dict[str, Any]) -> None: + """Checkpoint an in-flight assistant turn. Overwrites any prior partial + for the same seq. Caller is expected to clear_partial() once the real + part is written via write_part(). + """ + path = self._partials_dir / f"{seq:010d}.json" + await self._run(self._write_json, path, data) + + async def read_partial(self, seq: int) -> dict[str, Any] | None: + path = self._partials_dir / f"{seq:010d}.json" + return await self._run(self._read_json, path) + + async def read_all_partials(self) -> list[dict[str, Any]]: + """Return all partial checkpoints, sorted by seq. Used during restore + to surface any in-flight turn that the last process didn't finish. + """ + + def _read_all() -> list[dict[str, Any]]: + if not self._partials_dir.exists(): + return [] + files = sorted(self._partials_dir.glob("*.json")) + partials: list[dict[str, Any]] = [] + for f in files: + data = self._read_json(f) + if data is not None: + partials.append(data) + return partials + + return await self._run(_read_all) + + async def clear_partial(self, seq: int) -> None: + def _clear() -> None: + path = self._partials_dir / f"{seq:010d}.json" + if path.exists(): + path.unlink() + + await self._run(_clear) + async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None: def _delete() -> None: if not self._parts_dir.exists(): @@ -125,6 +167,10 @@ class FileConversationStore: if self._parts_dir.exists(): for f in self._parts_dir.glob("*.json"): f.unlink() + # Clear partial checkpoints + if self._partials_dir.exists(): + for f in self._partials_dir.glob("*.json"): + f.unlink() # Clear cursor cursor_path = self._base / "cursor.json" if cursor_path.exists(): diff --git a/core/tests/test_event_loop_node.py b/core/tests/test_event_loop_node.py index ddf0d9fe..232b9891 100644 --- a/core/tests/test_event_loop_node.py +++ b/core/tests/test_event_loop_node.py @@ -2109,3 +2109,137 @@ class TestToolConcurrencyPartition: # Both tools must have run: soft errors don't cascade. assert executed == ["call_1", "call_2"] + + +# =========================================================================== +# Replay detector (warn + execute) +# =========================================================================== + + +class TestReplayDetector: + @pytest.mark.asyncio + async def test_replay_emits_event_and_prefixes_result( + self, tmp_path, runtime, node_spec, buffer + ): + """Re-emitting a tool call whose prior result succeeded fires the + TOOL_CALL_REPLAY_DETECTED event and prepends a steer onto the stored + result, but still executes the call (warn + execute).""" + node_spec.output_keys = [] + + async def tool_exec(tool_use: ToolUse) -> ToolResult: + return ToolResult( + tool_use_id=tool_use.id, + content=f"fresh result for {tool_use.id}", + is_error=False, + ) + + # Turn 1: model calls browser_setup with id=call_1 + # Turn 2: model calls browser_setup AGAIN with id=call_2 (the replay) + # Turn 3: text stop + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario("browser_setup", {}, tool_use_id="call_1"), + tool_call_scenario("browser_setup", {}, tool_use_id="call_2"), + text_scenario("done"), + ] + ) + + tools = [Tool(name="browser_setup", description="", parameters={})] + + # Capture events from the bus. + captured: list[Any] = [] + bus = EventBus() + + async def _collect(evt): + captured.append(evt) + + bus.subscribe([EventType.TOOL_CALL_REPLAY_DETECTED], _collect) + + ctx = build_ctx( + runtime, + node_spec, + buffer, + llm, + tools=tools, + is_subagent_mode=True, + ) + store = FileConversationStore(tmp_path / "conv") + node = EventLoopNode( + tool_executor=tool_exec, + conversation_store=store, + event_bus=bus, + config=LoopConfig(max_iterations=5), + ) + await node.execute(ctx) + + # Exactly one replay-detected event fired for the second call. + assert len(captured) == 1 + assert captured[0].data["tool_name"] == "browser_setup" + + # The stored tool result for the replay carries the steer prefix, + # and the real execution output is preserved. + parts = await store.read_parts() + tool_msgs = [ + p for p in parts if p.get("role") == "tool" and p.get("tool_use_id") == "call_2" + ] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["content"].startswith("[Replay detected: browser_setup") + assert "fresh result for call_2" in tool_msgs[0]["content"] + + # The first call's result is untouched. + first = [ + p for p in parts if p.get("role") == "tool" and p.get("tool_use_id") == "call_1" + ] + assert first[0]["content"] == "fresh result for call_1" + + @pytest.mark.asyncio + async def test_replay_with_error_prior_does_not_fire( + self, tmp_path, runtime, node_spec, buffer + ): + """A prior call that errored does not count as a successful completion, + so re-emitting it is legitimate (not a replay).""" + node_spec.output_keys = [] + + async def tool_exec(tool_use: ToolUse) -> ToolResult: + is_err = tool_use.id == "call_1" + return ToolResult( + tool_use_id=tool_use.id, + content=("boom" if is_err else "ok"), + is_error=is_err, + ) + + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario("flaky", {}, tool_use_id="call_1"), + tool_call_scenario("flaky", {}, tool_use_id="call_2"), + text_scenario("recovered"), + ] + ) + tools = [Tool(name="flaky", description="", parameters={})] + + captured: list[Any] = [] + bus = EventBus() + + async def _collect(evt): + captured.append(evt) + + bus.subscribe([EventType.TOOL_CALL_REPLAY_DETECTED], _collect) + + ctx = build_ctx( + runtime, + node_spec, + buffer, + llm, + tools=tools, + is_subagent_mode=True, + ) + store = FileConversationStore(tmp_path / "conv") + node = EventLoopNode( + tool_executor=tool_exec, + conversation_store=store, + event_bus=bus, + config=LoopConfig(max_iterations=5), + ) + await node.execute(ctx) + + assert captured == [] diff --git a/core/tests/test_node_conversation.py b/core/tests/test_node_conversation.py index f43e80c0..6bce88d3 100644 --- a/core/tests/test_node_conversation.py +++ b/core/tests/test_node_conversation.py @@ -24,6 +24,7 @@ class MockConversationStore: def __init__(self) -> None: self._parts: dict[int, dict] = {} + self._partials: dict[int, dict] = {} self._meta: dict | None = None self._cursor: dict | None = None @@ -48,6 +49,18 @@ class MockConversationStore: async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None: self._parts = {k: v for k, v in self._parts.items() if k >= seq} + async def write_partial(self, seq: int, data: dict[str, Any]) -> None: + self._partials[seq] = data + + async def read_partial(self, seq: int) -> dict[str, Any] | None: + return self._partials.get(seq) + + async def read_all_partials(self) -> list[dict[str, Any]]: + return [self._partials[k] for k in sorted(self._partials)] + + async def clear_partial(self, seq: int) -> None: + self._partials.pop(seq, None) + async def close(self) -> None: pass @@ -750,6 +763,33 @@ class TestFileConversationStore: assert (base / "parts" / "0000000000.json").exists() assert (base / "parts" / "0000000001.json").exists() + @pytest.mark.asyncio + async def test_partials_separate_from_parts(self, tmp_path): + """Partial checkpoints must not pollute read_parts() and vice versa.""" + store = FileConversationStore(tmp_path / "conv") + await store.write_part(0, {"seq": 0, "content": "real"}) + await store.write_partial(1, {"seq": 1, "content": "inflight", "truncated": True}) + parts = await store.read_parts() + assert [p["seq"] for p in parts] == [0] + partials = await store.read_all_partials() + assert [p["seq"] for p in partials] == [1] + assert partials[0]["content"] == "inflight" + assert (await store.read_partial(1))["content"] == "inflight" + assert await store.read_partial(99) is None + await store.clear_partial(1) + assert await store.read_all_partials() == [] + + @pytest.mark.asyncio + async def test_partials_dir_does_not_break_parts_glob(self, tmp_path): + """delete_parts_before parses stems as int — partial files must not trip it.""" + store = FileConversationStore(tmp_path / "conv") + for i in range(3): + await store.write_part(i, {"seq": i}) + await store.write_partial(i + 100, {"seq": i + 100}) + await store.delete_parts_before(2) + assert [p["seq"] for p in await store.read_parts()] == [2] + assert [p["seq"] for p in await store.read_all_partials()] == [100, 101, 102] + # =================================================================== # Integration tests — real FileConversationStore, no mocks @@ -1646,3 +1686,169 @@ class TestRepairOrphanedToolCalls: roles = [m["role"] for m in repaired] assert roles == ["user", "assistant", "tool", "user"] assert repaired[2]["tool_call_id"] == "tc_2" + + +# =================================================================== +# Continue-nudge + replay-detector helpers (DS-14) +# =================================================================== + + +def _mk_assistant_with_tool_call(seq: int, tc_id: str, name: str, args: dict) -> Message: + return Message( + seq=seq, + role="assistant", + content="", + tool_calls=[ + { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(args)}, + } + ], + ) + + +class TestFindCompletedToolCall: + def test_returns_match_when_prior_non_error_result_exists(self): + conv = NodeConversation(system_prompt="s") + conv._messages = [ + Message(seq=0, role="user", content="go"), + _mk_assistant_with_tool_call(1, "tc_a", "browser_setup", {}), + Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"), + ] + match = conv.find_completed_tool_call("browser_setup", {}) + assert match is not None + assert match.seq == 1 + + def test_ignores_error_result(self): + conv = NodeConversation(system_prompt="s") + conv._messages = [ + Message(seq=0, role="user", content="go"), + _mk_assistant_with_tool_call(1, "tc_a", "browser_navigate", {"url": "x"}), + Message(seq=2, role="tool", content="boom", tool_use_id="tc_a", is_error=True), + ] + assert conv.find_completed_tool_call("browser_navigate", {"url": "x"}) is None + + def test_canonicalizes_json_args_regardless_of_key_order(self): + conv = NodeConversation(system_prompt="s") + # Prior args written in one order, new call re-emits in different order. + conv._messages = [ + Message(seq=0, role="user", content="go"), + _mk_assistant_with_tool_call(1, "tc_a", "fetch", {"b": 2, "a": 1}), + Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"), + ] + assert conv.find_completed_tool_call("fetch", {"a": 1, "b": 2}) is not None + # Different args should NOT match. + assert conv.find_completed_tool_call("fetch", {"a": 1, "b": 3}) is None + + def test_respects_within_last_turns_window(self): + conv = NodeConversation(system_prompt="s") + # Prior successful call, then 4 newer assistant turns of noise. + conv._messages = [ + Message(seq=0, role="user", content="go"), + _mk_assistant_with_tool_call(1, "tc_a", "browser_setup", {}), + Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"), + ] + # 4 newer assistant turns (no tool calls that match) + for i in range(3, 7): + conv._messages.append( + Message(seq=i, role="assistant", content=f"noise {i}") + ) + # Window=3 → prior assistant with browser_setup is at turn index 5 + # backwards (noise, noise, noise, noise, setup) — skipped. + assert ( + conv.find_completed_tool_call("browser_setup", {}, within_last_turns=3) + is None + ) + # Window=10 → found. + assert ( + conv.find_completed_tool_call("browser_setup", {}, within_last_turns=10) + is not None + ) + + +class TestPartialCheckpoint: + @pytest.mark.asyncio + async def test_checkpoint_is_cleared_when_real_part_lands(self, tmp_path): + """A partial for seq N is wiped once add_assistant_message(seq=N) persists.""" + store = FileConversationStore(tmp_path / "c") + conv = NodeConversation(system_prompt="s", store=store) + await conv.add_user_message("hi") + # Seed a partial for the would-be next assistant seq. + await conv.checkpoint_partial_assistant("half-written...") + partials = await store.read_all_partials() + assert len(partials) == 1 + assert partials[0]["content"] == "half-written..." + # Commit the real assistant turn — partial should be swept. + await conv.add_assistant_message("fully written") + assert await store.read_all_partials() == [] + + @pytest.mark.asyncio + async def test_restore_surfaces_partial_as_truncated_message(self, tmp_path): + """A partial left behind by a crashed stream is resurrected on restore.""" + store = FileConversationStore(tmp_path / "c") + conv = NodeConversation(system_prompt="s", store=store) + await conv.add_user_message("hi") + # Simulate a stream that produced some text + a tool call, then died + # before finishing. The checkpoint captures both. + await conv.checkpoint_partial_assistant( + "I was working on this when the stream died", + tool_calls=[ + { + "id": "tc_x", + "type": "function", + "function": {"name": "browser_click", "arguments": "{}"}, + } + ], + ) + # Fresh process — restore from disk. + fresh = await NodeConversation.restore(store) + assert fresh is not None + # The user message is there, plus the truncated assistant resurrected + # from the partial. + roles = [m.role for m in fresh.messages] + assert roles == ["user", "assistant"] + last = fresh.messages[-1] + assert last.truncated is True + assert last.content == "I was working on this when the stream died" + assert last.tool_calls and last.tool_calls[0]["function"]["name"] == "browser_click" + + @pytest.mark.asyncio + async def test_restore_cleans_stale_partials(self, tmp_path): + """A partial whose seq was already committed as a real part is discarded.""" + store = FileConversationStore(tmp_path / "c") + conv = NodeConversation(system_prompt="s", store=store) + await conv.add_user_message("hi") + await conv.add_assistant_message("real") # seq=1 + # Manually plant a stale partial at seq=1 (already committed). + await store.write_partial( + 1, {"seq": 1, "role": "assistant", "content": "stale", "truncated": True} + ) + fresh = await NodeConversation.restore(store) + assert fresh is not None + assert [m.content for m in fresh.messages] == ["hi", "real"] + # Stale partial swept by restore. + assert await store.read_all_partials() == [] + + +class TestMessageFlags: + def test_is_system_nudge_roundtrip(self): + m = Message(seq=0, role="user", content="nudge", is_system_nudge=True) + d = m.to_storage_dict() + assert d.get("is_system_nudge") is True + r = Message.from_storage_dict(d) + assert r.is_system_nudge is True + assert r.role == "user" + + def test_truncated_roundtrip(self): + m = Message(seq=0, role="assistant", content="half", truncated=True) + d = m.to_storage_dict() + assert d.get("truncated") is True + r = Message.from_storage_dict(d) + assert r.truncated is True + + def test_defaults_omit_flags_from_storage(self): + m = Message(seq=0, role="user", content="plain") + d = m.to_storage_dict() + assert "is_system_nudge" not in d + assert "truncated" not in d