feat: robust run id

This commit is contained in:
Richard Tang
2026-04-02 12:35:16 -07:00
parent 00c55d5fb2
commit 60d094464a
8 changed files with 105 additions and 63 deletions
+33 -11
View File
@@ -11,6 +11,11 @@ from typing import Any, Literal, Protocol, runtime_checkable
LEGACY_RUN_ID = "__legacy_run__" LEGACY_RUN_ID = "__legacy_run__"
def is_legacy_run_id(run_id: str | None) -> bool:
"""True when run_id represents pre-migration (no run boundary) data."""
return run_id is None or run_id == LEGACY_RUN_ID
@dataclass @dataclass
class Message: class Message:
"""A single message in a conversation. """A single message in a conversation.
@@ -1131,17 +1136,28 @@ class NodeConversation:
await self._write_next_seq() await self._write_next_seq()
async def _persist_meta(self) -> None: async def _persist_meta(self) -> None:
"""Lazily write conversation metadata to the store (called once).""" """Lazily write conversation metadata to the store (called once).
When ``self._run_id`` is set, metadata is keyed under
``meta["runs"][run_id]`` so multiple runs can coexist in the same
session. Legacy (no run_id) sessions write flat for backward compat.
"""
if self._store is None: if self._store is None:
return return
await self._store.write_meta( run_meta = {
{ "system_prompt": self._system_prompt,
"system_prompt": self._system_prompt, "max_context_tokens": self._max_context_tokens,
"max_context_tokens": self._max_context_tokens, "compaction_threshold": self._compaction_threshold,
"compaction_threshold": self._compaction_threshold, "output_keys": self._output_keys,
"output_keys": self._output_keys, }
} if self._run_id:
) existing = await self._store.read_meta() or {}
runs = dict(existing.get("runs", {}))
runs[self._run_id] = run_meta
existing["runs"] = runs
await self._store.write_meta(existing)
else:
await self._store.write_meta(run_meta)
self._meta_persisted = True self._meta_persisted = True
async def _write_next_seq(self) -> None: async def _write_next_seq(self) -> None:
@@ -1175,6 +1191,12 @@ class NodeConversation:
if meta is None: if meta is None:
return None return None
# Extract run-scoped metadata when available
if run_id and isinstance(meta.get("runs"), dict):
run_meta = meta["runs"].get(run_id)
if run_meta is not None:
meta = run_meta
conv = cls( conv = cls(
system_prompt=meta.get("system_prompt", ""), system_prompt=meta.get("system_prompt", ""),
max_context_tokens=meta.get("max_context_tokens", 32000), max_context_tokens=meta.get("max_context_tokens", 32000),
@@ -1187,8 +1209,8 @@ class NodeConversation:
parts = await store.read_parts() parts = await store.read_parts()
if run_id is not None: if run_id is not None:
if run_id == LEGACY_RUN_ID: if is_legacy_run_id(run_id):
parts = [p for p in parts if p.get("run_id") in (None, LEGACY_RUN_ID)] parts = [p for p in parts if is_legacy_run_id(p.get("run_id"))]
else: else:
parts = [p for p in parts if p.get("run_id") == run_id] parts = [p for p in parts if p.get("run_id") == run_id]
if phase_id: if phase_id:
@@ -61,17 +61,17 @@ async def restore(
conversation = await NodeConversation.restore( conversation = await NodeConversation.restore(
conversation_store, conversation_store,
phase_id=phase_filter, phase_id=phase_filter,
run_id=ctx.run_id or None, run_id=ctx.effective_run_id,
) )
if conversation is None: if conversation is None:
return None return None
accumulator = await OutputAccumulator.restore(conversation_store, run_id=ctx.run_id or None) accumulator = await OutputAccumulator.restore(conversation_store, run_id=ctx.effective_run_id)
accumulator.spillover_dir = config.spillover_dir accumulator.spillover_dir = config.spillover_dir
accumulator.max_value_chars = config.max_output_value_chars accumulator.max_value_chars = config.max_output_value_chars
cursor = await conversation_store.read_cursor() cursor = await conversation_store.read_cursor()
run_cursor = get_run_cursor(cursor, ctx.run_id or None) run_cursor = get_run_cursor(cursor, ctx.effective_run_id)
start_iteration = run_cursor.get("iteration", 0) + 1 if run_cursor else 0 start_iteration = run_cursor.get("iteration", 0) + 1 if run_cursor else 0
# Restore stall/doom-loop detection state # Restore stall/doom-loop detection state
@@ -128,7 +128,7 @@ async def write_cursor(
run_cursor["recent_tool_fingerprints"] = [ run_cursor["recent_tool_fingerprints"] = [
[list(pair) for pair in fps] for fps in recent_tool_fingerprints [list(pair) for pair in fps] for fps in recent_tool_fingerprints
] ]
await conversation_store.write_cursor(update_run_cursor(cursor, ctx.run_id or None, run_cursor)) await conversation_store.write_cursor(update_run_cursor(cursor, ctx.effective_run_id, run_cursor))
async def drain_injection_queue( async def drain_injection_queue(
+9 -3
View File
@@ -367,7 +367,7 @@ class EventLoopNode(NodeProtocol):
store=self._conversation_store, store=self._conversation_store,
spillover_dir=self._config.spillover_dir, spillover_dir=self._config.spillover_dir,
max_value_chars=self._config.max_output_value_chars, max_value_chars=self._config.max_output_value_chars,
run_id=ctx.run_id or None, run_id=ctx.effective_run_id,
) )
start_iteration = 0 start_iteration = 0
_restored_recent_responses: list[str] = [] _restored_recent_responses: list[str] = []
@@ -418,6 +418,12 @@ class EventLoopNode(NodeProtocol):
if conversation.system_prompt != _current_prompt: if conversation.system_prompt != _current_prompt:
conversation.update_system_prompt(_current_prompt) conversation.update_system_prompt(_current_prompt)
logger.info("Refreshed system prompt for restored conversation") logger.info("Refreshed system prompt for restored conversation")
# Refresh other meta fields that may differ across runs
conversation._max_context_tokens = self._config.max_context_tokens
if ctx.node_spec.output_keys:
conversation._output_keys = ctx.node_spec.output_keys
conversation._meta_persisted = False # Force re-persist with updated values
else: else:
_restored_recent_responses = [] _restored_recent_responses = []
_restored_tool_fingerprints = [] _restored_tool_fingerprints = []
@@ -481,7 +487,7 @@ class EventLoopNode(NodeProtocol):
max_context_tokens=self._config.max_context_tokens, max_context_tokens=self._config.max_context_tokens,
output_keys=ctx.node_spec.output_keys or None, output_keys=ctx.node_spec.output_keys or None,
store=self._conversation_store, store=self._conversation_store,
run_id=ctx.run_id or None, run_id=ctx.effective_run_id,
) )
# Stamp phase for first node in continuous mode # Stamp phase for first node in continuous mode
if _is_continuous: if _is_continuous:
@@ -490,7 +496,7 @@ class EventLoopNode(NodeProtocol):
store=self._conversation_store, store=self._conversation_store,
spillover_dir=self._config.spillover_dir, spillover_dir=self._config.spillover_dir,
max_value_chars=self._config.max_output_value_chars, max_value_chars=self._config.max_output_value_chars,
run_id=ctx.run_id or None, run_id=ctx.effective_run_id,
) )
start_iteration = 0 start_iteration = 0
+10
View File
@@ -481,6 +481,16 @@ class NodeContext:
execution_id: str = "" execution_id: str = ""
run_id: str = "" run_id: str = ""
@property
def effective_run_id(self) -> str | None:
"""Normalized run_id: returns run_id if truthy, otherwise None.
The field defaults to ``""``; callers should use this property
instead of ``self.run_id or None`` to avoid silently falling
back to session-scoped storage.
"""
return self.run_id or None
# Stream identity — the ExecutionStream this node runs within. # Stream identity — the ExecutionStream this node runs within.
# Falls back to node_id when not set (legacy / standalone executor). # Falls back to node_id when not set (legacy / standalone executor).
stream_id: str = "" stream_id: str = ""
+35 -35
View File
@@ -250,7 +250,10 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
def _load_resume_state( def _load_resume_state(
agent_path: str, session_id: str, checkpoint_id: str | None = None agent_path: str, session_id: str, checkpoint_id: str | None = None
) -> dict | None: ) -> dict | None:
"""Load session or checkpoint state for headless resume. """Load checkpoint state for headless resume.
All resumes require a checkpoint. If ``checkpoint_id`` is not provided
the latest checkpoint is auto-discovered.
Args: Args:
agent_path: Path to the agent folder (e.g., exports/my_agent) agent_path: Path to the agent folder (e.g., exports/my_agent)
@@ -258,7 +261,7 @@ def _load_resume_state(
checkpoint_id: Optional checkpoint ID within the session checkpoint_id: Optional checkpoint ID within the session
Returns: Returns:
session_state dict for executor, or None if not found session_state dict for executor, or None if no checkpoint found
""" """
agent_name = Path(agent_path).name agent_name = Path(agent_path).name
agent_work_dir = Path.home() / ".hive" / "agents" / agent_name agent_work_dir = Path.home() / ".hive" / "agents" / agent_name
@@ -267,40 +270,37 @@ def _load_resume_state(
if not session_dir.exists(): if not session_dir.exists():
return None return None
if checkpoint_id: # Auto-discover latest checkpoint when not specified
# Checkpoint-based resume: load checkpoint and extract state if not checkpoint_id:
cp_path = session_dir / "checkpoints" / f"{checkpoint_id}.json" cp_dir = session_dir / "checkpoints"
if not cp_path.exists(): if cp_dir.exists():
checkpoints = sorted(
cp_dir.glob("*.json"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
if checkpoints:
checkpoint_id = checkpoints[0].stem
if not checkpoint_id:
return None return None
try:
cp_data = json.loads(cp_path.read_text(encoding="utf-8")) cp_path = session_dir / "checkpoints" / f"{checkpoint_id}.json"
except (json.JSONDecodeError, OSError): if not cp_path.exists():
return None return None
return { try:
"resume_session_id": session_id, cp_data = json.loads(cp_path.read_text(encoding="utf-8"))
"data_buffer": cp_data.get("data_buffer", cp_data.get("shared_memory", {})), except (json.JSONDecodeError, OSError):
"paused_at": cp_data.get("next_node") or cp_data.get("current_node"), return None
"execution_path": cp_data.get("execution_path", []),
"node_visit_counts": {}, return {
} "resume_session_id": session_id,
else: "resume_from_checkpoint": checkpoint_id,
# Session state resume: load state.json "run_id": cp_data.get("run_id") or None,
state_path = session_dir / "state.json" "data_buffer": cp_data.get("data_buffer", cp_data.get("shared_memory", {})),
if not state_path.exists(): "paused_at": cp_data.get("next_node") or cp_data.get("current_node"),
return None "execution_path": cp_data.get("execution_path", []),
try: "node_visit_counts": cp_data.get("node_visit_counts", {}),
state_data = json.loads(state_path.read_text(encoding="utf-8")) }
except (json.JSONDecodeError, OSError):
return None
progress = state_data.get("progress", {})
paused_at = progress.get("paused_at") or progress.get("resume_from")
return {
"resume_session_id": session_id,
"data_buffer": state_data.get("data_buffer", state_data.get("memory", {})),
"paused_at": paused_at,
"execution_path": progress.get("path", []),
"node_visit_counts": progress.get("node_visit_counts", {}),
}
def _prompt_before_start(agent_path: str, runner, model: str | None = None): def _prompt_before_start(agent_path: str, runner, model: str | None = None):
+8 -5
View File
@@ -169,11 +169,10 @@ class SessionState(BaseModel):
def is_resumable(self) -> bool: def is_resumable(self) -> bool:
"""Can this session be resumed? """Can this session be resumed?
Every non-completed session is resumable. If resume_from/paused_at Only sessions with a valid checkpoint can be resumed.
aren't set, the executor falls back to the graph entry point — State-based resume (without a checkpoint) is no longer supported.
so we don't gate on those. Even catastrophic failures are resumable.
""" """
return self.status != SessionStatus.COMPLETED return self.is_resumable_from_checkpoint
@computed_field @computed_field
@property @property
@@ -294,7 +293,11 @@ class SessionState(BaseModel):
) )
def to_session_state_dict(self) -> dict[str, Any]: def to_session_state_dict(self) -> dict[str, Any]:
"""Convert to session_state format for GraphExecutor.execute().""" """Convert to session_state format for GraphExecutor.execute().
NOTE: state-based resume via paused_at/resume_from is deprecated.
Use checkpoint-based resume (``resume_from_checkpoint`` key) instead.
"""
# Derive resume target: explicit > last node in path > entry point # Derive resume target: explicit > last node in path > entry point
resume_from = ( resume_from = (
self.progress.resume_from self.progress.resume_from
+3 -3
View File
@@ -29,7 +29,7 @@ import shutil
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from framework.graph.conversation import LEGACY_RUN_ID from framework.graph.conversation import LEGACY_RUN_ID, is_legacy_run_id
class FileConversationStore: class FileConversationStore:
@@ -109,8 +109,8 @@ class FileConversationStore:
continue continue
data = self._read_json(f) or {} data = self._read_json(f) or {}
part_run_id = data.get("run_id") part_run_id = data.get("run_id")
if run_id == LEGACY_RUN_ID: if is_legacy_run_id(run_id):
if part_run_id in (None, LEGACY_RUN_ID): if is_legacy_run_id(part_run_id):
f.unlink() f.unlink()
elif part_run_id == run_id: elif part_run_id == run_id:
f.unlink() f.unlink()
+3 -2
View File
@@ -12,6 +12,7 @@ from framework.graph.conversation import (
Message, Message,
NodeConversation, NodeConversation,
extract_tool_call_history, extract_tool_call_history,
is_legacy_run_id,
) )
from framework.storage.conversation_store import FileConversationStore from framework.storage.conversation_store import FileConversationStore
@@ -55,8 +56,8 @@ class MockConversationStore:
if run_id is None: if run_id is None:
continue continue
part_run_id = value.get("run_id") part_run_id = value.get("run_id")
if run_id == LEGACY_RUN_ID: if is_legacy_run_id(run_id):
if part_run_id not in (None, LEGACY_RUN_ID): if not is_legacy_run_id(part_run_id):
kept[key] = value kept[key] = value
elif part_run_id != run_id: elif part_run_id != run_id:
kept[key] = value kept[key] = value