fix: partial parts and system nudge

This commit is contained in:
Timothy
2026-04-17 04:06:59 -07:00
parent 3c2161aad5
commit 6be026fcb1
8 changed files with 962 additions and 51 deletions
+3 -1
View File
@@ -55,7 +55,9 @@
"mcp__gcu-tools__browser_click_coordinate",
"mcp__gcu-tools__browser_get_rect",
"mcp__gcu-tools__browser_type_focused",
"mcp__gcu-tools__browser_wait"
"mcp__gcu-tools__browser_wait",
"Bash(python3 -c ' *)",
"Bash(python3 scripts/debug_queen_prompt.py independent)"
],
"additionalDirectories": [
"/home/timothy/.hive/skills/writing-hive-skills",
+285 -44
View File
@@ -2335,6 +2335,11 @@ class AgentLoop(AgentProtocol):
execution_id,
)
# Continue-nudge counter: how many times we've re-streamed within this
# _run_single_turn because the idle/TTFT watchdog fired. Caps to avoid
# nudging forever when the endpoint is genuinely dead.
_nudge_count_this_turn = 0
# Inner tool loop: stream may produce tool calls requiring re-invocation
while True:
# Pre-send guard: if context is at or over budget, compact before
@@ -2423,7 +2428,16 @@ class AgentLoop(AgentProtocol):
# Capture loop-scoped variables as defaults to satisfy B023.
# _stream_last_event_at is bumped on every event; the watchdog
# below uses it to detect silently hung HTTP connections.
_stream_last_event_at = time.monotonic()
_stream_start_at = time.monotonic()
_stream_last_event_at = _stream_start_at
# None until the first event arrives. Before first event, the
# watchdog uses the (much looser) TTFT budget — large-context
# local models legitimately take minutes to first token. Once
# any event has been observed, tight inter-event idle applies.
_first_event_at: float | None = None
# Partial tool_calls accumulated so far, as OpenAI-format dicts
# ready for persistence if the stream is cut short.
_partial_tc_dicts: list[dict[str, Any]] = []
async def _do_stream(
_msgs: list = messages, # noqa: B006
@@ -2432,8 +2446,10 @@ class AgentLoop(AgentProtocol):
_safe_names: set = _early_safe_names, # noqa: B006,B008
_tasks: dict = _early_tasks, # noqa: B006,B008
_exec_fn=_timed_execute,
_partial_dicts: list[dict[str, Any]] = _partial_tc_dicts, # noqa: B006,B008
) -> None:
nonlocal accumulated_text, _stream_error, _stream_last_event_at
nonlocal _first_event_at
_clean_snapshot = "" # visible-only text for the frontend
async for event in ctx.llm.stream(
@@ -2443,6 +2459,8 @@ class AgentLoop(AgentProtocol):
max_tokens=ctx.max_tokens,
):
_stream_last_event_at = time.monotonic()
if _first_event_at is None:
_first_event_at = _stream_last_event_at
if isinstance(event, TextDeltaEvent):
accumulated_text = event.snapshot
# Strip internal reasoning tags from the full
@@ -2462,9 +2480,46 @@ class AgentLoop(AgentProtocol):
iteration=iteration,
inner_turn=inner_turn,
)
# Checkpoint partial state so a watchdog cancel or
# crash doesn't discard whatever the model has
# produced so far. Cheap — one atomic file write.
try:
await conversation.checkpoint_partial_assistant(
accumulated_text,
_partial_dicts or None,
)
except Exception as _cp_err: # noqa: BLE001
logger.debug(
"[_run_single_turn] partial checkpoint failed: %s",
_cp_err,
)
elif isinstance(event, ToolCallEvent):
_tc.append(event)
_partial_dicts.append(
{
"id": event.tool_use_id,
"type": "function",
"function": {
"name": event.tool_name,
"arguments": json.dumps(event.tool_input),
},
}
)
# Checkpoint now that a tool call has landed —
# this is the important one: if the stream dies
# right after a tool call but before FinishEvent,
# we still have the intent recorded.
try:
await conversation.checkpoint_partial_assistant(
accumulated_text,
_partial_dicts or None,
)
except Exception as _cp_err: # noqa: BLE001
logger.debug(
"[_run_single_turn] partial checkpoint failed: %s",
_cp_err,
)
# Gap 1: start concurrency-safe tools immediately
# while the rest of the stream is still arriving,
# so read-heavy turns don't stall after the last
@@ -2492,55 +2547,99 @@ class AgentLoop(AgentProtocol):
_llm_stream_t0 = time.monotonic()
self._stream_task = asyncio.create_task(_do_stream())
logger.debug("[_run_single_turn] inner_turn=%d: Stream task created, waiting...", inner_turn)
_inactivity_limit = self._config.llm_stream_inactivity_timeout_seconds
# Watchdog budgets — see LoopConfig docstring for rationale.
_ttft_limit = self._config.llm_stream_ttft_timeout_seconds
_inter_event_limit = self._config.llm_stream_inter_event_idle_seconds
# Back-compat: if the legacy inactivity knob was overridden to
# a value below the new default, respect it as the inter-event
# budget (historic behaviour) so existing configs don't regress.
_legacy = self._config.llm_stream_inactivity_timeout_seconds
if _legacy and _legacy > 0 and _legacy < _inter_event_limit:
_inter_event_limit = _legacy
_watchdog_active = (_ttft_limit and _ttft_limit > 0) or (
_inter_event_limit and _inter_event_limit > 0
)
# Result of the watchdog: "ok" (stream finished), "ttft" (no first
# event in budget), "inactive" (silence after first event).
_watchdog_verdict: str = "ok"
_watchdog_elapsed: float = 0.0
_watchdog_limit: float = 0.0
try:
if _inactivity_limit and _inactivity_limit > 0:
# Heartbeat-aware wait: poll the task and cancel it if
# no stream event has been observed within the window.
# A silently dead HTTP connection otherwise hangs here
# forever — no exception, no delta, no timeout.
#
# Must use asyncio.wait (not wait_for) so we can tell
# "poll interval elapsed" apart from "task raised a
# TimeoutError of its own" — wait_for conflates them.
_check_interval = min(5.0, _inactivity_limit / 2)
if _watchdog_active:
# Poll cheapest-valid interval: at most every 5s, at least
# half the tighter budget. Must use asyncio.wait (not
# wait_for) so "poll interval elapsed" and "task raised
# TimeoutError of its own" stay distinguishable.
_tight = min(
_ttft_limit or float("inf"),
_inter_event_limit or float("inf"),
)
_check_interval = max(1.0, min(5.0, _tight / 2))
while True:
done, _pending = await asyncio.wait({self._stream_task}, timeout=_check_interval)
done, _pending = await asyncio.wait(
{self._stream_task}, timeout=_check_interval
)
if self._stream_task in done:
# Let any exception the task raised propagate
# naturally via the outer ``await`` below.
break
idle = time.monotonic() - _stream_last_event_at
if idle >= _inactivity_limit:
logger.warning(
"[_run_single_turn] inner_turn=%d: "
"stream inactivity %.0fs >= %.0fs — "
"cancelling stream task",
inner_turn,
idle,
_inactivity_limit,
)
self._bump("stream_inactivity_watchdog")
self._stream_task.cancel()
try:
await self._stream_task
except BaseException:
pass
raise ConnectionError(
f"LLM stream idle for {idle:.0f}s "
f"(inactivity limit {_inactivity_limit:.0f}s) — "
"connection presumed dead"
) from None
now = time.monotonic()
if _first_event_at is None:
# TTFT phase — stream open but silent. Use the
# looser budget; don't confuse slow models with
# dead connections.
elapsed = now - _stream_start_at
if _ttft_limit and _ttft_limit > 0 and elapsed >= _ttft_limit:
_watchdog_verdict = "ttft"
_watchdog_elapsed = elapsed
_watchdog_limit = _ttft_limit
break
else:
# Post-first-event silence. A stream that produced
# events and then went quiet is a real stall.
idle = now - _stream_last_event_at
if (
_inter_event_limit
and _inter_event_limit > 0
and idle >= _inter_event_limit
):
_watchdog_verdict = "inactive"
_watchdog_elapsed = idle
_watchdog_limit = _inter_event_limit
break
# Still active — keep polling.
# Re-raise any exception the stream task stored. When the
# watchdog loop exited via ``break`` the task is done, and
# ``await`` is the cheapest way to surface its exception.
await self._stream_task
logger.debug("[_run_single_turn] inner_turn=%d: Stream task completed normally", inner_turn)
if _watchdog_verdict != "ok":
logger.warning(
"[_run_single_turn] inner_turn=%d: watchdog=%s %.0fs >= %.0fs — cancelling stream",
inner_turn,
_watchdog_verdict,
_watchdog_elapsed,
_watchdog_limit,
)
self._bump(f"stream_watchdog_{_watchdog_verdict}")
self._stream_task.cancel()
try:
await self._stream_task
except BaseException:
pass
else:
# Re-raise any exception the stream task stored. When the
# watchdog loop exited via ``break`` the task is done, and
# ``await`` is the cheapest way to surface its exception.
await self._stream_task
logger.debug(
"[_run_single_turn] inner_turn=%d: Stream task completed normally",
inner_turn,
)
except asyncio.CancelledError:
logger.debug("[_run_single_turn] inner_turn=%d: Stream task cancelled", inner_turn)
if accumulated_text:
await conversation.add_assistant_message(content=accumulated_text)
if accumulated_text or _partial_tc_dicts:
await conversation.add_assistant_message(
content=accumulated_text,
tool_calls=_partial_tc_dicts or None,
truncated=True,
)
# Gap 1: kill any early-dispatched tool tasks too.
# Without this, a safe tool started during streaming
# would leak past cancellation and keep running.
@@ -2568,6 +2667,100 @@ class AgentLoop(AgentProtocol):
raise
finally:
self._stream_task = None
# Continue-nudge recovery path. Runs AFTER the stream task is
# cleaned up so all state is consistent. We persist whatever
# partial text + tool-calls the model produced (as a truncated
# message so the model can see its own in-flight work on the
# next turn), cancel early tool tasks, append a terse
# continuation hint, and restart the stream.
if _watchdog_verdict != "ok":
# Kill any safe-tool tasks the stream dispatched early —
# their results would have had nowhere to land anyway
# because the assistant message was incomplete.
for _early in _early_tasks.values():
if not _early.done():
_early.cancel()
# Promote whatever we captured into a real truncated
# message. The partial checkpoint for this seq is cleared
# automatically when add_assistant_message persists.
if accumulated_text or _partial_tc_dicts:
await conversation.add_assistant_message(
content=accumulated_text,
tool_calls=_partial_tc_dicts or None,
truncated=True,
)
reason_label = (
"no tokens before TTFT budget"
if _watchdog_verdict == "ttft"
else "stream went silent after producing events"
)
if self._event_bus:
if _watchdog_verdict == "ttft":
await self._event_bus.emit_stream_ttft_exceeded(
stream_id=stream_id,
node_id=node_id,
ttft_seconds=_watchdog_elapsed,
limit_seconds=_watchdog_limit,
execution_id=execution_id,
)
else:
await self._event_bus.emit_stream_inactive(
stream_id=stream_id,
node_id=node_id,
idle_seconds=_watchdog_elapsed,
limit_seconds=_watchdog_limit,
execution_id=execution_id,
)
nudge_enabled = self._config.continue_nudge_enabled
nudge_cap = self._config.continue_nudge_max_per_turn
if nudge_enabled and _nudge_count_this_turn < nudge_cap:
_nudge_count_this_turn += 1
nudge_msg = (
f"[System: the previous stream stalled ({reason_label}, "
f"{_watchdog_elapsed:.0f}s). Continue from the last tool "
"result already in this conversation. Do NOT repeat tool "
"calls whose results are visible above — reuse them and "
"move to the next step.]"
)
await conversation.add_user_message(
nudge_msg,
is_system_nudge=True,
)
if self._event_bus:
await self._event_bus.emit_stream_nudge_sent(
stream_id=stream_id,
node_id=node_id,
reason=_watchdog_verdict,
nudge_count=_nudge_count_this_turn,
execution_id=execution_id,
)
logger.info(
"[%s] continue-nudge sent (count=%d/%d, reason=%s)",
node_id,
_nudge_count_this_turn,
nudge_cap,
_watchdog_verdict,
)
# Reset the outer _turn_t0 timer so the "LLM done in
# Xms" log line reflects real work not the nudge cycle.
_llm_stream_ms = int((time.monotonic() - _llm_stream_t0) * 1000)
logger.debug(
"[_run_single_turn] inner_turn=%d: nudge restart after %dms",
inner_turn,
_llm_stream_ms,
)
continue # restart the inner loop, re-fetches messages
# Nudge disabled or cap exhausted — fall back to the
# existing retry path so a truly dead endpoint eventually
# surfaces as an error.
raise ConnectionError(
f"LLM stream {_watchdog_verdict} for {_watchdog_elapsed:.0f}s "
f"(limit {_watchdog_limit:.0f}s) — nudge cap reached"
)
_llm_stream_ms = int((time.monotonic() - _llm_stream_t0) * 1000)
# If a recoverable stream error produced an empty response,
@@ -2667,6 +2860,12 @@ class AgentLoop(AgentProtocol):
results_by_id: dict[str, ToolResult] = {}
timing_by_id: dict[str, dict[str, Any]] = {} # tool_use_id -> {start_timestamp, duration_s}
pending_real: list[ToolCallEvent] = []
# Replay detector: per-turn map from tool_use_id -> steer prefix.
# Populated below when we detect that the model is re-emitting a
# tool call whose (name + canonical args) matches a prior success.
# Applied to the stored tool result content so the model sees the
# nudge on its next turn without losing the real execution output.
replay_prefixes_by_id: dict[str, str] = {}
for tc in tool_calls:
tool_call_count += 1
@@ -2939,6 +3138,39 @@ class AgentLoop(AgentProtocol):
)
results_by_id[tc.tool_use_id] = result
else:
# Replay detector: flag re-executions of recent
# successful calls. We still run the tool (some
# are legitimately repeated, e.g. screenshots and
# read-only evaluates) but prepend a terse steer
# onto the stored result so the model sees the
# signal on its next turn.
if self._config.replay_detector_enabled:
prior = conversation.find_completed_tool_call(
tc.tool_name,
tc.tool_input,
within_last_turns=self._config.replay_detector_within_last_turns,
)
if prior is not None:
logger.warning(
"[%s] replay detected: %s matches prior seq=%d — executing anyway",
node_id,
tc.tool_name,
prior.seq,
)
self._bump("tool_call_replay_detected")
if self._event_bus:
await self._event_bus.emit_tool_call_replay_detected(
stream_id=stream_id,
node_id=node_id,
tool_name=tc.tool_name,
prior_seq=prior.seq,
execution_id=execution_id,
)
replay_prefixes_by_id[tc.tool_use_id] = (
f"[Replay detected: {tc.tool_name} matches "
f"seq={prior.seq}. Result still produced below — "
"consider whether the retry was necessary.]\n"
)
pending_real.append(tc)
# Phase 2a: partition real tools by concurrency safety.
@@ -3136,9 +3368,18 @@ class AgentLoop(AgentProtocol):
)
image_content = None
# Apply replay-detector steer prefix if this call matched a
# recent successful invocation. Only applies to non-error
# results — an error already breaks the replay chain.
stored_content = result.content
if not result.is_error:
_prefix = replay_prefixes_by_id.get(tc.tool_use_id)
if _prefix:
stored_content = f"{_prefix}{stored_content or ''}"
await conversation.add_tool_result(
tool_use_id=tc.tool_use_id,
content=result.content,
content=stored_content,
is_error=result.is_error,
image_content=image_content,
is_skill_content=result.is_skill_content,
+160
View File
@@ -48,6 +48,14 @@ class Message:
is_skill_content: bool = False
# Logical worker run identifier for shared-session persistence
run_id: str | None = None
# True when this is a framework-injected continuation hint (continue-nudge
# on stream stall). Stored as a user message for API compatibility, but
# the UI should render it as a compact system notice, not user speech.
is_system_nudge: bool = False
# True when this message is a partial/truncated assistant turn reconstructed
# from a crashed or watchdog-cancelled stream. Signals that the original
# turn never finished — the model may or may not choose to redo it.
truncated: bool = False
def to_llm_dict(self) -> dict[str, Any]:
"""Convert to OpenAI-format message dict."""
@@ -109,6 +117,10 @@ class Message:
d["image_content"] = self.image_content
if self.run_id is not None:
d["run_id"] = self.run_id
if self.is_system_nudge:
d["is_system_nudge"] = self.is_system_nudge
if self.truncated:
d["truncated"] = self.truncated
return d
@classmethod
@@ -126,6 +138,8 @@ class Message:
is_client_input=data.get("is_client_input", False),
image_content=data.get("image_content"),
run_id=data.get("run_id"),
is_system_nudge=data.get("is_system_nudge", False),
truncated=data.get("truncated", False),
)
@@ -317,6 +331,14 @@ class ConversationStore(Protocol):
async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None: ...
async def write_partial(self, seq: int, data: dict[str, Any]) -> None: ...
async def read_partial(self, seq: int) -> dict[str, Any] | None: ...
async def read_all_partials(self) -> list[dict[str, Any]]: ...
async def clear_partial(self, seq: int) -> None: ...
async def close(self) -> None: ...
async def destroy(self) -> None: ...
@@ -462,6 +484,7 @@ class NodeConversation:
is_transition_marker: bool = False,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
is_system_nudge: bool = False,
) -> Message:
msg = Message(
seq=self._next_seq,
@@ -472,6 +495,7 @@ class NodeConversation:
is_transition_marker=is_transition_marker,
is_client_input=is_client_input,
image_content=image_content,
is_system_nudge=is_system_nudge,
)
self._messages.append(msg)
self._next_seq += 1
@@ -485,6 +509,8 @@ class NodeConversation:
self,
content: str,
tool_calls: list[dict[str, Any]] | None = None,
*,
truncated: bool = False,
) -> Message:
msg = Message(
seq=self._next_seq,
@@ -493,6 +519,7 @@ class NodeConversation:
tool_calls=tool_calls,
phase_id=self._current_phase,
run_id=self._run_id,
truncated=truncated,
)
self._messages.append(msg)
self._next_seq += 1
@@ -548,6 +575,59 @@ class NodeConversation:
# --- Query -------------------------------------------------------------
def find_completed_tool_call(
self,
name: str,
tool_input: dict[str, Any],
within_last_turns: int = 3,
) -> Message | None:
"""Return the most recent assistant message that issued a tool call
with the same (name + canonical-json args) AND received a non-error
tool result, within the last ``within_last_turns`` assistant turns.
Used by the replay detector to flag when the model is about to redo
a successful call we prepend a steer onto the upcoming result but
still execute, so tools like browser_screenshot that are legitimately
repeated are not silently skipped.
"""
try:
target_canonical = json.dumps(tool_input, sort_keys=True, default=str)
except (TypeError, ValueError):
target_canonical = str(tool_input)
# Walk backwards over recent assistant messages
assistant_turns_seen = 0
for idx in range(len(self._messages) - 1, -1, -1):
m = self._messages[idx]
if m.role != "assistant":
continue
assistant_turns_seen += 1
if assistant_turns_seen > within_last_turns:
break
if not m.tool_calls:
continue
for tc in m.tool_calls:
func = tc.get("function", {}) if isinstance(tc, dict) else {}
tc_name = func.get("name")
if tc_name != name:
continue
args_str = func.get("arguments", "")
try:
parsed = json.loads(args_str) if isinstance(args_str, str) else args_str
canonical = json.dumps(parsed, sort_keys=True, default=str)
except (TypeError, ValueError):
canonical = str(args_str)
if canonical != target_canonical:
continue
# Found a match — now verify its result was not an error.
tc_id = tc.get("id")
for later in self._messages[idx + 1 :]:
if later.role == "tool" and later.tool_use_id == tc_id:
if not later.is_error:
return m
break
return None
def to_llm_messages(self) -> list[dict[str, Any]]:
"""Return messages as OpenAI-format dicts (system prompt excluded).
@@ -1365,6 +1445,45 @@ class NodeConversation:
await self._persist_meta()
await self._store.write_part(message.seq, message.to_storage_dict())
await self._write_next_seq()
# Any partial checkpoint for this seq is now superseded by the real
# part — clear it so a future restore doesn't resurrect stale text.
try:
await self._store.clear_partial(message.seq)
except AttributeError:
# Older stores may not implement partials; ignore.
pass
async def checkpoint_partial_assistant(
self,
accumulated_text: str,
tool_calls: list[dict[str, Any]] | None = None,
) -> None:
"""Write an in-flight assistant turn's state to disk under the next seq.
Called from the stream event loop. Safe to call repeatedly each call
overwrites the prior checkpoint. Persisted via ``write_partial`` so it
does NOT appear in ``read_parts()`` and cannot be double-loaded. Cleared
automatically when ``add_assistant_message`` for this seq lands.
"""
if self._store is None:
return
if not self._meta_persisted:
await self._persist_meta()
payload: dict[str, Any] = {
"seq": self._next_seq,
"role": "assistant",
"content": accumulated_text,
"phase_id": self._current_phase,
"run_id": self._run_id,
"truncated": True,
}
if tool_calls:
payload["tool_calls"] = tool_calls
try:
await self._store.write_partial(self._next_seq, payload)
except AttributeError:
# Older stores may not implement partials; ignore.
pass
async def _persist_meta(self) -> None:
"""Lazily write conversation metadata to the store (called once).
@@ -1461,4 +1580,45 @@ class NodeConversation:
elif conv._messages:
conv._next_seq = conv._messages[-1].seq + 1
# Surface any leftover partial checkpoints as truncated messages so
# the next turn sees what the interrupted stream was in the middle
# of producing. Only partials whose seq is >= next_seq are meaningful;
# anything lower was already superseded by a real part.
try:
partials = await store.read_all_partials()
except AttributeError:
partials = []
for p in partials:
pseq = p.get("seq", -1)
if pseq < conv._next_seq:
# Stale — clean it up.
try:
await store.clear_partial(pseq)
except AttributeError:
pass
continue
# Only resurrect partials relevant to this run / phase.
if run_id and not is_legacy_run_id(run_id) and p.get("run_id") != run_id:
continue
if phase_id and p.get("phase_id") is not None and p.get("phase_id") != phase_id:
continue
# Reconstruct as a truncated assistant message.
msg = Message(
seq=pseq,
role="assistant",
content=p.get("content", "") or "",
tool_calls=p.get("tool_calls"),
phase_id=p.get("phase_id"),
run_id=p.get("run_id"),
truncated=True,
)
conv._messages.append(msg)
conv._next_seq = max(conv._next_seq, pseq + 1)
logger.info(
"restore: resurrected truncated partial seq=%d (text=%d chars, tool_calls=%d)",
pseq,
len(msg.content),
len(msg.tool_calls or []),
)
return conv
+31 -6
View File
@@ -131,14 +131,39 @@ class LoopConfig:
# Per-tool-call timeout.
tool_call_timeout_seconds: float = 60.0
# LLM stream inactivity watchdog. If no stream event (delta, tool call,
# finish) arrives within this many seconds, the stream task is cancelled
# and a transient error is raised so the retry loop can back off and
# reconnect. Prevents agents from hanging forever on a silently dead
# HTTP connection (no provider heartbeat, no exception, just silence).
# Set to 0 to disable.
# LLM stream inactivity watchdog. Split into two budgets so legitimate
# slow TTFT on large contexts doesn't get mistaken for a dead connection.
# - ttft: stream open -> first event. Large-context local models can
# legitimately take minutes before the first token arrives.
# - inter_event: last event -> now, ONLY after the first event. A stream
# that started producing and then went silent is a real stall.
# Whichever fires first cancels the stream. Set to 0 to disable that
# individual budget; set both to 0 to fully disable the watchdog.
llm_stream_ttft_timeout_seconds: float = 600.0
llm_stream_inter_event_idle_seconds: float = 120.0
# Deprecated alias — kept so existing configs keep working. If set to a
# non-default value it overrides inter_event_idle (historical behavior).
llm_stream_inactivity_timeout_seconds: float = 120.0
# Continue-nudge recovery. When the idle watchdog fires on a live but
# stuck stream, cancel the stream and append a short continuation
# hint to the conversation instead of raising a ConnectionError and
# re-running the whole turn. Preserves any partial text/tool-calls the
# stream emitted before the stall.
continue_nudge_enabled: bool = True
# Cap so a truly dead endpoint eventually falls back to the error path
# instead of nudging forever.
continue_nudge_max_per_turn: int = 3
# Tool-call replay detector. When the model emits a tool call whose
# (name + canonical-args) matches a prior successful call in the last
# K assistant turns, emit telemetry and prepend a short steer onto the
# tool result — but still execute. Weaker models legitimately repeat
# read-only calls (screenshot, evaluate), so silent skipping would
# cause surprising behavior.
replay_detector_enabled: bool = True
replay_detector_within_last_turns: int = 3
# Subagent delegation timeout (wall-clock max).
subagent_timeout_seconds: float = 3600.0
+97
View File
@@ -111,6 +111,15 @@ class EventType(StrEnum):
# Retry tracking
NODE_RETRY = "node_retry"
# Stream-health observability. Split from NODE_RETRY so the UI can
# distinguish "slow TTFT on a huge context" (healthy, just slow) from
# "stream went silent mid-generation" (probable stall) from "we nudged
# the model to continue" (recovery), which NODE_RETRY used to conflate.
STREAM_TTFT_EXCEEDED = "stream_ttft_exceeded"
STREAM_INACTIVE = "stream_inactive"
STREAM_NUDGE_SENT = "stream_nudge_sent"
TOOL_CALL_REPLAY_DETECTED = "tool_call_replay_detected"
# Worker agent lifecycle
WORKER_COMPLETED = "worker_completed"
WORKER_FAILED = "worker_failed"
@@ -1061,6 +1070,94 @@ class EventBus:
)
)
async def emit_stream_ttft_exceeded(
self,
stream_id: str,
node_id: str,
ttft_seconds: float,
limit_seconds: float,
execution_id: str | None = None,
) -> None:
"""Emit when a stream stayed silent past the TTFT budget (no first event)."""
await self.publish(
AgentEvent(
type=EventType.STREAM_TTFT_EXCEEDED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={
"ttft_seconds": ttft_seconds,
"limit_seconds": limit_seconds,
},
)
)
async def emit_stream_inactive(
self,
stream_id: str,
node_id: str,
idle_seconds: float,
limit_seconds: float,
execution_id: str | None = None,
) -> None:
"""Emit when a stream that had produced events went silent past budget."""
await self.publish(
AgentEvent(
type=EventType.STREAM_INACTIVE,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={
"idle_seconds": idle_seconds,
"limit_seconds": limit_seconds,
},
)
)
async def emit_stream_nudge_sent(
self,
stream_id: str,
node_id: str,
reason: str,
nudge_count: int,
execution_id: str | None = None,
) -> None:
"""Emit when the continue-nudge was injected (recovery, not retry)."""
await self.publish(
AgentEvent(
type=EventType.STREAM_NUDGE_SENT,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={
"reason": reason,
"nudge_count": nudge_count,
},
)
)
async def emit_tool_call_replay_detected(
self,
stream_id: str,
node_id: str,
tool_name: str,
prior_seq: int,
execution_id: str | None = None,
) -> None:
"""Emit when the model is about to re-execute a prior successful call."""
await self.publish(
AgentEvent(
type=EventType.TOOL_CALL_REPLAY_DETECTED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={
"tool_name": tool_name,
"prior_seq": prior_seq,
},
)
)
async def emit_worker_completed(
self,
stream_id: str,
@@ -43,6 +43,10 @@ class FileConversationStore:
def __init__(self, base_path: str | Path) -> None:
self._base = Path(base_path)
self._parts_dir = self._base / "parts"
# Partial checkpoints for in-flight assistant turns. Written on every
# stream event, deleted atomically when the final part lands. Kept
# in a sibling dir so the parts/ glob doesn't pick them up.
self._partials_dir = self._base / "partials"
# --- sync helpers --------------------------------------------------------
@@ -99,6 +103,44 @@ class FileConversationStore:
async def read_cursor(self) -> dict[str, Any] | None:
return await self._run(self._read_json, self._base / "cursor.json")
async def write_partial(self, seq: int, data: dict[str, Any]) -> None:
"""Checkpoint an in-flight assistant turn. Overwrites any prior partial
for the same seq. Caller is expected to clear_partial() once the real
part is written via write_part().
"""
path = self._partials_dir / f"{seq:010d}.json"
await self._run(self._write_json, path, data)
async def read_partial(self, seq: int) -> dict[str, Any] | None:
path = self._partials_dir / f"{seq:010d}.json"
return await self._run(self._read_json, path)
async def read_all_partials(self) -> list[dict[str, Any]]:
"""Return all partial checkpoints, sorted by seq. Used during restore
to surface any in-flight turn that the last process didn't finish.
"""
def _read_all() -> list[dict[str, Any]]:
if not self._partials_dir.exists():
return []
files = sorted(self._partials_dir.glob("*.json"))
partials: list[dict[str, Any]] = []
for f in files:
data = self._read_json(f)
if data is not None:
partials.append(data)
return partials
return await self._run(_read_all)
async def clear_partial(self, seq: int) -> None:
def _clear() -> None:
path = self._partials_dir / f"{seq:010d}.json"
if path.exists():
path.unlink()
await self._run(_clear)
async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None:
def _delete() -> None:
if not self._parts_dir.exists():
@@ -125,6 +167,10 @@ class FileConversationStore:
if self._parts_dir.exists():
for f in self._parts_dir.glob("*.json"):
f.unlink()
# Clear partial checkpoints
if self._partials_dir.exists():
for f in self._partials_dir.glob("*.json"):
f.unlink()
# Clear cursor
cursor_path = self._base / "cursor.json"
if cursor_path.exists():
+134
View File
@@ -2109,3 +2109,137 @@ class TestToolConcurrencyPartition:
# Both tools must have run: soft errors don't cascade.
assert executed == ["call_1", "call_2"]
# ===========================================================================
# Replay detector (warn + execute)
# ===========================================================================
class TestReplayDetector:
@pytest.mark.asyncio
async def test_replay_emits_event_and_prefixes_result(
self, tmp_path, runtime, node_spec, buffer
):
"""Re-emitting a tool call whose prior result succeeded fires the
TOOL_CALL_REPLAY_DETECTED event and prepends a steer onto the stored
result, but still executes the call (warn + execute)."""
node_spec.output_keys = []
async def tool_exec(tool_use: ToolUse) -> ToolResult:
return ToolResult(
tool_use_id=tool_use.id,
content=f"fresh result for {tool_use.id}",
is_error=False,
)
# Turn 1: model calls browser_setup with id=call_1
# Turn 2: model calls browser_setup AGAIN with id=call_2 (the replay)
# Turn 3: text stop
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario("browser_setup", {}, tool_use_id="call_1"),
tool_call_scenario("browser_setup", {}, tool_use_id="call_2"),
text_scenario("done"),
]
)
tools = [Tool(name="browser_setup", description="", parameters={})]
# Capture events from the bus.
captured: list[Any] = []
bus = EventBus()
async def _collect(evt):
captured.append(evt)
bus.subscribe([EventType.TOOL_CALL_REPLAY_DETECTED], _collect)
ctx = build_ctx(
runtime,
node_spec,
buffer,
llm,
tools=tools,
is_subagent_mode=True,
)
store = FileConversationStore(tmp_path / "conv")
node = EventLoopNode(
tool_executor=tool_exec,
conversation_store=store,
event_bus=bus,
config=LoopConfig(max_iterations=5),
)
await node.execute(ctx)
# Exactly one replay-detected event fired for the second call.
assert len(captured) == 1
assert captured[0].data["tool_name"] == "browser_setup"
# The stored tool result for the replay carries the steer prefix,
# and the real execution output is preserved.
parts = await store.read_parts()
tool_msgs = [
p for p in parts if p.get("role") == "tool" and p.get("tool_use_id") == "call_2"
]
assert len(tool_msgs) == 1
assert tool_msgs[0]["content"].startswith("[Replay detected: browser_setup")
assert "fresh result for call_2" in tool_msgs[0]["content"]
# The first call's result is untouched.
first = [
p for p in parts if p.get("role") == "tool" and p.get("tool_use_id") == "call_1"
]
assert first[0]["content"] == "fresh result for call_1"
@pytest.mark.asyncio
async def test_replay_with_error_prior_does_not_fire(
self, tmp_path, runtime, node_spec, buffer
):
"""A prior call that errored does not count as a successful completion,
so re-emitting it is legitimate (not a replay)."""
node_spec.output_keys = []
async def tool_exec(tool_use: ToolUse) -> ToolResult:
is_err = tool_use.id == "call_1"
return ToolResult(
tool_use_id=tool_use.id,
content=("boom" if is_err else "ok"),
is_error=is_err,
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario("flaky", {}, tool_use_id="call_1"),
tool_call_scenario("flaky", {}, tool_use_id="call_2"),
text_scenario("recovered"),
]
)
tools = [Tool(name="flaky", description="", parameters={})]
captured: list[Any] = []
bus = EventBus()
async def _collect(evt):
captured.append(evt)
bus.subscribe([EventType.TOOL_CALL_REPLAY_DETECTED], _collect)
ctx = build_ctx(
runtime,
node_spec,
buffer,
llm,
tools=tools,
is_subagent_mode=True,
)
store = FileConversationStore(tmp_path / "conv")
node = EventLoopNode(
tool_executor=tool_exec,
conversation_store=store,
event_bus=bus,
config=LoopConfig(max_iterations=5),
)
await node.execute(ctx)
assert captured == []
+206
View File
@@ -24,6 +24,7 @@ class MockConversationStore:
def __init__(self) -> None:
self._parts: dict[int, dict] = {}
self._partials: dict[int, dict] = {}
self._meta: dict | None = None
self._cursor: dict | None = None
@@ -48,6 +49,18 @@ class MockConversationStore:
async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None:
self._parts = {k: v for k, v in self._parts.items() if k >= seq}
async def write_partial(self, seq: int, data: dict[str, Any]) -> None:
self._partials[seq] = data
async def read_partial(self, seq: int) -> dict[str, Any] | None:
return self._partials.get(seq)
async def read_all_partials(self) -> list[dict[str, Any]]:
return [self._partials[k] for k in sorted(self._partials)]
async def clear_partial(self, seq: int) -> None:
self._partials.pop(seq, None)
async def close(self) -> None:
pass
@@ -750,6 +763,33 @@ class TestFileConversationStore:
assert (base / "parts" / "0000000000.json").exists()
assert (base / "parts" / "0000000001.json").exists()
@pytest.mark.asyncio
async def test_partials_separate_from_parts(self, tmp_path):
"""Partial checkpoints must not pollute read_parts() and vice versa."""
store = FileConversationStore(tmp_path / "conv")
await store.write_part(0, {"seq": 0, "content": "real"})
await store.write_partial(1, {"seq": 1, "content": "inflight", "truncated": True})
parts = await store.read_parts()
assert [p["seq"] for p in parts] == [0]
partials = await store.read_all_partials()
assert [p["seq"] for p in partials] == [1]
assert partials[0]["content"] == "inflight"
assert (await store.read_partial(1))["content"] == "inflight"
assert await store.read_partial(99) is None
await store.clear_partial(1)
assert await store.read_all_partials() == []
@pytest.mark.asyncio
async def test_partials_dir_does_not_break_parts_glob(self, tmp_path):
"""delete_parts_before parses stems as int — partial files must not trip it."""
store = FileConversationStore(tmp_path / "conv")
for i in range(3):
await store.write_part(i, {"seq": i})
await store.write_partial(i + 100, {"seq": i + 100})
await store.delete_parts_before(2)
assert [p["seq"] for p in await store.read_parts()] == [2]
assert [p["seq"] for p in await store.read_all_partials()] == [100, 101, 102]
# ===================================================================
# Integration tests — real FileConversationStore, no mocks
@@ -1646,3 +1686,169 @@ class TestRepairOrphanedToolCalls:
roles = [m["role"] for m in repaired]
assert roles == ["user", "assistant", "tool", "user"]
assert repaired[2]["tool_call_id"] == "tc_2"
# ===================================================================
# Continue-nudge + replay-detector helpers (DS-14)
# ===================================================================
def _mk_assistant_with_tool_call(seq: int, tc_id: str, name: str, args: dict) -> Message:
return Message(
seq=seq,
role="assistant",
content="",
tool_calls=[
{
"id": tc_id,
"type": "function",
"function": {"name": name, "arguments": json.dumps(args)},
}
],
)
class TestFindCompletedToolCall:
def test_returns_match_when_prior_non_error_result_exists(self):
conv = NodeConversation(system_prompt="s")
conv._messages = [
Message(seq=0, role="user", content="go"),
_mk_assistant_with_tool_call(1, "tc_a", "browser_setup", {}),
Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"),
]
match = conv.find_completed_tool_call("browser_setup", {})
assert match is not None
assert match.seq == 1
def test_ignores_error_result(self):
conv = NodeConversation(system_prompt="s")
conv._messages = [
Message(seq=0, role="user", content="go"),
_mk_assistant_with_tool_call(1, "tc_a", "browser_navigate", {"url": "x"}),
Message(seq=2, role="tool", content="boom", tool_use_id="tc_a", is_error=True),
]
assert conv.find_completed_tool_call("browser_navigate", {"url": "x"}) is None
def test_canonicalizes_json_args_regardless_of_key_order(self):
conv = NodeConversation(system_prompt="s")
# Prior args written in one order, new call re-emits in different order.
conv._messages = [
Message(seq=0, role="user", content="go"),
_mk_assistant_with_tool_call(1, "tc_a", "fetch", {"b": 2, "a": 1}),
Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"),
]
assert conv.find_completed_tool_call("fetch", {"a": 1, "b": 2}) is not None
# Different args should NOT match.
assert conv.find_completed_tool_call("fetch", {"a": 1, "b": 3}) is None
def test_respects_within_last_turns_window(self):
conv = NodeConversation(system_prompt="s")
# Prior successful call, then 4 newer assistant turns of noise.
conv._messages = [
Message(seq=0, role="user", content="go"),
_mk_assistant_with_tool_call(1, "tc_a", "browser_setup", {}),
Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"),
]
# 4 newer assistant turns (no tool calls that match)
for i in range(3, 7):
conv._messages.append(
Message(seq=i, role="assistant", content=f"noise {i}")
)
# Window=3 → prior assistant with browser_setup is at turn index 5
# backwards (noise, noise, noise, noise, setup) — skipped.
assert (
conv.find_completed_tool_call("browser_setup", {}, within_last_turns=3)
is None
)
# Window=10 → found.
assert (
conv.find_completed_tool_call("browser_setup", {}, within_last_turns=10)
is not None
)
class TestPartialCheckpoint:
@pytest.mark.asyncio
async def test_checkpoint_is_cleared_when_real_part_lands(self, tmp_path):
"""A partial for seq N is wiped once add_assistant_message(seq=N) persists."""
store = FileConversationStore(tmp_path / "c")
conv = NodeConversation(system_prompt="s", store=store)
await conv.add_user_message("hi")
# Seed a partial for the would-be next assistant seq.
await conv.checkpoint_partial_assistant("half-written...")
partials = await store.read_all_partials()
assert len(partials) == 1
assert partials[0]["content"] == "half-written..."
# Commit the real assistant turn — partial should be swept.
await conv.add_assistant_message("fully written")
assert await store.read_all_partials() == []
@pytest.mark.asyncio
async def test_restore_surfaces_partial_as_truncated_message(self, tmp_path):
"""A partial left behind by a crashed stream is resurrected on restore."""
store = FileConversationStore(tmp_path / "c")
conv = NodeConversation(system_prompt="s", store=store)
await conv.add_user_message("hi")
# Simulate a stream that produced some text + a tool call, then died
# before finishing. The checkpoint captures both.
await conv.checkpoint_partial_assistant(
"I was working on this when the stream died",
tool_calls=[
{
"id": "tc_x",
"type": "function",
"function": {"name": "browser_click", "arguments": "{}"},
}
],
)
# Fresh process — restore from disk.
fresh = await NodeConversation.restore(store)
assert fresh is not None
# The user message is there, plus the truncated assistant resurrected
# from the partial.
roles = [m.role for m in fresh.messages]
assert roles == ["user", "assistant"]
last = fresh.messages[-1]
assert last.truncated is True
assert last.content == "I was working on this when the stream died"
assert last.tool_calls and last.tool_calls[0]["function"]["name"] == "browser_click"
@pytest.mark.asyncio
async def test_restore_cleans_stale_partials(self, tmp_path):
"""A partial whose seq was already committed as a real part is discarded."""
store = FileConversationStore(tmp_path / "c")
conv = NodeConversation(system_prompt="s", store=store)
await conv.add_user_message("hi")
await conv.add_assistant_message("real") # seq=1
# Manually plant a stale partial at seq=1 (already committed).
await store.write_partial(
1, {"seq": 1, "role": "assistant", "content": "stale", "truncated": True}
)
fresh = await NodeConversation.restore(store)
assert fresh is not None
assert [m.content for m in fresh.messages] == ["hi", "real"]
# Stale partial swept by restore.
assert await store.read_all_partials() == []
class TestMessageFlags:
def test_is_system_nudge_roundtrip(self):
m = Message(seq=0, role="user", content="nudge", is_system_nudge=True)
d = m.to_storage_dict()
assert d.get("is_system_nudge") is True
r = Message.from_storage_dict(d)
assert r.is_system_nudge is True
assert r.role == "user"
def test_truncated_roundtrip(self):
m = Message(seq=0, role="assistant", content="half", truncated=True)
d = m.to_storage_dict()
assert d.get("truncated") is True
r = Message.from_storage_dict(d)
assert r.truncated is True
def test_defaults_omit_flags_from_storage(self):
m = Message(seq=0, role="user", content="plain")
d = m.to_storage_dict()
assert "is_system_nudge" not in d
assert "truncated" not in d