fix: partial parts and system nudge
This commit is contained in:
@@ -55,7 +55,9 @@
|
||||
"mcp__gcu-tools__browser_click_coordinate",
|
||||
"mcp__gcu-tools__browser_get_rect",
|
||||
"mcp__gcu-tools__browser_type_focused",
|
||||
"mcp__gcu-tools__browser_wait"
|
||||
"mcp__gcu-tools__browser_wait",
|
||||
"Bash(python3 -c ' *)",
|
||||
"Bash(python3 scripts/debug_queen_prompt.py independent)"
|
||||
],
|
||||
"additionalDirectories": [
|
||||
"/home/timothy/.hive/skills/writing-hive-skills",
|
||||
|
||||
@@ -2335,6 +2335,11 @@ class AgentLoop(AgentProtocol):
|
||||
execution_id,
|
||||
)
|
||||
|
||||
# Continue-nudge counter: how many times we've re-streamed within this
|
||||
# _run_single_turn because the idle/TTFT watchdog fired. Caps to avoid
|
||||
# nudging forever when the endpoint is genuinely dead.
|
||||
_nudge_count_this_turn = 0
|
||||
|
||||
# Inner tool loop: stream may produce tool calls requiring re-invocation
|
||||
while True:
|
||||
# Pre-send guard: if context is at or over budget, compact before
|
||||
@@ -2423,7 +2428,16 @@ class AgentLoop(AgentProtocol):
|
||||
# Capture loop-scoped variables as defaults to satisfy B023.
|
||||
# _stream_last_event_at is bumped on every event; the watchdog
|
||||
# below uses it to detect silently hung HTTP connections.
|
||||
_stream_last_event_at = time.monotonic()
|
||||
_stream_start_at = time.monotonic()
|
||||
_stream_last_event_at = _stream_start_at
|
||||
# None until the first event arrives. Before first event, the
|
||||
# watchdog uses the (much looser) TTFT budget — large-context
|
||||
# local models legitimately take minutes to first token. Once
|
||||
# any event has been observed, tight inter-event idle applies.
|
||||
_first_event_at: float | None = None
|
||||
# Partial tool_calls accumulated so far, as OpenAI-format dicts
|
||||
# ready for persistence if the stream is cut short.
|
||||
_partial_tc_dicts: list[dict[str, Any]] = []
|
||||
|
||||
async def _do_stream(
|
||||
_msgs: list = messages, # noqa: B006
|
||||
@@ -2432,8 +2446,10 @@ class AgentLoop(AgentProtocol):
|
||||
_safe_names: set = _early_safe_names, # noqa: B006,B008
|
||||
_tasks: dict = _early_tasks, # noqa: B006,B008
|
||||
_exec_fn=_timed_execute,
|
||||
_partial_dicts: list[dict[str, Any]] = _partial_tc_dicts, # noqa: B006,B008
|
||||
) -> None:
|
||||
nonlocal accumulated_text, _stream_error, _stream_last_event_at
|
||||
nonlocal _first_event_at
|
||||
_clean_snapshot = "" # visible-only text for the frontend
|
||||
|
||||
async for event in ctx.llm.stream(
|
||||
@@ -2443,6 +2459,8 @@ class AgentLoop(AgentProtocol):
|
||||
max_tokens=ctx.max_tokens,
|
||||
):
|
||||
_stream_last_event_at = time.monotonic()
|
||||
if _first_event_at is None:
|
||||
_first_event_at = _stream_last_event_at
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
accumulated_text = event.snapshot
|
||||
# Strip internal reasoning tags from the full
|
||||
@@ -2462,9 +2480,46 @@ class AgentLoop(AgentProtocol):
|
||||
iteration=iteration,
|
||||
inner_turn=inner_turn,
|
||||
)
|
||||
# Checkpoint partial state so a watchdog cancel or
|
||||
# crash doesn't discard whatever the model has
|
||||
# produced so far. Cheap — one atomic file write.
|
||||
try:
|
||||
await conversation.checkpoint_partial_assistant(
|
||||
accumulated_text,
|
||||
_partial_dicts or None,
|
||||
)
|
||||
except Exception as _cp_err: # noqa: BLE001
|
||||
logger.debug(
|
||||
"[_run_single_turn] partial checkpoint failed: %s",
|
||||
_cp_err,
|
||||
)
|
||||
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
_tc.append(event)
|
||||
_partial_dicts.append(
|
||||
{
|
||||
"id": event.tool_use_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": event.tool_name,
|
||||
"arguments": json.dumps(event.tool_input),
|
||||
},
|
||||
}
|
||||
)
|
||||
# Checkpoint now that a tool call has landed —
|
||||
# this is the important one: if the stream dies
|
||||
# right after a tool call but before FinishEvent,
|
||||
# we still have the intent recorded.
|
||||
try:
|
||||
await conversation.checkpoint_partial_assistant(
|
||||
accumulated_text,
|
||||
_partial_dicts or None,
|
||||
)
|
||||
except Exception as _cp_err: # noqa: BLE001
|
||||
logger.debug(
|
||||
"[_run_single_turn] partial checkpoint failed: %s",
|
||||
_cp_err,
|
||||
)
|
||||
# Gap 1: start concurrency-safe tools immediately
|
||||
# while the rest of the stream is still arriving,
|
||||
# so read-heavy turns don't stall after the last
|
||||
@@ -2492,55 +2547,99 @@ class AgentLoop(AgentProtocol):
|
||||
_llm_stream_t0 = time.monotonic()
|
||||
self._stream_task = asyncio.create_task(_do_stream())
|
||||
logger.debug("[_run_single_turn] inner_turn=%d: Stream task created, waiting...", inner_turn)
|
||||
_inactivity_limit = self._config.llm_stream_inactivity_timeout_seconds
|
||||
|
||||
# Watchdog budgets — see LoopConfig docstring for rationale.
|
||||
_ttft_limit = self._config.llm_stream_ttft_timeout_seconds
|
||||
_inter_event_limit = self._config.llm_stream_inter_event_idle_seconds
|
||||
# Back-compat: if the legacy inactivity knob was overridden to
|
||||
# a value below the new default, respect it as the inter-event
|
||||
# budget (historic behaviour) so existing configs don't regress.
|
||||
_legacy = self._config.llm_stream_inactivity_timeout_seconds
|
||||
if _legacy and _legacy > 0 and _legacy < _inter_event_limit:
|
||||
_inter_event_limit = _legacy
|
||||
_watchdog_active = (_ttft_limit and _ttft_limit > 0) or (
|
||||
_inter_event_limit and _inter_event_limit > 0
|
||||
)
|
||||
# Result of the watchdog: "ok" (stream finished), "ttft" (no first
|
||||
# event in budget), "inactive" (silence after first event).
|
||||
_watchdog_verdict: str = "ok"
|
||||
_watchdog_elapsed: float = 0.0
|
||||
_watchdog_limit: float = 0.0
|
||||
|
||||
try:
|
||||
if _inactivity_limit and _inactivity_limit > 0:
|
||||
# Heartbeat-aware wait: poll the task and cancel it if
|
||||
# no stream event has been observed within the window.
|
||||
# A silently dead HTTP connection otherwise hangs here
|
||||
# forever — no exception, no delta, no timeout.
|
||||
#
|
||||
# Must use asyncio.wait (not wait_for) so we can tell
|
||||
# "poll interval elapsed" apart from "task raised a
|
||||
# TimeoutError of its own" — wait_for conflates them.
|
||||
_check_interval = min(5.0, _inactivity_limit / 2)
|
||||
if _watchdog_active:
|
||||
# Poll cheapest-valid interval: at most every 5s, at least
|
||||
# half the tighter budget. Must use asyncio.wait (not
|
||||
# wait_for) so "poll interval elapsed" and "task raised
|
||||
# TimeoutError of its own" stay distinguishable.
|
||||
_tight = min(
|
||||
_ttft_limit or float("inf"),
|
||||
_inter_event_limit or float("inf"),
|
||||
)
|
||||
_check_interval = max(1.0, min(5.0, _tight / 2))
|
||||
while True:
|
||||
done, _pending = await asyncio.wait({self._stream_task}, timeout=_check_interval)
|
||||
done, _pending = await asyncio.wait(
|
||||
{self._stream_task}, timeout=_check_interval
|
||||
)
|
||||
if self._stream_task in done:
|
||||
# Let any exception the task raised propagate
|
||||
# naturally via the outer ``await`` below.
|
||||
break
|
||||
idle = time.monotonic() - _stream_last_event_at
|
||||
if idle >= _inactivity_limit:
|
||||
logger.warning(
|
||||
"[_run_single_turn] inner_turn=%d: "
|
||||
"stream inactivity %.0fs >= %.0fs — "
|
||||
"cancelling stream task",
|
||||
inner_turn,
|
||||
idle,
|
||||
_inactivity_limit,
|
||||
)
|
||||
self._bump("stream_inactivity_watchdog")
|
||||
self._stream_task.cancel()
|
||||
try:
|
||||
await self._stream_task
|
||||
except BaseException:
|
||||
pass
|
||||
raise ConnectionError(
|
||||
f"LLM stream idle for {idle:.0f}s "
|
||||
f"(inactivity limit {_inactivity_limit:.0f}s) — "
|
||||
"connection presumed dead"
|
||||
) from None
|
||||
now = time.monotonic()
|
||||
if _first_event_at is None:
|
||||
# TTFT phase — stream open but silent. Use the
|
||||
# looser budget; don't confuse slow models with
|
||||
# dead connections.
|
||||
elapsed = now - _stream_start_at
|
||||
if _ttft_limit and _ttft_limit > 0 and elapsed >= _ttft_limit:
|
||||
_watchdog_verdict = "ttft"
|
||||
_watchdog_elapsed = elapsed
|
||||
_watchdog_limit = _ttft_limit
|
||||
break
|
||||
else:
|
||||
# Post-first-event silence. A stream that produced
|
||||
# events and then went quiet is a real stall.
|
||||
idle = now - _stream_last_event_at
|
||||
if (
|
||||
_inter_event_limit
|
||||
and _inter_event_limit > 0
|
||||
and idle >= _inter_event_limit
|
||||
):
|
||||
_watchdog_verdict = "inactive"
|
||||
_watchdog_elapsed = idle
|
||||
_watchdog_limit = _inter_event_limit
|
||||
break
|
||||
# Still active — keep polling.
|
||||
# Re-raise any exception the stream task stored. When the
|
||||
# watchdog loop exited via ``break`` the task is done, and
|
||||
# ``await`` is the cheapest way to surface its exception.
|
||||
await self._stream_task
|
||||
logger.debug("[_run_single_turn] inner_turn=%d: Stream task completed normally", inner_turn)
|
||||
|
||||
if _watchdog_verdict != "ok":
|
||||
logger.warning(
|
||||
"[_run_single_turn] inner_turn=%d: watchdog=%s %.0fs >= %.0fs — cancelling stream",
|
||||
inner_turn,
|
||||
_watchdog_verdict,
|
||||
_watchdog_elapsed,
|
||||
_watchdog_limit,
|
||||
)
|
||||
self._bump(f"stream_watchdog_{_watchdog_verdict}")
|
||||
self._stream_task.cancel()
|
||||
try:
|
||||
await self._stream_task
|
||||
except BaseException:
|
||||
pass
|
||||
else:
|
||||
# Re-raise any exception the stream task stored. When the
|
||||
# watchdog loop exited via ``break`` the task is done, and
|
||||
# ``await`` is the cheapest way to surface its exception.
|
||||
await self._stream_task
|
||||
logger.debug(
|
||||
"[_run_single_turn] inner_turn=%d: Stream task completed normally",
|
||||
inner_turn,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("[_run_single_turn] inner_turn=%d: Stream task cancelled", inner_turn)
|
||||
if accumulated_text:
|
||||
await conversation.add_assistant_message(content=accumulated_text)
|
||||
if accumulated_text or _partial_tc_dicts:
|
||||
await conversation.add_assistant_message(
|
||||
content=accumulated_text,
|
||||
tool_calls=_partial_tc_dicts or None,
|
||||
truncated=True,
|
||||
)
|
||||
# Gap 1: kill any early-dispatched tool tasks too.
|
||||
# Without this, a safe tool started during streaming
|
||||
# would leak past cancellation and keep running.
|
||||
@@ -2568,6 +2667,100 @@ class AgentLoop(AgentProtocol):
|
||||
raise
|
||||
finally:
|
||||
self._stream_task = None
|
||||
|
||||
# Continue-nudge recovery path. Runs AFTER the stream task is
|
||||
# cleaned up so all state is consistent. We persist whatever
|
||||
# partial text + tool-calls the model produced (as a truncated
|
||||
# message so the model can see its own in-flight work on the
|
||||
# next turn), cancel early tool tasks, append a terse
|
||||
# continuation hint, and restart the stream.
|
||||
if _watchdog_verdict != "ok":
|
||||
# Kill any safe-tool tasks the stream dispatched early —
|
||||
# their results would have had nowhere to land anyway
|
||||
# because the assistant message was incomplete.
|
||||
for _early in _early_tasks.values():
|
||||
if not _early.done():
|
||||
_early.cancel()
|
||||
# Promote whatever we captured into a real truncated
|
||||
# message. The partial checkpoint for this seq is cleared
|
||||
# automatically when add_assistant_message persists.
|
||||
if accumulated_text or _partial_tc_dicts:
|
||||
await conversation.add_assistant_message(
|
||||
content=accumulated_text,
|
||||
tool_calls=_partial_tc_dicts or None,
|
||||
truncated=True,
|
||||
)
|
||||
|
||||
reason_label = (
|
||||
"no tokens before TTFT budget"
|
||||
if _watchdog_verdict == "ttft"
|
||||
else "stream went silent after producing events"
|
||||
)
|
||||
if self._event_bus:
|
||||
if _watchdog_verdict == "ttft":
|
||||
await self._event_bus.emit_stream_ttft_exceeded(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
ttft_seconds=_watchdog_elapsed,
|
||||
limit_seconds=_watchdog_limit,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
else:
|
||||
await self._event_bus.emit_stream_inactive(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
idle_seconds=_watchdog_elapsed,
|
||||
limit_seconds=_watchdog_limit,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
nudge_enabled = self._config.continue_nudge_enabled
|
||||
nudge_cap = self._config.continue_nudge_max_per_turn
|
||||
if nudge_enabled and _nudge_count_this_turn < nudge_cap:
|
||||
_nudge_count_this_turn += 1
|
||||
nudge_msg = (
|
||||
f"[System: the previous stream stalled ({reason_label}, "
|
||||
f"{_watchdog_elapsed:.0f}s). Continue from the last tool "
|
||||
"result already in this conversation. Do NOT repeat tool "
|
||||
"calls whose results are visible above — reuse them and "
|
||||
"move to the next step.]"
|
||||
)
|
||||
await conversation.add_user_message(
|
||||
nudge_msg,
|
||||
is_system_nudge=True,
|
||||
)
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_stream_nudge_sent(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
reason=_watchdog_verdict,
|
||||
nudge_count=_nudge_count_this_turn,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
logger.info(
|
||||
"[%s] continue-nudge sent (count=%d/%d, reason=%s)",
|
||||
node_id,
|
||||
_nudge_count_this_turn,
|
||||
nudge_cap,
|
||||
_watchdog_verdict,
|
||||
)
|
||||
# Reset the outer _turn_t0 timer so the "LLM done in
|
||||
# Xms" log line reflects real work not the nudge cycle.
|
||||
_llm_stream_ms = int((time.monotonic() - _llm_stream_t0) * 1000)
|
||||
logger.debug(
|
||||
"[_run_single_turn] inner_turn=%d: nudge restart after %dms",
|
||||
inner_turn,
|
||||
_llm_stream_ms,
|
||||
)
|
||||
continue # restart the inner loop, re-fetches messages
|
||||
# Nudge disabled or cap exhausted — fall back to the
|
||||
# existing retry path so a truly dead endpoint eventually
|
||||
# surfaces as an error.
|
||||
raise ConnectionError(
|
||||
f"LLM stream {_watchdog_verdict} for {_watchdog_elapsed:.0f}s "
|
||||
f"(limit {_watchdog_limit:.0f}s) — nudge cap reached"
|
||||
)
|
||||
|
||||
_llm_stream_ms = int((time.monotonic() - _llm_stream_t0) * 1000)
|
||||
|
||||
# If a recoverable stream error produced an empty response,
|
||||
@@ -2667,6 +2860,12 @@ class AgentLoop(AgentProtocol):
|
||||
results_by_id: dict[str, ToolResult] = {}
|
||||
timing_by_id: dict[str, dict[str, Any]] = {} # tool_use_id -> {start_timestamp, duration_s}
|
||||
pending_real: list[ToolCallEvent] = []
|
||||
# Replay detector: per-turn map from tool_use_id -> steer prefix.
|
||||
# Populated below when we detect that the model is re-emitting a
|
||||
# tool call whose (name + canonical args) matches a prior success.
|
||||
# Applied to the stored tool result content so the model sees the
|
||||
# nudge on its next turn without losing the real execution output.
|
||||
replay_prefixes_by_id: dict[str, str] = {}
|
||||
|
||||
for tc in tool_calls:
|
||||
tool_call_count += 1
|
||||
@@ -2939,6 +3138,39 @@ class AgentLoop(AgentProtocol):
|
||||
)
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
else:
|
||||
# Replay detector: flag re-executions of recent
|
||||
# successful calls. We still run the tool (some
|
||||
# are legitimately repeated, e.g. screenshots and
|
||||
# read-only evaluates) but prepend a terse steer
|
||||
# onto the stored result so the model sees the
|
||||
# signal on its next turn.
|
||||
if self._config.replay_detector_enabled:
|
||||
prior = conversation.find_completed_tool_call(
|
||||
tc.tool_name,
|
||||
tc.tool_input,
|
||||
within_last_turns=self._config.replay_detector_within_last_turns,
|
||||
)
|
||||
if prior is not None:
|
||||
logger.warning(
|
||||
"[%s] replay detected: %s matches prior seq=%d — executing anyway",
|
||||
node_id,
|
||||
tc.tool_name,
|
||||
prior.seq,
|
||||
)
|
||||
self._bump("tool_call_replay_detected")
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_tool_call_replay_detected(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_name=tc.tool_name,
|
||||
prior_seq=prior.seq,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
replay_prefixes_by_id[tc.tool_use_id] = (
|
||||
f"[Replay detected: {tc.tool_name} matches "
|
||||
f"seq={prior.seq}. Result still produced below — "
|
||||
"consider whether the retry was necessary.]\n"
|
||||
)
|
||||
pending_real.append(tc)
|
||||
|
||||
# Phase 2a: partition real tools by concurrency safety.
|
||||
@@ -3136,9 +3368,18 @@ class AgentLoop(AgentProtocol):
|
||||
)
|
||||
image_content = None
|
||||
|
||||
# Apply replay-detector steer prefix if this call matched a
|
||||
# recent successful invocation. Only applies to non-error
|
||||
# results — an error already breaks the replay chain.
|
||||
stored_content = result.content
|
||||
if not result.is_error:
|
||||
_prefix = replay_prefixes_by_id.get(tc.tool_use_id)
|
||||
if _prefix:
|
||||
stored_content = f"{_prefix}{stored_content or ''}"
|
||||
|
||||
await conversation.add_tool_result(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
content=stored_content,
|
||||
is_error=result.is_error,
|
||||
image_content=image_content,
|
||||
is_skill_content=result.is_skill_content,
|
||||
|
||||
@@ -48,6 +48,14 @@ class Message:
|
||||
is_skill_content: bool = False
|
||||
# Logical worker run identifier for shared-session persistence
|
||||
run_id: str | None = None
|
||||
# True when this is a framework-injected continuation hint (continue-nudge
|
||||
# on stream stall). Stored as a user message for API compatibility, but
|
||||
# the UI should render it as a compact system notice, not user speech.
|
||||
is_system_nudge: bool = False
|
||||
# True when this message is a partial/truncated assistant turn reconstructed
|
||||
# from a crashed or watchdog-cancelled stream. Signals that the original
|
||||
# turn never finished — the model may or may not choose to redo it.
|
||||
truncated: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
@@ -109,6 +117,10 @@ class Message:
|
||||
d["image_content"] = self.image_content
|
||||
if self.run_id is not None:
|
||||
d["run_id"] = self.run_id
|
||||
if self.is_system_nudge:
|
||||
d["is_system_nudge"] = self.is_system_nudge
|
||||
if self.truncated:
|
||||
d["truncated"] = self.truncated
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
@@ -126,6 +138,8 @@ class Message:
|
||||
is_client_input=data.get("is_client_input", False),
|
||||
image_content=data.get("image_content"),
|
||||
run_id=data.get("run_id"),
|
||||
is_system_nudge=data.get("is_system_nudge", False),
|
||||
truncated=data.get("truncated", False),
|
||||
)
|
||||
|
||||
|
||||
@@ -317,6 +331,14 @@ class ConversationStore(Protocol):
|
||||
|
||||
async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None: ...
|
||||
|
||||
async def write_partial(self, seq: int, data: dict[str, Any]) -> None: ...
|
||||
|
||||
async def read_partial(self, seq: int) -> dict[str, Any] | None: ...
|
||||
|
||||
async def read_all_partials(self) -> list[dict[str, Any]]: ...
|
||||
|
||||
async def clear_partial(self, seq: int) -> None: ...
|
||||
|
||||
async def close(self) -> None: ...
|
||||
|
||||
async def destroy(self) -> None: ...
|
||||
@@ -462,6 +484,7 @@ class NodeConversation:
|
||||
is_transition_marker: bool = False,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
is_system_nudge: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
@@ -472,6 +495,7 @@ class NodeConversation:
|
||||
is_transition_marker=is_transition_marker,
|
||||
is_client_input=is_client_input,
|
||||
image_content=image_content,
|
||||
is_system_nudge=is_system_nudge,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -485,6 +509,8 @@ class NodeConversation:
|
||||
self,
|
||||
content: str,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
*,
|
||||
truncated: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
@@ -493,6 +519,7 @@ class NodeConversation:
|
||||
tool_calls=tool_calls,
|
||||
phase_id=self._current_phase,
|
||||
run_id=self._run_id,
|
||||
truncated=truncated,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -548,6 +575,59 @@ class NodeConversation:
|
||||
|
||||
# --- Query -------------------------------------------------------------
|
||||
|
||||
def find_completed_tool_call(
|
||||
self,
|
||||
name: str,
|
||||
tool_input: dict[str, Any],
|
||||
within_last_turns: int = 3,
|
||||
) -> Message | None:
|
||||
"""Return the most recent assistant message that issued a tool call
|
||||
with the same (name + canonical-json args) AND received a non-error
|
||||
tool result, within the last ``within_last_turns`` assistant turns.
|
||||
|
||||
Used by the replay detector to flag when the model is about to redo
|
||||
a successful call — we prepend a steer onto the upcoming result but
|
||||
still execute, so tools like browser_screenshot that are legitimately
|
||||
repeated are not silently skipped.
|
||||
"""
|
||||
try:
|
||||
target_canonical = json.dumps(tool_input, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
target_canonical = str(tool_input)
|
||||
|
||||
# Walk backwards over recent assistant messages
|
||||
assistant_turns_seen = 0
|
||||
for idx in range(len(self._messages) - 1, -1, -1):
|
||||
m = self._messages[idx]
|
||||
if m.role != "assistant":
|
||||
continue
|
||||
assistant_turns_seen += 1
|
||||
if assistant_turns_seen > within_last_turns:
|
||||
break
|
||||
if not m.tool_calls:
|
||||
continue
|
||||
for tc in m.tool_calls:
|
||||
func = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
tc_name = func.get("name")
|
||||
if tc_name != name:
|
||||
continue
|
||||
args_str = func.get("arguments", "")
|
||||
try:
|
||||
parsed = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||
canonical = json.dumps(parsed, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
canonical = str(args_str)
|
||||
if canonical != target_canonical:
|
||||
continue
|
||||
# Found a match — now verify its result was not an error.
|
||||
tc_id = tc.get("id")
|
||||
for later in self._messages[idx + 1 :]:
|
||||
if later.role == "tool" and later.tool_use_id == tc_id:
|
||||
if not later.is_error:
|
||||
return m
|
||||
break
|
||||
return None
|
||||
|
||||
def to_llm_messages(self) -> list[dict[str, Any]]:
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded).
|
||||
|
||||
@@ -1365,6 +1445,45 @@ class NodeConversation:
|
||||
await self._persist_meta()
|
||||
await self._store.write_part(message.seq, message.to_storage_dict())
|
||||
await self._write_next_seq()
|
||||
# Any partial checkpoint for this seq is now superseded by the real
|
||||
# part — clear it so a future restore doesn't resurrect stale text.
|
||||
try:
|
||||
await self._store.clear_partial(message.seq)
|
||||
except AttributeError:
|
||||
# Older stores may not implement partials; ignore.
|
||||
pass
|
||||
|
||||
async def checkpoint_partial_assistant(
|
||||
self,
|
||||
accumulated_text: str,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Write an in-flight assistant turn's state to disk under the next seq.
|
||||
|
||||
Called from the stream event loop. Safe to call repeatedly — each call
|
||||
overwrites the prior checkpoint. Persisted via ``write_partial`` so it
|
||||
does NOT appear in ``read_parts()`` and cannot be double-loaded. Cleared
|
||||
automatically when ``add_assistant_message`` for this seq lands.
|
||||
"""
|
||||
if self._store is None:
|
||||
return
|
||||
if not self._meta_persisted:
|
||||
await self._persist_meta()
|
||||
payload: dict[str, Any] = {
|
||||
"seq": self._next_seq,
|
||||
"role": "assistant",
|
||||
"content": accumulated_text,
|
||||
"phase_id": self._current_phase,
|
||||
"run_id": self._run_id,
|
||||
"truncated": True,
|
||||
}
|
||||
if tool_calls:
|
||||
payload["tool_calls"] = tool_calls
|
||||
try:
|
||||
await self._store.write_partial(self._next_seq, payload)
|
||||
except AttributeError:
|
||||
# Older stores may not implement partials; ignore.
|
||||
pass
|
||||
|
||||
async def _persist_meta(self) -> None:
|
||||
"""Lazily write conversation metadata to the store (called once).
|
||||
@@ -1461,4 +1580,45 @@ class NodeConversation:
|
||||
elif conv._messages:
|
||||
conv._next_seq = conv._messages[-1].seq + 1
|
||||
|
||||
# Surface any leftover partial checkpoints as truncated messages so
|
||||
# the next turn sees what the interrupted stream was in the middle
|
||||
# of producing. Only partials whose seq is >= next_seq are meaningful;
|
||||
# anything lower was already superseded by a real part.
|
||||
try:
|
||||
partials = await store.read_all_partials()
|
||||
except AttributeError:
|
||||
partials = []
|
||||
for p in partials:
|
||||
pseq = p.get("seq", -1)
|
||||
if pseq < conv._next_seq:
|
||||
# Stale — clean it up.
|
||||
try:
|
||||
await store.clear_partial(pseq)
|
||||
except AttributeError:
|
||||
pass
|
||||
continue
|
||||
# Only resurrect partials relevant to this run / phase.
|
||||
if run_id and not is_legacy_run_id(run_id) and p.get("run_id") != run_id:
|
||||
continue
|
||||
if phase_id and p.get("phase_id") is not None and p.get("phase_id") != phase_id:
|
||||
continue
|
||||
# Reconstruct as a truncated assistant message.
|
||||
msg = Message(
|
||||
seq=pseq,
|
||||
role="assistant",
|
||||
content=p.get("content", "") or "",
|
||||
tool_calls=p.get("tool_calls"),
|
||||
phase_id=p.get("phase_id"),
|
||||
run_id=p.get("run_id"),
|
||||
truncated=True,
|
||||
)
|
||||
conv._messages.append(msg)
|
||||
conv._next_seq = max(conv._next_seq, pseq + 1)
|
||||
logger.info(
|
||||
"restore: resurrected truncated partial seq=%d (text=%d chars, tool_calls=%d)",
|
||||
pseq,
|
||||
len(msg.content),
|
||||
len(msg.tool_calls or []),
|
||||
)
|
||||
|
||||
return conv
|
||||
|
||||
@@ -131,14 +131,39 @@ class LoopConfig:
|
||||
# Per-tool-call timeout.
|
||||
tool_call_timeout_seconds: float = 60.0
|
||||
|
||||
# LLM stream inactivity watchdog. If no stream event (delta, tool call,
|
||||
# finish) arrives within this many seconds, the stream task is cancelled
|
||||
# and a transient error is raised so the retry loop can back off and
|
||||
# reconnect. Prevents agents from hanging forever on a silently dead
|
||||
# HTTP connection (no provider heartbeat, no exception, just silence).
|
||||
# Set to 0 to disable.
|
||||
# LLM stream inactivity watchdog. Split into two budgets so legitimate
|
||||
# slow TTFT on large contexts doesn't get mistaken for a dead connection.
|
||||
# - ttft: stream open -> first event. Large-context local models can
|
||||
# legitimately take minutes before the first token arrives.
|
||||
# - inter_event: last event -> now, ONLY after the first event. A stream
|
||||
# that started producing and then went silent is a real stall.
|
||||
# Whichever fires first cancels the stream. Set to 0 to disable that
|
||||
# individual budget; set both to 0 to fully disable the watchdog.
|
||||
llm_stream_ttft_timeout_seconds: float = 600.0
|
||||
llm_stream_inter_event_idle_seconds: float = 120.0
|
||||
# Deprecated alias — kept so existing configs keep working. If set to a
|
||||
# non-default value it overrides inter_event_idle (historical behavior).
|
||||
llm_stream_inactivity_timeout_seconds: float = 120.0
|
||||
|
||||
# Continue-nudge recovery. When the idle watchdog fires on a live but
|
||||
# stuck stream, cancel the stream and append a short continuation
|
||||
# hint to the conversation instead of raising a ConnectionError and
|
||||
# re-running the whole turn. Preserves any partial text/tool-calls the
|
||||
# stream emitted before the stall.
|
||||
continue_nudge_enabled: bool = True
|
||||
# Cap so a truly dead endpoint eventually falls back to the error path
|
||||
# instead of nudging forever.
|
||||
continue_nudge_max_per_turn: int = 3
|
||||
|
||||
# Tool-call replay detector. When the model emits a tool call whose
|
||||
# (name + canonical-args) matches a prior successful call in the last
|
||||
# K assistant turns, emit telemetry and prepend a short steer onto the
|
||||
# tool result — but still execute. Weaker models legitimately repeat
|
||||
# read-only calls (screenshot, evaluate), so silent skipping would
|
||||
# cause surprising behavior.
|
||||
replay_detector_enabled: bool = True
|
||||
replay_detector_within_last_turns: int = 3
|
||||
|
||||
# Subagent delegation timeout (wall-clock max).
|
||||
subagent_timeout_seconds: float = 3600.0
|
||||
|
||||
|
||||
@@ -111,6 +111,15 @@ class EventType(StrEnum):
|
||||
# Retry tracking
|
||||
NODE_RETRY = "node_retry"
|
||||
|
||||
# Stream-health observability. Split from NODE_RETRY so the UI can
|
||||
# distinguish "slow TTFT on a huge context" (healthy, just slow) from
|
||||
# "stream went silent mid-generation" (probable stall) from "we nudged
|
||||
# the model to continue" (recovery), which NODE_RETRY used to conflate.
|
||||
STREAM_TTFT_EXCEEDED = "stream_ttft_exceeded"
|
||||
STREAM_INACTIVE = "stream_inactive"
|
||||
STREAM_NUDGE_SENT = "stream_nudge_sent"
|
||||
TOOL_CALL_REPLAY_DETECTED = "tool_call_replay_detected"
|
||||
|
||||
# Worker agent lifecycle
|
||||
WORKER_COMPLETED = "worker_completed"
|
||||
WORKER_FAILED = "worker_failed"
|
||||
@@ -1061,6 +1070,94 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_stream_ttft_exceeded(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
ttft_seconds: float,
|
||||
limit_seconds: float,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit when a stream stayed silent past the TTFT budget (no first event)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.STREAM_TTFT_EXCEEDED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"ttft_seconds": ttft_seconds,
|
||||
"limit_seconds": limit_seconds,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_stream_inactive(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
idle_seconds: float,
|
||||
limit_seconds: float,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit when a stream that had produced events went silent past budget."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.STREAM_INACTIVE,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"idle_seconds": idle_seconds,
|
||||
"limit_seconds": limit_seconds,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_stream_nudge_sent(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
reason: str,
|
||||
nudge_count: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit when the continue-nudge was injected (recovery, not retry)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.STREAM_NUDGE_SENT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"reason": reason,
|
||||
"nudge_count": nudge_count,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_tool_call_replay_detected(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_name: str,
|
||||
prior_seq: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit when the model is about to re-execute a prior successful call."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_REPLAY_DETECTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"prior_seq": prior_seq,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_worker_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
|
||||
@@ -43,6 +43,10 @@ class FileConversationStore:
|
||||
def __init__(self, base_path: str | Path) -> None:
|
||||
self._base = Path(base_path)
|
||||
self._parts_dir = self._base / "parts"
|
||||
# Partial checkpoints for in-flight assistant turns. Written on every
|
||||
# stream event, deleted atomically when the final part lands. Kept
|
||||
# in a sibling dir so the parts/ glob doesn't pick them up.
|
||||
self._partials_dir = self._base / "partials"
|
||||
|
||||
# --- sync helpers --------------------------------------------------------
|
||||
|
||||
@@ -99,6 +103,44 @@ class FileConversationStore:
|
||||
async def read_cursor(self) -> dict[str, Any] | None:
|
||||
return await self._run(self._read_json, self._base / "cursor.json")
|
||||
|
||||
async def write_partial(self, seq: int, data: dict[str, Any]) -> None:
|
||||
"""Checkpoint an in-flight assistant turn. Overwrites any prior partial
|
||||
for the same seq. Caller is expected to clear_partial() once the real
|
||||
part is written via write_part().
|
||||
"""
|
||||
path = self._partials_dir / f"{seq:010d}.json"
|
||||
await self._run(self._write_json, path, data)
|
||||
|
||||
async def read_partial(self, seq: int) -> dict[str, Any] | None:
|
||||
path = self._partials_dir / f"{seq:010d}.json"
|
||||
return await self._run(self._read_json, path)
|
||||
|
||||
async def read_all_partials(self) -> list[dict[str, Any]]:
|
||||
"""Return all partial checkpoints, sorted by seq. Used during restore
|
||||
to surface any in-flight turn that the last process didn't finish.
|
||||
"""
|
||||
|
||||
def _read_all() -> list[dict[str, Any]]:
|
||||
if not self._partials_dir.exists():
|
||||
return []
|
||||
files = sorted(self._partials_dir.glob("*.json"))
|
||||
partials: list[dict[str, Any]] = []
|
||||
for f in files:
|
||||
data = self._read_json(f)
|
||||
if data is not None:
|
||||
partials.append(data)
|
||||
return partials
|
||||
|
||||
return await self._run(_read_all)
|
||||
|
||||
async def clear_partial(self, seq: int) -> None:
|
||||
def _clear() -> None:
|
||||
path = self._partials_dir / f"{seq:010d}.json"
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
await self._run(_clear)
|
||||
|
||||
async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None:
|
||||
def _delete() -> None:
|
||||
if not self._parts_dir.exists():
|
||||
@@ -125,6 +167,10 @@ class FileConversationStore:
|
||||
if self._parts_dir.exists():
|
||||
for f in self._parts_dir.glob("*.json"):
|
||||
f.unlink()
|
||||
# Clear partial checkpoints
|
||||
if self._partials_dir.exists():
|
||||
for f in self._partials_dir.glob("*.json"):
|
||||
f.unlink()
|
||||
# Clear cursor
|
||||
cursor_path = self._base / "cursor.json"
|
||||
if cursor_path.exists():
|
||||
|
||||
@@ -2109,3 +2109,137 @@ class TestToolConcurrencyPartition:
|
||||
|
||||
# Both tools must have run: soft errors don't cascade.
|
||||
assert executed == ["call_1", "call_2"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Replay detector (warn + execute)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestReplayDetector:
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_emits_event_and_prefixes_result(
|
||||
self, tmp_path, runtime, node_spec, buffer
|
||||
):
|
||||
"""Re-emitting a tool call whose prior result succeeded fires the
|
||||
TOOL_CALL_REPLAY_DETECTED event and prepends a steer onto the stored
|
||||
result, but still executes the call (warn + execute)."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
async def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=f"fresh result for {tool_use.id}",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
# Turn 1: model calls browser_setup with id=call_1
|
||||
# Turn 2: model calls browser_setup AGAIN with id=call_2 (the replay)
|
||||
# Turn 3: text stop
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("browser_setup", {}, tool_use_id="call_1"),
|
||||
tool_call_scenario("browser_setup", {}, tool_use_id="call_2"),
|
||||
text_scenario("done"),
|
||||
]
|
||||
)
|
||||
|
||||
tools = [Tool(name="browser_setup", description="", parameters={})]
|
||||
|
||||
# Capture events from the bus.
|
||||
captured: list[Any] = []
|
||||
bus = EventBus()
|
||||
|
||||
async def _collect(evt):
|
||||
captured.append(evt)
|
||||
|
||||
bus.subscribe([EventType.TOOL_CALL_REPLAY_DETECTED], _collect)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
buffer,
|
||||
llm,
|
||||
tools=tools,
|
||||
is_subagent_mode=True,
|
||||
)
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
node = EventLoopNode(
|
||||
tool_executor=tool_exec,
|
||||
conversation_store=store,
|
||||
event_bus=bus,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
await node.execute(ctx)
|
||||
|
||||
# Exactly one replay-detected event fired for the second call.
|
||||
assert len(captured) == 1
|
||||
assert captured[0].data["tool_name"] == "browser_setup"
|
||||
|
||||
# The stored tool result for the replay carries the steer prefix,
|
||||
# and the real execution output is preserved.
|
||||
parts = await store.read_parts()
|
||||
tool_msgs = [
|
||||
p for p in parts if p.get("role") == "tool" and p.get("tool_use_id") == "call_2"
|
||||
]
|
||||
assert len(tool_msgs) == 1
|
||||
assert tool_msgs[0]["content"].startswith("[Replay detected: browser_setup")
|
||||
assert "fresh result for call_2" in tool_msgs[0]["content"]
|
||||
|
||||
# The first call's result is untouched.
|
||||
first = [
|
||||
p for p in parts if p.get("role") == "tool" and p.get("tool_use_id") == "call_1"
|
||||
]
|
||||
assert first[0]["content"] == "fresh result for call_1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_with_error_prior_does_not_fire(
|
||||
self, tmp_path, runtime, node_spec, buffer
|
||||
):
|
||||
"""A prior call that errored does not count as a successful completion,
|
||||
so re-emitting it is legitimate (not a replay)."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
async def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
is_err = tool_use.id == "call_1"
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=("boom" if is_err else "ok"),
|
||||
is_error=is_err,
|
||||
)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("flaky", {}, tool_use_id="call_1"),
|
||||
tool_call_scenario("flaky", {}, tool_use_id="call_2"),
|
||||
text_scenario("recovered"),
|
||||
]
|
||||
)
|
||||
tools = [Tool(name="flaky", description="", parameters={})]
|
||||
|
||||
captured: list[Any] = []
|
||||
bus = EventBus()
|
||||
|
||||
async def _collect(evt):
|
||||
captured.append(evt)
|
||||
|
||||
bus.subscribe([EventType.TOOL_CALL_REPLAY_DETECTED], _collect)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
buffer,
|
||||
llm,
|
||||
tools=tools,
|
||||
is_subagent_mode=True,
|
||||
)
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
node = EventLoopNode(
|
||||
tool_executor=tool_exec,
|
||||
conversation_store=store,
|
||||
event_bus=bus,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
await node.execute(ctx)
|
||||
|
||||
assert captured == []
|
||||
|
||||
@@ -24,6 +24,7 @@ class MockConversationStore:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._parts: dict[int, dict] = {}
|
||||
self._partials: dict[int, dict] = {}
|
||||
self._meta: dict | None = None
|
||||
self._cursor: dict | None = None
|
||||
|
||||
@@ -48,6 +49,18 @@ class MockConversationStore:
|
||||
async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None:
|
||||
self._parts = {k: v for k, v in self._parts.items() if k >= seq}
|
||||
|
||||
async def write_partial(self, seq: int, data: dict[str, Any]) -> None:
|
||||
self._partials[seq] = data
|
||||
|
||||
async def read_partial(self, seq: int) -> dict[str, Any] | None:
|
||||
return self._partials.get(seq)
|
||||
|
||||
async def read_all_partials(self) -> list[dict[str, Any]]:
|
||||
return [self._partials[k] for k in sorted(self._partials)]
|
||||
|
||||
async def clear_partial(self, seq: int) -> None:
|
||||
self._partials.pop(seq, None)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -750,6 +763,33 @@ class TestFileConversationStore:
|
||||
assert (base / "parts" / "0000000000.json").exists()
|
||||
assert (base / "parts" / "0000000001.json").exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partials_separate_from_parts(self, tmp_path):
|
||||
"""Partial checkpoints must not pollute read_parts() and vice versa."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
await store.write_part(0, {"seq": 0, "content": "real"})
|
||||
await store.write_partial(1, {"seq": 1, "content": "inflight", "truncated": True})
|
||||
parts = await store.read_parts()
|
||||
assert [p["seq"] for p in parts] == [0]
|
||||
partials = await store.read_all_partials()
|
||||
assert [p["seq"] for p in partials] == [1]
|
||||
assert partials[0]["content"] == "inflight"
|
||||
assert (await store.read_partial(1))["content"] == "inflight"
|
||||
assert await store.read_partial(99) is None
|
||||
await store.clear_partial(1)
|
||||
assert await store.read_all_partials() == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partials_dir_does_not_break_parts_glob(self, tmp_path):
|
||||
"""delete_parts_before parses stems as int — partial files must not trip it."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
for i in range(3):
|
||||
await store.write_part(i, {"seq": i})
|
||||
await store.write_partial(i + 100, {"seq": i + 100})
|
||||
await store.delete_parts_before(2)
|
||||
assert [p["seq"] for p in await store.read_parts()] == [2]
|
||||
assert [p["seq"] for p in await store.read_all_partials()] == [100, 101, 102]
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Integration tests — real FileConversationStore, no mocks
|
||||
@@ -1646,3 +1686,169 @@ class TestRepairOrphanedToolCalls:
|
||||
roles = [m["role"] for m in repaired]
|
||||
assert roles == ["user", "assistant", "tool", "user"]
|
||||
assert repaired[2]["tool_call_id"] == "tc_2"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Continue-nudge + replay-detector helpers (DS-14)
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def _mk_assistant_with_tool_call(seq: int, tc_id: str, name: str, args: dict) -> Message:
|
||||
return Message(
|
||||
seq=seq,
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": tc_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": json.dumps(args)},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestFindCompletedToolCall:
|
||||
def test_returns_match_when_prior_non_error_result_exists(self):
|
||||
conv = NodeConversation(system_prompt="s")
|
||||
conv._messages = [
|
||||
Message(seq=0, role="user", content="go"),
|
||||
_mk_assistant_with_tool_call(1, "tc_a", "browser_setup", {}),
|
||||
Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"),
|
||||
]
|
||||
match = conv.find_completed_tool_call("browser_setup", {})
|
||||
assert match is not None
|
||||
assert match.seq == 1
|
||||
|
||||
def test_ignores_error_result(self):
|
||||
conv = NodeConversation(system_prompt="s")
|
||||
conv._messages = [
|
||||
Message(seq=0, role="user", content="go"),
|
||||
_mk_assistant_with_tool_call(1, "tc_a", "browser_navigate", {"url": "x"}),
|
||||
Message(seq=2, role="tool", content="boom", tool_use_id="tc_a", is_error=True),
|
||||
]
|
||||
assert conv.find_completed_tool_call("browser_navigate", {"url": "x"}) is None
|
||||
|
||||
def test_canonicalizes_json_args_regardless_of_key_order(self):
|
||||
conv = NodeConversation(system_prompt="s")
|
||||
# Prior args written in one order, new call re-emits in different order.
|
||||
conv._messages = [
|
||||
Message(seq=0, role="user", content="go"),
|
||||
_mk_assistant_with_tool_call(1, "tc_a", "fetch", {"b": 2, "a": 1}),
|
||||
Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"),
|
||||
]
|
||||
assert conv.find_completed_tool_call("fetch", {"a": 1, "b": 2}) is not None
|
||||
# Different args should NOT match.
|
||||
assert conv.find_completed_tool_call("fetch", {"a": 1, "b": 3}) is None
|
||||
|
||||
def test_respects_within_last_turns_window(self):
|
||||
conv = NodeConversation(system_prompt="s")
|
||||
# Prior successful call, then 4 newer assistant turns of noise.
|
||||
conv._messages = [
|
||||
Message(seq=0, role="user", content="go"),
|
||||
_mk_assistant_with_tool_call(1, "tc_a", "browser_setup", {}),
|
||||
Message(seq=2, role="tool", content="ok", tool_use_id="tc_a"),
|
||||
]
|
||||
# 4 newer assistant turns (no tool calls that match)
|
||||
for i in range(3, 7):
|
||||
conv._messages.append(
|
||||
Message(seq=i, role="assistant", content=f"noise {i}")
|
||||
)
|
||||
# Window=3 → prior assistant with browser_setup is at turn index 5
|
||||
# backwards (noise, noise, noise, noise, setup) — skipped.
|
||||
assert (
|
||||
conv.find_completed_tool_call("browser_setup", {}, within_last_turns=3)
|
||||
is None
|
||||
)
|
||||
# Window=10 → found.
|
||||
assert (
|
||||
conv.find_completed_tool_call("browser_setup", {}, within_last_turns=10)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
class TestPartialCheckpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_is_cleared_when_real_part_lands(self, tmp_path):
|
||||
"""A partial for seq N is wiped once add_assistant_message(seq=N) persists."""
|
||||
store = FileConversationStore(tmp_path / "c")
|
||||
conv = NodeConversation(system_prompt="s", store=store)
|
||||
await conv.add_user_message("hi")
|
||||
# Seed a partial for the would-be next assistant seq.
|
||||
await conv.checkpoint_partial_assistant("half-written...")
|
||||
partials = await store.read_all_partials()
|
||||
assert len(partials) == 1
|
||||
assert partials[0]["content"] == "half-written..."
|
||||
# Commit the real assistant turn — partial should be swept.
|
||||
await conv.add_assistant_message("fully written")
|
||||
assert await store.read_all_partials() == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_surfaces_partial_as_truncated_message(self, tmp_path):
|
||||
"""A partial left behind by a crashed stream is resurrected on restore."""
|
||||
store = FileConversationStore(tmp_path / "c")
|
||||
conv = NodeConversation(system_prompt="s", store=store)
|
||||
await conv.add_user_message("hi")
|
||||
# Simulate a stream that produced some text + a tool call, then died
|
||||
# before finishing. The checkpoint captures both.
|
||||
await conv.checkpoint_partial_assistant(
|
||||
"I was working on this when the stream died",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tc_x",
|
||||
"type": "function",
|
||||
"function": {"name": "browser_click", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
)
|
||||
# Fresh process — restore from disk.
|
||||
fresh = await NodeConversation.restore(store)
|
||||
assert fresh is not None
|
||||
# The user message is there, plus the truncated assistant resurrected
|
||||
# from the partial.
|
||||
roles = [m.role for m in fresh.messages]
|
||||
assert roles == ["user", "assistant"]
|
||||
last = fresh.messages[-1]
|
||||
assert last.truncated is True
|
||||
assert last.content == "I was working on this when the stream died"
|
||||
assert last.tool_calls and last.tool_calls[0]["function"]["name"] == "browser_click"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_cleans_stale_partials(self, tmp_path):
|
||||
"""A partial whose seq was already committed as a real part is discarded."""
|
||||
store = FileConversationStore(tmp_path / "c")
|
||||
conv = NodeConversation(system_prompt="s", store=store)
|
||||
await conv.add_user_message("hi")
|
||||
await conv.add_assistant_message("real") # seq=1
|
||||
# Manually plant a stale partial at seq=1 (already committed).
|
||||
await store.write_partial(
|
||||
1, {"seq": 1, "role": "assistant", "content": "stale", "truncated": True}
|
||||
)
|
||||
fresh = await NodeConversation.restore(store)
|
||||
assert fresh is not None
|
||||
assert [m.content for m in fresh.messages] == ["hi", "real"]
|
||||
# Stale partial swept by restore.
|
||||
assert await store.read_all_partials() == []
|
||||
|
||||
|
||||
class TestMessageFlags:
|
||||
def test_is_system_nudge_roundtrip(self):
|
||||
m = Message(seq=0, role="user", content="nudge", is_system_nudge=True)
|
||||
d = m.to_storage_dict()
|
||||
assert d.get("is_system_nudge") is True
|
||||
r = Message.from_storage_dict(d)
|
||||
assert r.is_system_nudge is True
|
||||
assert r.role == "user"
|
||||
|
||||
def test_truncated_roundtrip(self):
|
||||
m = Message(seq=0, role="assistant", content="half", truncated=True)
|
||||
d = m.to_storage_dict()
|
||||
assert d.get("truncated") is True
|
||||
r = Message.from_storage_dict(d)
|
||||
assert r.truncated is True
|
||||
|
||||
def test_defaults_omit_flags_from_storage(self):
|
||||
m = Message(seq=0, role="user", content="plain")
|
||||
d = m.to_storage_dict()
|
||||
assert "is_system_nudge" not in d
|
||||
assert "truncated" not in d
|
||||
|
||||
Reference in New Issue
Block a user