Merge pull request #4141 from TimothyZhang7/feature/resumable-sessions

Feature/resumable sessions

Release candidate for v0.4.2
This commit is contained in:
Timothy @aden
2026-02-08 20:40:33 -08:00
committed by GitHub
15 changed files with 2850 additions and 21 deletions
+85
View File
@@ -495,11 +495,96 @@ max_node_visits=3 # Prevent getting stuck
- Confirm it calls set_output eventually
```
#### Template 6: Checkpoint Recovery (Post-Fix Resume)
```markdown
## Recovery Strategy: Resume from Last Clean Checkpoint
**Situation:** You've fixed the issue, but the failed session is stuck mid-execution
**Solution:** Resume execution from a checkpoint before the failure
### Option A: Auto-Resume from Latest Checkpoint (Recommended)
Use CLI arguments to auto-resume when launching TUI:
```bash
PYTHONPATH=core:exports python -m {agent_name} --tui \
--resume-session {session_id}
```
This will:
- Load session state from `state.json`
- Continue from where it paused/failed
- Apply your fixes immediately
### Option B: Resume from Specific Checkpoint (Time-Travel)
If you need to go back to an earlier point:
```bash
PYTHONPATH=core:exports python -m {agent_name} --tui \
--resume-session {session_id} \
--checkpoint {checkpoint_id}
```
Example:
```bash
PYTHONPATH=core:exports python -m deep_research_agent --tui \
--resume-session session_20260208_143022_abc12345 \
--checkpoint cp_node_complete_intake_143030
```
### Option C: Use TUI Commands
Alternatively, launch TUI normally and use commands:
```bash
# Launch TUI
PYTHONPATH=core:exports python -m {agent_name} --tui
# In TUI, use commands:
/resume {session_id} # Resume from session state
/recover {session_id} {checkpoint_id} # Recover from specific checkpoint
```
### When to Use Each Option:
**Use `/resume` (or --resume-session) when:**
- You fixed credentials and want to retry
- Agent paused and you want to continue
- Agent failed and you want to retry from last state
**Use `/recover` (or --resume-session + --checkpoint) when:**
- You need to go back to an earlier checkpoint
- You want to try a different path from a specific point
- Debugging requires time-travel to earlier state
### Find Available Checkpoints:
```bash
# In TUI:
/sessions {session_id}
# This shows all checkpoints with timestamps:
Available Checkpoints: (3)
1. cp_node_complete_intake_143030
2. cp_node_complete_research_143115
3. cp_pause_research_143130
```
**Verification:**
- Use `--resume-session` to test your fix immediately
- No need to re-run from the beginning
- Session continues with your code changes applied
```
**Selecting the right template:**
- Match the issue category from Stage 4
- Customize with specific details from Stage 5
- Include actual error messages and code snippets
- Provide file paths and line numbers when possible
- **Always include recovery commands** (Template 6) after providing fix recommendations
---
+85
View File
@@ -0,0 +1,85 @@
"""
Checkpoint Configuration - Controls checkpoint behavior during execution.
"""
from dataclasses import dataclass
@dataclass
class CheckpointConfig:
"""
Configuration for checkpoint behavior during graph execution.
Controls when checkpoints are created, how they're stored,
and when they're pruned.
"""
# Enable/disable checkpointing
enabled: bool = True
# When to checkpoint
checkpoint_on_node_start: bool = True
checkpoint_on_node_complete: bool = True
# Pruning (time-based)
checkpoint_max_age_days: int = 7 # Prune checkpoints older than 1 week
prune_every_n_nodes: int = 10 # Check for pruning every N nodes
# Performance
async_checkpoint: bool = True # Don't block execution on checkpoint writes
# What to include in checkpoints
include_full_memory: bool = True
include_metrics: bool = True
def should_checkpoint_node_start(self) -> bool:
"""Check if should checkpoint before node execution."""
return self.enabled and self.checkpoint_on_node_start
def should_checkpoint_node_complete(self) -> bool:
"""Check if should checkpoint after node execution."""
return self.enabled and self.checkpoint_on_node_complete
def should_prune_checkpoints(self, nodes_executed: int) -> bool:
"""
Check if should prune checkpoints based on execution progress.
Args:
nodes_executed: Number of nodes executed so far
Returns:
True if should check for old checkpoints and prune them
"""
return (
self.enabled
and self.prune_every_n_nodes > 0
and nodes_executed % self.prune_every_n_nodes == 0
)
# Default configuration for most agents
DEFAULT_CHECKPOINT_CONFIG = CheckpointConfig(
enabled=True,
checkpoint_on_node_start=True,
checkpoint_on_node_complete=True,
checkpoint_max_age_days=7,
prune_every_n_nodes=10,
async_checkpoint=True,
)
# Minimal configuration (only checkpoint at node completion)
MINIMAL_CHECKPOINT_CONFIG = CheckpointConfig(
enabled=True,
checkpoint_on_node_start=False,
checkpoint_on_node_complete=True,
checkpoint_max_age_days=7,
prune_every_n_nodes=20,
async_checkpoint=True,
)
# Disabled configuration (no checkpointing)
DISABLED_CHECKPOINT_CONFIG = CheckpointConfig(
enabled=False,
)
+16 -2
View File
@@ -1763,7 +1763,19 @@ class EventLoopNode(NodeProtocol):
conversation: NodeConversation,
iteration: int,
) -> bool:
"""Check if pause has been requested. Returns True if paused."""
"""
Check if pause has been requested. Returns True if paused.
Note: This check happens BEFORE starting iteration N, after completing N-1.
If paused, the node exits having completed {iteration} iterations (0 to iteration-1).
"""
# Check executor-level pause event (for /pause command, Ctrl+Z)
if ctx.pause_event and ctx.pause_event.is_set():
completed = iteration # 0-indexed: iteration=3 means 3 iterations completed (0,1,2)
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (executor-level)")
return True
# Check context-level pause flags (legacy/alternative methods)
pause_requested = ctx.input_data.get("pause_requested", False)
if not pause_requested:
try:
@@ -1771,8 +1783,10 @@ class EventLoopNode(NodeProtocol):
except (PermissionError, KeyError):
pause_requested = False
if pause_requested:
logger.info(f"Pause requested at iteration {iteration}")
completed = iteration
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (context-level)")
return True
return False
# -------------------------------------------------------------------
+343 -2
View File
@@ -17,6 +17,7 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from framework.graph.checkpoint_config import CheckpointConfig
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.goal import Goal
from framework.graph.node import (
@@ -33,6 +34,8 @@ from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
from framework.graph.validator import OutputValidator
from framework.llm.provider import LLMProvider, Tool
from framework.runtime.core import Runtime
from framework.schemas.checkpoint import Checkpoint
from framework.storage.checkpoint_store import CheckpointStore
@dataclass
@@ -179,6 +182,9 @@ class GraphExecutor:
self.enable_parallel_execution = enable_parallel_execution
self._parallel_config = parallel_config or ParallelExecutionConfig()
# Pause/resume control
self._pause_requested = asyncio.Event()
def _validate_tools(self, graph: GraphSpec) -> list[str]:
"""
Validate that all tools declared by nodes are available.
@@ -208,6 +214,7 @@ class GraphExecutor:
goal: Goal,
input_data: dict[str, Any] | None = None,
session_state: dict[str, Any] | None = None,
checkpoint_config: "CheckpointConfig | None" = None,
) -> ExecutionResult:
"""
Execute a graph for a goal.
@@ -246,6 +253,12 @@ class GraphExecutor:
# Initialize execution state
memory = SharedMemory()
# Initialize checkpoint store if checkpointing is enabled
checkpoint_store: CheckpointStore | None = None
if checkpoint_config and checkpoint_config.enabled and self._storage_path:
checkpoint_store = CheckpointStore(self._storage_path)
self.logger.info("✓ Checkpointing enabled")
# Restore session state if provided
if session_state and "memory" in session_state:
memory_data = session_state["memory"]
@@ -273,8 +286,110 @@ class GraphExecutor:
node_visit_counts: dict[str, int] = {} # Track visits for feedback loops
_is_retry = False # True when looping back for a retry (not a new visit)
# Restore node_visit_counts from session state if available
if session_state and "node_visit_counts" in session_state:
node_visit_counts = dict(session_state["node_visit_counts"])
if node_visit_counts:
self.logger.info(f"📥 Restored node visit counts: {node_visit_counts}")
# If resuming at a specific node (paused_at), that node was counted
# but never completed, so decrement its count
paused_at = session_state.get("paused_at")
if (
paused_at
and paused_at in node_visit_counts
and node_visit_counts[paused_at] > 0
):
old_count = node_visit_counts[paused_at]
node_visit_counts[paused_at] -= 1
self.logger.info(
f"📥 Decremented visit count for paused node '{paused_at}': "
f"{old_count} -> {node_visit_counts[paused_at]}"
)
# Determine entry point (may differ if resuming)
current_node_id = graph.get_entry_point(session_state)
# Check if resuming from checkpoint
if session_state and session_state.get("resume_from_checkpoint") and checkpoint_store:
checkpoint_id = session_state["resume_from_checkpoint"]
try:
checkpoint = await checkpoint_store.load_checkpoint(checkpoint_id)
if checkpoint:
self.logger.info(
f"🔄 Resuming from checkpoint: {checkpoint_id} "
f"(node: {checkpoint.current_node})"
)
# Restore memory from checkpoint
for key, value in checkpoint.shared_memory.items():
memory.write(key, value, validate=False)
# Start from checkpoint's next node or current node
current_node_id = (
checkpoint.next_node or checkpoint.current_node or graph.entry_node
)
# Restore execution path
path.extend(checkpoint.execution_path)
self.logger.info(
f"📥 Restored memory with {len(checkpoint.shared_memory)} keys, "
f"resuming at node: {current_node_id}"
)
else:
self.logger.warning(
f"Checkpoint {checkpoint_id} not found, resuming from normal entry point"
)
# Check if resuming from paused_at (fallback to session state)
paused_at = session_state.get("paused_at") if session_state else None
if paused_at and graph.get_node(paused_at) is not None:
current_node_id = paused_at
self.logger.info(f"🔄 Resuming from paused node: {paused_at}")
else:
current_node_id = graph.get_entry_point(session_state)
except Exception as e:
self.logger.error(
f"Failed to load checkpoint {checkpoint_id}: {e}, "
f"resuming from normal entry point"
)
# Check if resuming from paused_at (fallback to session state)
paused_at = session_state.get("paused_at") if session_state else None
if paused_at and graph.get_node(paused_at) is not None:
current_node_id = paused_at
self.logger.info(f"🔄 Resuming from paused node: {paused_at}")
else:
current_node_id = graph.get_entry_point(session_state)
else:
# Check if resuming from paused_at (session state resume)
paused_at = session_state.get("paused_at") if session_state else None
node_ids = [n.id for n in graph.nodes]
self.logger.info(f"🔍 Debug: paused_at={paused_at}, available node IDs={node_ids}")
if paused_at and graph.get_node(paused_at) is not None:
# Resume from paused_at node directly (works for any node, not just pause_nodes)
current_node_id = paused_at
# Restore execution path from session state if available
if session_state:
execution_path = session_state.get("execution_path", [])
if execution_path:
path.extend(execution_path)
self.logger.info(
f"🔄 Resuming from paused node: {paused_at} "
f"(restored path: {execution_path})"
)
else:
self.logger.info(f"🔄 Resuming from paused node: {paused_at}")
else:
self.logger.info(f"🔄 Resuming from paused node: {paused_at}")
else:
# Fall back to normal entry point logic
self.logger.warning(
f"⚠ paused_at={paused_at} is not a valid node, falling back to entry point"
)
current_node_id = graph.get_entry_point(session_state)
steps = 0
if session_state and current_node_id != graph.entry_node:
@@ -313,6 +428,45 @@ class GraphExecutor:
while steps < graph.max_steps:
steps += 1
# Check for pause request
if self._pause_requested.is_set():
self.logger.info("⏸ Pause detected - stopping at node boundary")
# Create session state for pause
saved_memory = memory.read_all()
pause_session_state: dict[str, Any] = {
"memory": saved_memory, # Include memory for resume
"execution_path": list(path),
"node_visit_counts": dict(node_visit_counts),
}
# Create a pause checkpoint
if checkpoint_store:
pause_checkpoint = self._create_checkpoint(
checkpoint_type="pause",
current_node=current_node_id,
execution_path=path,
memory=memory,
next_node=current_node_id,
is_clean=True,
)
await checkpoint_store.save_checkpoint(pause_checkpoint)
pause_session_state["latest_checkpoint_id"] = pause_checkpoint.checkpoint_id
pause_session_state["resume_from_checkpoint"] = (
pause_checkpoint.checkpoint_id
)
# Return with paused status
return ExecutionResult(
success=False,
output=saved_memory,
path=path,
paused_at=current_node_id,
error="Execution paused by user request",
session_state=pause_session_state,
node_visit_counts=dict(node_visit_counts),
)
# Get current node
node_spec = graph.get_node(current_node_id)
if node_spec is None:
@@ -391,6 +545,27 @@ class GraphExecutor:
description=f"Validation errors for {current_node_id}: {validation_errors}",
)
# CHECKPOINT: node_start
if (
checkpoint_store
and checkpoint_config
and checkpoint_config.should_checkpoint_node_start()
):
checkpoint = self._create_checkpoint(
checkpoint_type="node_start",
current_node=node_spec.id,
execution_path=list(path),
memory=memory,
is_clean=(sum(node_retry_counts.values()) == 0),
)
if checkpoint_config.async_checkpoint:
# Non-blocking checkpoint save
asyncio.create_task(checkpoint_store.save_checkpoint(checkpoint))
else:
# Blocking checkpoint save
await checkpoint_store.save_checkpoint(checkpoint)
# Emit node-started event (skip event_loop nodes — they emit their own)
if self._event_bus and node_spec.node_type != "event_loop":
await self._event_bus.emit_node_loop_started(
@@ -564,13 +739,21 @@ class GraphExecutor:
execution_quality="failed",
)
# Save memory for potential resume
saved_memory = memory.read_all()
failure_session_state = {
"memory": saved_memory,
"execution_path": list(path),
"node_visit_counts": dict(node_visit_counts),
}
return ExecutionResult(
success=False,
error=(
f"Node '{node_spec.name}' failed after "
f"{max_retries} attempts: {result.error}"
),
output=memory.read_all(),
output=saved_memory,
steps_executed=steps,
total_tokens=total_tokens,
total_latency_ms=total_latency,
@@ -581,6 +764,7 @@ class GraphExecutor:
had_partial_failures=len(nodes_failed) > 0,
execution_quality="failed",
node_visit_counts=dict(node_visit_counts),
session_state=failure_session_state,
)
# Check if we just executed a pause node - if so, save state and return
@@ -703,6 +887,39 @@ class GraphExecutor:
break
next_spec = graph.get_node(next_node)
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
# CHECKPOINT: node_complete (after determining next node)
if (
checkpoint_store
and checkpoint_config
and checkpoint_config.should_checkpoint_node_complete()
):
checkpoint = self._create_checkpoint(
checkpoint_type="node_complete",
current_node=node_spec.id,
execution_path=list(path),
memory=memory,
next_node=next_node,
is_clean=(sum(node_retry_counts.values()) == 0),
)
if checkpoint_config.async_checkpoint:
asyncio.create_task(checkpoint_store.save_checkpoint(checkpoint))
else:
await checkpoint_store.save_checkpoint(checkpoint)
# Periodic checkpoint pruning
if (
checkpoint_store
and checkpoint_config
and checkpoint_config.should_prune_checkpoints(len(path))
):
asyncio.create_task(
checkpoint_store.prune_checkpoints(
max_age_days=checkpoint_config.checkpoint_max_age_days
)
)
current_node_id = next_node
# Update input_data for next node
@@ -760,6 +977,50 @@ class GraphExecutor:
node_visit_counts=dict(node_visit_counts),
)
except asyncio.CancelledError:
# Handle cancellation (e.g., TUI quit) - save as paused instead of failed
self.logger.info("⏸ Execution cancelled - saving state for resume")
# Save memory and state for resume
saved_memory = memory.read_all()
session_state_out: dict[str, Any] = {
"memory": saved_memory,
"execution_path": list(path),
"node_visit_counts": dict(node_visit_counts),
}
# Calculate quality metrics
total_retries_count = sum(node_retry_counts.values())
nodes_failed = [nid for nid, count in node_retry_counts.items() if count > 0]
exec_quality = "degraded" if total_retries_count > 0 else "clean"
if self.runtime_logger:
await self.runtime_logger.end_run(
status="paused",
duration_ms=total_latency,
node_path=path,
execution_quality=exec_quality,
)
# Return with paused status
return ExecutionResult(
success=False,
error="Execution paused by user",
output=saved_memory,
steps_executed=steps,
total_tokens=total_tokens,
total_latency_ms=total_latency,
path=path,
paused_at=current_node_id, # Save where we were
session_state=session_state_out,
total_retries=total_retries_count,
nodes_with_failures=nodes_failed,
retry_details=dict(node_retry_counts),
had_partial_failures=len(nodes_failed) > 0,
execution_quality=exec_quality,
node_visit_counts=dict(node_visit_counts),
)
except Exception as e:
import traceback
@@ -797,9 +1058,40 @@ class GraphExecutor:
execution_quality="failed",
)
# Save memory and state for potential resume
saved_memory = memory.read_all()
session_state_out: dict[str, Any] = {
"memory": saved_memory,
"execution_path": list(path),
"node_visit_counts": dict(node_visit_counts),
}
# Mark latest checkpoint for resume on failure
if checkpoint_store:
try:
checkpoints = await checkpoint_store.list_checkpoints()
if checkpoints:
# Find latest clean checkpoint
index = await checkpoint_store.load_index()
if index:
latest_clean = index.get_latest_clean_checkpoint()
if latest_clean:
session_state_out["resume_from_checkpoint"] = (
latest_clean.checkpoint_id
)
session_state_out["latest_checkpoint_id"] = (
latest_clean.checkpoint_id
)
self.logger.info(
f"💾 Marked checkpoint for resume: {latest_clean.checkpoint_id}"
)
except Exception as checkpoint_err:
self.logger.warning(f"Failed to mark checkpoint for resume: {checkpoint_err}")
return ExecutionResult(
success=False,
error=str(e),
output=saved_memory,
steps_executed=steps,
path=path,
total_retries=total_retries_count,
@@ -808,6 +1100,7 @@ class GraphExecutor:
had_partial_failures=len(nodes_failed) > 0,
execution_quality="failed",
node_visit_counts=dict(node_visit_counts),
session_state=session_state_out,
)
finally:
@@ -848,6 +1141,7 @@ class GraphExecutor:
goal=goal, # Pass Goal object for LLM-powered routers
max_tokens=max_tokens,
runtime_logger=self.runtime_logger,
pause_event=self._pause_requested, # Pass pause event for granular control
)
# Valid node types - no ambiguous "llm" type allowed
@@ -1360,3 +1654,50 @@ class GraphExecutor:
def register_function(self, node_id: str, func: Callable) -> None:
"""Register a function as a node."""
self.node_registry[node_id] = FunctionNode(func)
def request_pause(self) -> None:
"""
Request graceful pause of the current execution.
The execution will pause at the next node boundary after the current
node completes. A checkpoint will be saved at the pause point, allowing
the execution to be resumed later.
This method is safe to call from any thread.
"""
self._pause_requested.set()
self.logger.info("⏸ Pause requested - will pause at next node boundary")
def _create_checkpoint(
self,
checkpoint_type: str,
current_node: str,
execution_path: list[str],
memory: SharedMemory,
next_node: str | None = None,
is_clean: bool = True,
) -> Checkpoint:
"""
Create a checkpoint from current execution state.
Args:
checkpoint_type: Type of checkpoint (node_start, node_complete)
current_node: Current node ID
execution_path: Nodes executed so far
memory: SharedMemory instance
next_node: Next node to execute (for node_complete checkpoints)
is_clean: Whether execution was clean up to this point
Returns:
New Checkpoint instance
"""
return Checkpoint.create(
checkpoint_type=checkpoint_type,
session_id=self._storage_path.name if self._storage_path else "unknown",
current_node=current_node,
execution_path=execution_path,
shared_memory=memory.read_all(),
next_node=next_node,
is_clean=is_clean,
)
+3
View File
@@ -480,6 +480,9 @@ class NodeContext:
# Runtime logging (optional)
runtime_logger: Any = None # RuntimeLogger | None — uses Any to avoid import
# Pause control (optional) - asyncio.Event for pause requests
pause_event: Any = None # asyncio.Event | None
@dataclass
class NodeResult:
+190 -1
View File
@@ -63,6 +63,18 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
default=None,
help="LLM model to use (any LiteLLM-compatible name)",
)
run_parser.add_argument(
"--resume-session",
type=str,
default=None,
help="Resume from a specific session ID",
)
run_parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="Resume from a specific checkpoint (requires --resume-session)",
)
run_parser.set_defaults(func=cmd_run)
# info command
@@ -196,6 +208,129 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
)
tui_parser.set_defaults(func=cmd_tui)
# sessions command group (checkpoint/resume management)
sessions_parser = subparsers.add_parser(
"sessions",
help="Manage agent sessions",
description="List, inspect, and manage agent execution sessions.",
)
sessions_subparsers = sessions_parser.add_subparsers(
dest="sessions_cmd",
help="Session management commands",
)
# sessions list
sessions_list_parser = sessions_subparsers.add_parser(
"list",
help="List agent sessions",
description="List all sessions for an agent.",
)
sessions_list_parser.add_argument(
"agent_path",
type=str,
help="Path to agent folder",
)
sessions_list_parser.add_argument(
"--status",
choices=["all", "active", "failed", "completed", "paused"],
default="all",
help="Filter by session status (default: all)",
)
sessions_list_parser.add_argument(
"--has-checkpoints",
action="store_true",
help="Show only sessions with checkpoints",
)
sessions_list_parser.set_defaults(func=cmd_sessions_list)
# sessions show
sessions_show_parser = sessions_subparsers.add_parser(
"show",
help="Show session details",
description="Display detailed information about a specific session.",
)
sessions_show_parser.add_argument(
"agent_path",
type=str,
help="Path to agent folder",
)
sessions_show_parser.add_argument(
"session_id",
type=str,
help="Session ID to inspect",
)
sessions_show_parser.add_argument(
"--json",
action="store_true",
help="Output as JSON",
)
sessions_show_parser.set_defaults(func=cmd_sessions_show)
# sessions checkpoints
sessions_checkpoints_parser = sessions_subparsers.add_parser(
"checkpoints",
help="List session checkpoints",
description="List all checkpoints for a session.",
)
sessions_checkpoints_parser.add_argument(
"agent_path",
type=str,
help="Path to agent folder",
)
sessions_checkpoints_parser.add_argument(
"session_id",
type=str,
help="Session ID",
)
sessions_checkpoints_parser.set_defaults(func=cmd_sessions_checkpoints)
# pause command
pause_parser = subparsers.add_parser(
"pause",
help="Pause running session",
description="Request graceful pause of a running agent session.",
)
pause_parser.add_argument(
"agent_path",
type=str,
help="Path to agent folder",
)
pause_parser.add_argument(
"session_id",
type=str,
help="Session ID to pause",
)
pause_parser.set_defaults(func=cmd_pause)
# resume command
resume_parser = subparsers.add_parser(
"resume",
help="Resume session from checkpoint",
description="Resume a paused or failed session from a checkpoint.",
)
resume_parser.add_argument(
"agent_path",
type=str,
help="Path to agent folder",
)
resume_parser.add_argument(
"session_id",
type=str,
help="Session ID to resume",
)
resume_parser.add_argument(
"--checkpoint",
"-c",
type=str,
help="Specific checkpoint ID to resume from (default: latest)",
)
resume_parser.add_argument(
"--tui",
action="store_true",
help="Resume in TUI dashboard mode",
)
resume_parser.set_defaults(func=cmd_resume)
def cmd_run(args: argparse.Namespace) -> int:
"""Run an exported agent."""
@@ -253,7 +388,11 @@ def cmd_run(args: argparse.Namespace) -> int:
if runner._agent_runtime and not runner._agent_runtime.is_running:
await runner._agent_runtime.start()
app = AdenTUI(runner._agent_runtime)
app = AdenTUI(
runner._agent_runtime,
resume_session=getattr(args, "resume_session", None),
resume_checkpoint=getattr(args, "checkpoint", None),
)
# TUI handles execution via ChatRepl — user submits input,
# ChatRepl calls runtime.trigger_and_wait(). No auto-launch.
@@ -1432,3 +1571,53 @@ def _interactive_multi(agents_dir: Path) -> int:
orchestrator.cleanup()
return 0
def cmd_sessions_list(args: argparse.Namespace) -> int:
"""List agent sessions."""
print("⚠ Sessions list command not yet implemented")
print("This will be available once checkpoint infrastructure is complete.")
print(f"\nAgent: {args.agent_path}")
print(f"Status filter: {args.status}")
print(f"Has checkpoints: {args.has_checkpoints}")
return 1
def cmd_sessions_show(args: argparse.Namespace) -> int:
"""Show detailed session information."""
print("⚠ Session show command not yet implemented")
print("This will be available once checkpoint infrastructure is complete.")
print(f"\nAgent: {args.agent_path}")
print(f"Session: {args.session_id}")
return 1
def cmd_sessions_checkpoints(args: argparse.Namespace) -> int:
"""List checkpoints for a session."""
print("⚠ Session checkpoints command not yet implemented")
print("This will be available once checkpoint infrastructure is complete.")
print(f"\nAgent: {args.agent_path}")
print(f"Session: {args.session_id}")
return 1
def cmd_pause(args: argparse.Namespace) -> int:
"""Pause a running session."""
print("⚠ Pause command not yet implemented")
print("This will be available once executor pause integration is complete.")
print(f"\nAgent: {args.agent_path}")
print(f"Session: {args.session_id}")
return 1
def cmd_resume(args: argparse.Namespace) -> int:
"""Resume a session from checkpoint."""
print("⚠ Resume command not yet implemented")
print("This will be available once checkpoint resume integration is complete.")
print(f"\nAgent: {args.agent_path}")
print(f"Session: {args.session_id}")
if args.checkpoint:
print(f"Checkpoint: {args.checkpoint}")
if args.tui:
print("Mode: TUI")
return 1
+12
View File
@@ -741,6 +741,17 @@ class AgentRunner:
# Create AgentRuntime with all entry points
log_store = RuntimeLogStore(base_path=self._storage_path / "runtime_logs")
# Enable checkpointing by default for resumable sessions
from framework.graph.checkpoint_config import CheckpointConfig
checkpoint_config = CheckpointConfig(
enabled=True,
checkpoint_on_node_start=False, # Only checkpoint after nodes complete
checkpoint_on_node_complete=True,
checkpoint_max_age_days=7,
async_checkpoint=True, # Non-blocking
)
self._agent_runtime = create_agent_runtime(
graph=self.graph,
goal=self.goal,
@@ -750,6 +761,7 @@ class AgentRunner:
tools=tools,
tool_executor=tool_executor,
runtime_log_store=log_store,
checkpoint_config=checkpoint_config,
)
async def run(
@@ -0,0 +1,842 @@
# Resumable Sessions Design
## Problem Statement
Currently, when an agent encounters a failure during execution (e.g., credential validation, API errors, tool failures), the entire session is lost. This creates a poor user experience, especially when:
1. The agent has completed significant work before the failure
2. The failure is recoverable (e.g., adding missing credentials)
3. The user wants to retry from the exact failure point without redoing work
## Design Goals
1. **Crash Recovery**: Sessions can resume after process crashes or errors
2. **Partial Completion**: Preserve work done by nodes that completed successfully
3. **Flexible Resume Points**: Resume from exact failure point or previous checkpoints
4. **State Consistency**: Guarantee consistent SharedMemory and conversation state
5. **Minimal Overhead**: Checkpointing shouldn't significantly impact performance
6. **User Control**: Users can inspect, modify, and resume sessions explicitly
## Architecture
### 1. Checkpoint System
#### Checkpoint Types
**Automatic Checkpoints** (saved automatically by framework):
- `node_start`: Before each node begins execution
- `node_complete`: After each node successfully completes
- `edge_transition`: Before traversing to next node
- `loop_iteration`: At each iteration in EventLoopNode (optional)
**Manual Checkpoints** (triggered by agent designer):
- `safe_point`: Explicitly marked safe points in graph
- `user_checkpoint`: Before awaiting user input in client-facing nodes
#### Checkpoint Data Structure
```python
@dataclass
class Checkpoint:
"""Single checkpoint in execution timeline."""
# Identity
checkpoint_id: str # Format: checkpoint_{timestamp}_{uuid_short}
session_id: str
checkpoint_type: str # "node_start", "node_complete", etc.
# Timestamps
created_at: str # ISO 8601
# Execution state
current_node: str | None
next_node: str | None # For edge_transition checkpoints
execution_path: list[str] # Nodes executed so far
# Memory state (snapshot)
shared_memory: dict[str, Any] # Full SharedMemory._data
# Per-node conversation state references
# (actual conversations stored separately, reference by node_id)
conversation_states: dict[str, str] # {node_id: conversation_checkpoint_id}
# Output accumulator state
accumulated_outputs: dict[str, Any]
# Execution metrics (for resuming quality tracking)
metrics_snapshot: dict[str, Any]
# Metadata
is_clean: bool # True if no failures/retries before this checkpoint
can_resume_from: bool # False if checkpoint is in unstable state
description: str # Human-readable checkpoint description
```
#### Storage Structure
```
~/.hive/agents/{agent_name}/
└── sessions/
└── session_YYYYMMDD_HHMMSS_{uuid}/
├── state.json # Session state (existing)
├── checkpoints/
│ ├── index.json # Checkpoint index/manifest
│ ├── checkpoint_1.json # Individual checkpoints
│ ├── checkpoint_2.json
│ └── checkpoint_N.json
├── conversations/ # Per-node conversation state (existing)
│ ├── node_id_1/
│ │ ├── parts/
│ │ ├── meta.json
│ │ └── cursor.json
│ └── node_id_2/...
├── data/ # Spillover artifacts (existing)
└── logs/ # L1/L2/L3 logs (existing)
```
**Checkpoint Index Format** (`checkpoints/index.json`):
```json
{
"session_id": "session_20260208_143022_abc12345",
"checkpoints": [
{
"checkpoint_id": "checkpoint_20260208_143030_xyz123",
"type": "node_complete",
"created_at": "2026-02-08T14:30:30.123Z",
"current_node": "collector",
"is_clean": true,
"can_resume_from": true,
"description": "Completed collector node successfully"
},
{
"checkpoint_id": "checkpoint_20260208_143045_abc789",
"type": "node_start",
"created_at": "2026-02-08T14:30:45.456Z",
"current_node": "analyzer",
"is_clean": true,
"can_resume_from": true,
"description": "Starting analyzer node"
}
],
"latest_checkpoint_id": "checkpoint_20260208_143045_abc789",
"total_checkpoints": 2
}
```
### 2. Resume Mechanism
#### Resume Flow
```python
# High-level resume flow
async def resume_session(
session_id: str,
checkpoint_id: str | None = None, # None = resume from latest
modifications: dict[str, Any] | None = None, # Override memory values
) -> ExecutionResult:
"""
Resume a session from a checkpoint.
Args:
session_id: Session to resume
checkpoint_id: Specific checkpoint (None = latest)
modifications: Optional memory/state modifications before resume
Returns:
ExecutionResult with resumed execution
"""
# 1. Load session state
session_state = await session_store.read_state(session_id)
# 2. Verify session is resumable
if not session_state.is_resumable:
raise ValueError(f"Session {session_id} is not resumable")
# 3. Load checkpoint
checkpoint = await checkpoint_store.load_checkpoint(
session_id,
checkpoint_id or session_state.progress.resume_from
)
# 4. Restore state
# - Restore SharedMemory from checkpoint.shared_memory
# - Restore per-node conversations from checkpoint.conversation_states
# - Restore output accumulator from checkpoint.accumulated_outputs
# - Apply modifications if provided
# 5. Resume execution from checkpoint.next_node or checkpoint.current_node
result = await executor.execute(
graph=graph,
goal=goal,
memory=restored_memory,
entry_point=checkpoint.next_node or checkpoint.current_node,
session_state=restored_session_state,
)
# 6. Update session state with resumed execution
await session_store.write_state(session_id, updated_state)
return result
```
#### Checkpoint Restoration
```python
@dataclass
class CheckpointStore:
"""Manages checkpoint storage and retrieval."""
async def save_checkpoint(
self,
session_id: str,
checkpoint: Checkpoint,
) -> None:
"""Save a checkpoint atomically."""
# 1. Write checkpoint file: checkpoints/checkpoint_{id}.json
# 2. Update index: checkpoints/index.json
# 3. Use atomic write for crash safety
async def load_checkpoint(
self,
session_id: str,
checkpoint_id: str | None = None,
) -> Checkpoint | None:
"""Load a checkpoint by ID or latest."""
# 1. Read checkpoint index
# 2. Find checkpoint by ID (or latest if None)
# 3. Load and deserialize checkpoint file
async def list_checkpoints(
self,
session_id: str,
checkpoint_type: str | None = None,
is_clean: bool | None = None,
) -> list[Checkpoint]:
"""List all checkpoints for a session with optional filters."""
async def delete_checkpoint(
self,
session_id: str,
checkpoint_id: str,
) -> bool:
"""Delete a specific checkpoint."""
async def prune_checkpoints(
self,
session_id: str,
keep_count: int = 10,
keep_clean_only: bool = False,
) -> int:
"""Prune old checkpoints, keeping most recent N."""
```
### 3. GraphExecutor Integration
#### Modified Execution Loop
```python
# In GraphExecutor.execute()
async def execute(
self,
graph: GraphSpec,
goal: Goal,
memory: SharedMemory | None = None,
entry_point: str = "start",
session_state: dict[str, Any] | None = None,
checkpoint_config: CheckpointConfig | None = None,
) -> ExecutionResult:
"""
Execute graph with checkpointing support.
New parameters:
checkpoint_config: Configuration for checkpointing behavior
"""
# Initialize checkpoint store
checkpoint_store = CheckpointStore(storage_path / "checkpoints")
# Restore from checkpoint if session_state indicates resume
if session_state and session_state.get("resume_from"):
checkpoint = await checkpoint_store.load_checkpoint(
session_id,
session_state["resume_from"]
)
memory = self._restore_memory_from_checkpoint(checkpoint)
entry_point = checkpoint.next_node or checkpoint.current_node
current_node = entry_point
while current_node:
# CHECKPOINT: node_start
if checkpoint_config and checkpoint_config.checkpoint_on_node_start:
await self._save_checkpoint(
checkpoint_store,
checkpoint_type="node_start",
current_node=current_node,
memory=memory,
# ... other state
)
try:
# Execute node
result = await self._execute_node(current_node, memory, context)
# CHECKPOINT: node_complete
if checkpoint_config and checkpoint_config.checkpoint_on_node_complete:
await self._save_checkpoint(
checkpoint_store,
checkpoint_type="node_complete",
current_node=current_node,
memory=memory,
# ... other state
)
except Exception as e:
# On failure, mark current checkpoint as resume point
await self._mark_failure_checkpoint(
checkpoint_store,
current_node=current_node,
error=str(e),
)
raise
# Find next edge
next_node = self._find_next_node(current_node, result, memory)
# CHECKPOINT: edge_transition
if next_node and checkpoint_config and checkpoint_config.checkpoint_on_edge:
await self._save_checkpoint(
checkpoint_store,
checkpoint_type="edge_transition",
current_node=current_node,
next_node=next_node,
memory=memory,
# ... other state
)
current_node = next_node
```
### 4. EventLoopNode Integration
#### Conversation State Checkpointing
EventLoopNode already has conversation persistence via `ConversationStore`. For resumability:
```python
class EventLoopNode:
async def execute(self, ctx: NodeContext) -> NodeResult:
"""Execute with checkpoint support."""
# Try to restore from checkpoint
if ctx.checkpoint_id:
conversation = await self._restore_conversation(ctx.checkpoint_id)
output_accumulator = await OutputAccumulator.restore(self.store)
else:
# Fresh start
conversation = await self._initialize_conversation(ctx)
output_accumulator = OutputAccumulator(store=self.store)
# Event loop with periodic checkpointing
iteration = 0
while iteration < self.config.max_iterations:
# Optional: checkpoint every N iterations
if self.config.checkpoint_every_n_iterations:
if iteration % self.config.checkpoint_every_n_iterations == 0:
await self._save_loop_checkpoint(
conversation,
output_accumulator,
iteration,
)
# ... rest of event loop
iteration += 1
```
**Note**: EventLoopNode conversation state is already persisted to disk after each turn via `ConversationStore`, so it's naturally resumable. We just need to:
1. Track which conversation checkpoint to restore from
2. Ensure output accumulator state is also restored
### 5. User-Facing API
#### MCP Tools for Resume
```python
# In tools/src/aden_tools/tools/session_management/
@tool
async def list_resumable_sessions(
agent_work_dir: str,
status: str = "failed", # "failed", "paused", "cancelled"
limit: int = 20,
) -> dict:
"""
List sessions that can be resumed.
Returns:
{
"sessions": [
{
"session_id": "session_20260208_143022_abc12345",
"status": "failed",
"error": "Missing API key: OPENAI_API_KEY",
"failed_at_node": "analyzer",
"last_checkpoint": "checkpoint_20260208_143045_abc789",
"created_at": "2026-02-08T14:30:22Z",
"updated_at": "2026-02-08T14:30:45Z"
}
],
"total": 1
}
"""
@tool
async def list_session_checkpoints(
agent_work_dir: str,
session_id: str,
checkpoint_type: str = "", # Filter by type
clean_only: bool = False, # Only show clean checkpoints
) -> dict:
"""
List all checkpoints for a session.
Returns:
{
"session_id": "session_20260208_143022_abc12345",
"checkpoints": [
{
"checkpoint_id": "checkpoint_20260208_143030_xyz123",
"type": "node_complete",
"created_at": "2026-02-08T14:30:30Z",
"current_node": "collector",
"is_clean": true,
"can_resume_from": true,
"description": "Completed collector node successfully"
},
...
]
}
"""
@tool
async def inspect_checkpoint(
agent_work_dir: str,
session_id: str,
checkpoint_id: str,
include_memory: bool = False, # Include full memory state
) -> dict:
"""
Inspect a checkpoint's detailed state.
Returns:
{
"checkpoint_id": "checkpoint_20260208_143030_xyz123",
"type": "node_complete",
"current_node": "collector",
"execution_path": ["start", "collector"],
"accumulated_outputs": {
"twitter_handles": ["@user1", "@user2"]
},
"memory": {...}, # If include_memory=True
"metrics_snapshot": {
"total_retries": 2,
"nodes_with_failures": []
}
}
"""
@tool
async def resume_session(
agent_work_dir: str,
session_id: str,
checkpoint_id: str = "", # Empty = latest checkpoint
memory_modifications: str = "{}", # JSON string of memory overrides
) -> dict:
"""
Resume a session from a checkpoint.
Args:
agent_work_dir: Path to agent workspace
session_id: Session to resume
checkpoint_id: Specific checkpoint (empty = latest)
memory_modifications: JSON object with memory key overrides
Returns:
{
"session_id": "session_20260208_143022_abc12345",
"resumed_from": "checkpoint_20260208_143045_abc789",
"status": "active", # Now actively running
"message": "Session resumed successfully from checkpoint_20260208_143045_abc789"
}
"""
```
#### CLI Commands
```bash
# List resumable sessions
hive sessions list --agent twitter_outreach --status failed
# Show checkpoints for a session
hive sessions checkpoints session_20260208_143022_abc12345
# Inspect a checkpoint
hive sessions inspect session_20260208_143022_abc12345 checkpoint_20260208_143045_abc789
# Resume a session
hive sessions resume session_20260208_143022_abc12345
# Resume from specific checkpoint
hive sessions resume session_20260208_143022_abc12345 --checkpoint checkpoint_20260208_143030_xyz123
# Resume with memory modifications (e.g., after adding credentials)
hive sessions resume session_20260208_143022_abc12345 --set api_key=sk-...
```
### 6. Configuration
#### CheckpointConfig
```python
@dataclass
class CheckpointConfig:
"""Configuration for checkpoint behavior."""
# When to checkpoint
checkpoint_on_node_start: bool = True
checkpoint_on_node_complete: bool = True
checkpoint_on_edge: bool = False # Usually redundant with node_start
checkpoint_on_loop_iteration: bool = False # Can be expensive
checkpoint_every_n_iterations: int = 0 # 0 = disabled
# Pruning
max_checkpoints_per_session: int = 100
prune_after_node_count: int = 10 # Prune every N nodes
keep_clean_checkpoints_only: bool = False
# Performance
async_checkpoint: bool = True # Don't block execution on checkpoint writes
# What to include
include_conversation_snapshots: bool = True
include_full_memory: bool = True
```
#### Agent-Level Configuration
```python
# In agent.py or config.py
class MyAgent(Agent):
def get_checkpoint_config(self) -> CheckpointConfig:
"""Override to customize checkpoint behavior."""
return CheckpointConfig(
checkpoint_on_node_start=True,
checkpoint_on_node_complete=True,
checkpoint_every_n_iterations=5, # Checkpoint every 5 iterations in loops
max_checkpoints_per_session=50,
)
```
## Implementation Plan
### Phase 1: Core Checkpoint Infrastructure (Week 1)
1. **Create checkpoint schemas**
- `Checkpoint` dataclass
- `CheckpointIndex` for manifest
- Serialization/deserialization
2. **Implement CheckpointStore**
- `save_checkpoint()` with atomic writes
- `load_checkpoint()` with deserialization
- `list_checkpoints()` with filtering
- `prune_checkpoints()` for cleanup
3. **Update SessionState schema**
- Add `resume_from_checkpoint_id` field
- Add `checkpoints_enabled` flag
### Phase 2: GraphExecutor Integration (Week 2)
1. **Modify GraphExecutor**
- Add `CheckpointConfig` parameter
- Implement checkpoint saving at node boundaries
- Implement checkpoint restoration logic
- Handle memory state snapshots
2. **Update execution loop**
- Checkpoint before node execution
- Checkpoint after successful completion
- Mark failure checkpoints on errors
### Phase 3: EventLoopNode Integration (Week 3)
1. **Enhance conversation restoration**
- Link checkpoints to conversation states
- Ensure OutputAccumulator is checkpointed
- Test loop resumption from middle of execution
2. **Add optional loop iteration checkpoints**
- Configurable iteration frequency
- Balance between granularity and performance
### Phase 4: User-Facing Features (Week 4)
1. **Implement MCP tools**
- `list_resumable_sessions`
- `list_session_checkpoints`
- `inspect_checkpoint`
- `resume_session`
2. **Add CLI commands**
- `hive sessions list`
- `hive sessions checkpoints`
- `hive sessions inspect`
- `hive sessions resume`
3. **Update TUI**
- Show resumable sessions in UI
- Allow resume from TUI interface
### Phase 5: Testing & Documentation (Week 5)
1. **Write comprehensive tests**
- Unit tests for CheckpointStore
- Integration tests for resume flow
- Edge case testing (concurrent checkpoints, corruption, etc.)
2. **Performance testing**
- Measure checkpoint overhead
- Optimize async checkpoint writing
- Test with large memory states
3. **Documentation**
- Update skills with resume patterns
- Document checkpoint configuration
- Add troubleshooting guide
## Performance Considerations
### Checkpoint Overhead
**Estimated overhead per checkpoint**:
- Memory serialization: ~5-10ms for typical state (< 1MB)
- File I/O: ~10-20ms for atomic write
- Total: ~15-30ms per checkpoint
**Mitigation strategies**:
1. **Async checkpointing**: Don't block execution on writes
2. **Selective checkpointing**: Only checkpoint at important boundaries
3. **Incremental checkpoints**: Store deltas instead of full state (future)
4. **Compression**: Compress large memory states before writing
### Storage Size
**Typical checkpoint size**:
- Small memory state (< 100KB): ~50-100KB per checkpoint
- Medium memory state (< 1MB): ~500KB-1MB per checkpoint
- Large memory state (> 1MB): ~1-5MB per checkpoint
**Mitigation strategies**:
1. **Pruning**: Keep only N most recent checkpoints
2. **Clean-only retention**: Only keep checkpoints from clean execution
3. **Compression**: Use gzip for checkpoint files
4. **Archiving**: Move old checkpoints to archive storage
## Error Handling
### Checkpoint Save Failures
**Scenarios**:
- Disk full
- Permission errors
- Serialization failures
- Concurrent writes
**Handling**:
```python
try:
await checkpoint_store.save_checkpoint(session_id, checkpoint)
except CheckpointSaveError as e:
# Log warning but don't fail execution
logger.warning(f"Failed to save checkpoint: {e}")
# Continue execution without checkpoint
```
### Checkpoint Load Failures
**Scenarios**:
- Checkpoint file corrupted
- Checkpoint format incompatible
- Referenced conversation state missing
**Handling**:
```python
try:
checkpoint = await checkpoint_store.load_checkpoint(session_id, checkpoint_id)
except CheckpointLoadError as e:
# Try to find previous valid checkpoint
checkpoints = await checkpoint_store.list_checkpoints(session_id)
for cp in reversed(checkpoints):
try:
checkpoint = await checkpoint_store.load_checkpoint(session_id, cp.checkpoint_id)
logger.info(f"Fell back to checkpoint {cp.checkpoint_id}")
break
except CheckpointLoadError:
continue
else:
raise ValueError(f"No valid checkpoints found for session {session_id}")
```
### Resume Failures
**Scenarios**:
- Checkpoint state inconsistent with current graph
- Node no longer exists in updated agent code
- Memory keys missing required values
**Handling**:
1. **Validation**: Verify checkpoint compatibility before resume
2. **Graceful degradation**: Resume from earlier checkpoint if possible
3. **User notification**: Clear error messages about why resume failed
## Migration Path
### Backward Compatibility
**Existing sessions** (without checkpoints):
- Can still be executed normally
- Checkpoint system is opt-in per agent
- No breaking changes to existing APIs
**Enabling checkpoints**:
```python
# Option 1: Agent-level default
class MyAgent(Agent):
checkpoint_config = CheckpointConfig(
checkpoint_on_node_complete=True,
)
# Option 2: Runtime override
runtime = create_agent_runtime(
agent=my_agent,
checkpoint_config=CheckpointConfig(...),
)
# Option 3: Per-execution
result = await executor.execute(
graph=graph,
goal=goal,
checkpoint_config=CheckpointConfig(...),
)
```
### Gradual Rollout
1. **Phase 1**: Core infrastructure, no user-facing features
2. **Phase 2**: Opt-in for specific agents via config
3. **Phase 3**: User-facing MCP tools and CLI
4. **Phase 4**: Enable by default for all new agents
5. **Phase 5**: TUI integration
## Future Enhancements
### 1. Incremental Checkpoints
Instead of full state snapshots, store only deltas:
```python
@dataclass
class IncrementalCheckpoint:
"""Checkpoint with only changed state."""
base_checkpoint_id: str # Parent checkpoint
memory_delta: dict[str, Any] # Only changed keys
added_outputs: dict[str, Any] # Only new outputs
```
### 2. Distributed Checkpointing
For long-running agents, checkpoint to cloud storage:
```python
checkpoint_config = CheckpointConfig(
storage_backend="s3", # or "gcs", "azure"
storage_url="s3://my-bucket/checkpoints/",
)
```
### 3. Checkpoint Compression
Compress large memory states:
```python
checkpoint_config = CheckpointConfig(
compress=True,
compression_threshold_bytes=100_000, # Compress if > 100KB
)
```
### 4. Smart Checkpoint Selection
Use heuristics to decide when to checkpoint:
```python
class SmartCheckpointStrategy:
def should_checkpoint(self, context: ExecutionContext) -> bool:
# Checkpoint after expensive nodes
if context.node_latency_ms > 30_000:
return True
# Checkpoint before risky operations
if context.node_id in ["api_call", "external_tool"]:
return True
# Checkpoint after significant memory changes
if context.memory_delta_size > 10:
return True
return False
```
## Security Considerations
### 1. Sensitive Data in Checkpoints
**Problem**: Checkpoints may contain sensitive data (API keys, credentials, PII)
**Mitigation**:
```python
@dataclass
class CheckpointConfig:
# Exclude sensitive keys from checkpoint
exclude_memory_keys: list[str] = field(default_factory=lambda: [
"api_key",
"credentials",
"access_token",
])
# Encrypt checkpoint files
encrypt_checkpoints: bool = True
encryption_key_source: str = "keychain" # or "env_var", "file"
```
### 2. Checkpoint Tampering
**Problem**: Malicious modification of checkpoint files
**Mitigation**:
```python
@dataclass
class Checkpoint:
# Add cryptographic signature
signature: str # HMAC of checkpoint content
def verify_signature(self, secret_key: str) -> bool:
"""Verify checkpoint hasn't been tampered with."""
...
```
## References
- [RUNTIME_LOGGING.md](./RUNTIME_LOGGING.md) - Current logging system
- [session_state.py](../schemas/session_state.py) - Session state schema
- [session_store.py](../storage/session_store.py) - Session storage
- [executor.py](../graph/executor.py) - Graph executor
- [event_loop_node.py](../graph/event_loop_node.py) - EventLoop implementation
+9
View File
@@ -12,6 +12,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
from framework.graph.checkpoint_config import CheckpointConfig
from framework.graph.executor import ExecutionResult
from framework.runtime.event_bus import EventBus
from framework.runtime.execution_stream import EntryPointSpec, ExecutionStream
@@ -102,6 +103,7 @@ class AgentRuntime:
tool_executor: Callable | None = None,
config: AgentRuntimeConfig | None = None,
runtime_log_store: Any = None,
checkpoint_config: CheckpointConfig | None = None,
):
"""
Initialize agent runtime.
@@ -115,11 +117,13 @@ class AgentRuntime:
tool_executor: Function to execute tools
config: Optional runtime configuration
runtime_log_store: Optional RuntimeLogStore for per-execution logging
checkpoint_config: Optional checkpoint configuration for resumable sessions
"""
self.graph = graph
self.goal = goal
self._config = config or AgentRuntimeConfig()
self._runtime_log_store = runtime_log_store
self._checkpoint_config = checkpoint_config
# Initialize storage
storage_path_obj = Path(storage_path) if isinstance(storage_path, str) else storage_path
@@ -222,6 +226,7 @@ class AgentRuntime:
result_retention_ttl_seconds=self._config.execution_result_ttl_seconds,
runtime_log_store=self._runtime_log_store,
session_store=self._session_store,
checkpoint_config=self._checkpoint_config,
)
await stream.start()
self._streams[ep_id] = stream
@@ -460,6 +465,7 @@ def create_agent_runtime(
config: AgentRuntimeConfig | None = None,
runtime_log_store: Any = None,
enable_logging: bool = True,
checkpoint_config: CheckpointConfig | None = None,
) -> AgentRuntime:
"""
Create and configure an AgentRuntime with entry points.
@@ -480,6 +486,8 @@ def create_agent_runtime(
If None and enable_logging=True, creates one automatically.
enable_logging: Whether to enable runtime logging (default: True).
Set to False to disable logging entirely.
checkpoint_config: Optional checkpoint configuration for resumable sessions.
If None, uses default checkpointing behavior.
Returns:
Configured AgentRuntime (not yet started)
@@ -500,6 +508,7 @@ def create_agent_runtime(
tool_executor=tool_executor,
config=config,
runtime_log_store=runtime_log_store,
checkpoint_config=checkpoint_config,
)
for spec in entry_points:
+46 -3
View File
@@ -17,6 +17,7 @@ from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any
from framework.graph.checkpoint_config import CheckpointConfig
from framework.graph.executor import ExecutionResult, GraphExecutor
from framework.runtime.shared_state import IsolationLevel, SharedStateManager
from framework.runtime.stream_runtime import StreamRuntime, StreamRuntimeAdapter
@@ -115,6 +116,7 @@ class ExecutionStream:
result_retention_ttl_seconds: float | None = None,
runtime_log_store: Any = None,
session_store: "SessionStore | None" = None,
checkpoint_config: CheckpointConfig | None = None,
):
"""
Initialize execution stream.
@@ -133,6 +135,7 @@ class ExecutionStream:
tool_executor: Function to execute tools
runtime_log_store: Optional RuntimeLogStore for per-execution logging
session_store: Optional SessionStore for unified session storage
checkpoint_config: Optional checkpoint configuration for resumable sessions
"""
self.stream_id = stream_id
self.entry_spec = entry_spec
@@ -148,6 +151,7 @@ class ExecutionStream:
self._result_retention_max = result_retention_max
self._result_retention_ttl_seconds = result_retention_ttl_seconds
self._runtime_log_store = runtime_log_store
self._checkpoint_config = checkpoint_config
self._session_store = session_store
# Create stream-scoped runtime
@@ -400,6 +404,7 @@ class ExecutionStream:
goal=self.goal,
input_data=ctx.input_data,
session_state=ctx.session_state,
checkpoint_config=self._checkpoint_config,
)
# Clean up executor reference
@@ -437,8 +442,42 @@ class ExecutionStream:
logger.debug(f"Execution {execution_id} completed: success={result.success}")
except asyncio.CancelledError:
ctx.status = "cancelled"
raise
# Execution was cancelled
# The executor catches CancelledError and returns a paused result,
# but if cancellation happened before executor started, we won't have a result
logger.info(f"Execution {execution_id} cancelled")
# Check if we have a result (executor completed and returned)
try:
_ = result # Check if result variable exists
has_result = True
except NameError:
has_result = False
result = ExecutionResult(
success=False,
error="Execution cancelled",
)
# Update context status based on result
if has_result and result.paused_at:
ctx.status = "paused"
ctx.completed_at = datetime.now()
else:
ctx.status = "cancelled"
# Clean up executor reference
self._active_executors.pop(execution_id, None)
# Store result with retention
self._record_execution_result(execution_id, result)
# Write session state
if has_result and result.paused_at:
await self._write_session_state(execution_id, ctx, result=result)
else:
await self._write_session_state(execution_id, ctx, error="Execution cancelled")
# Don't re-raise - we've handled it and saved state
except Exception as e:
ctx.status = "failed"
@@ -511,7 +550,11 @@ class ExecutionStream:
else:
status = SessionStatus.FAILED
elif error:
status = SessionStatus.FAILED
# Check if this is a cancellation
if ctx.status == "cancelled" or "cancelled" in error.lower():
status = SessionStatus.CANCELLED
else:
status = SessionStatus.FAILED
else:
status = SessionStatus.ACTIVE
+178
View File
@@ -0,0 +1,178 @@
"""
Checkpoint Schema - Execution state snapshots for resumability.
Checkpoints capture the execution state at strategic points (node boundaries,
iterations) to enable crash recovery and resume-from-failure scenarios.
"""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class Checkpoint(BaseModel):
"""
Single checkpoint in execution timeline.
Captures complete execution state at a specific point to enable
resuming from that exact point after failures or pauses.
"""
# Identity
checkpoint_id: str # Format: cp_{type}_{node_id}_{timestamp}
checkpoint_type: str # "node_start" | "node_complete" | "loop_iteration"
session_id: str
# Timestamps
created_at: str # ISO 8601 format
# Execution state
current_node: str | None = None
next_node: str | None = None # For edge_transition checkpoints
execution_path: list[str] = Field(default_factory=list) # Nodes executed so far
# State snapshots
shared_memory: dict[str, Any] = Field(default_factory=dict) # Full SharedMemory._data
accumulated_outputs: dict[str, Any] = Field(default_factory=dict) # Outputs accumulated so far
# Execution metrics (for resuming quality tracking)
metrics_snapshot: dict[str, Any] = Field(default_factory=dict)
# Metadata
is_clean: bool = True # True if no failures/retries before this checkpoint
description: str = "" # Human-readable checkpoint description
model_config = {"extra": "allow"}
@classmethod
def create(
cls,
checkpoint_type: str,
session_id: str,
current_node: str,
execution_path: list[str],
shared_memory: dict[str, Any],
next_node: str | None = None,
accumulated_outputs: dict[str, Any] | None = None,
metrics_snapshot: dict[str, Any] | None = None,
is_clean: bool = True,
description: str = "",
) -> "Checkpoint":
"""
Create a new checkpoint with generated ID and timestamp.
Args:
checkpoint_type: Type of checkpoint (node_start, node_complete, etc.)
session_id: Session this checkpoint belongs to
current_node: Node ID at checkpoint time
execution_path: List of node IDs executed so far
shared_memory: Full memory state snapshot
next_node: Next node to execute (for node_complete checkpoints)
accumulated_outputs: Outputs accumulated so far
metrics_snapshot: Execution metrics at checkpoint time
is_clean: Whether execution was clean up to this point
description: Human-readable description
Returns:
New Checkpoint instance
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_id = f"cp_{checkpoint_type}_{current_node}_{timestamp}"
if not description:
description = f"{checkpoint_type.replace('_', ' ').title()}: {current_node}"
return cls(
checkpoint_id=checkpoint_id,
checkpoint_type=checkpoint_type,
session_id=session_id,
created_at=datetime.now().isoformat(),
current_node=current_node,
next_node=next_node,
execution_path=execution_path,
shared_memory=shared_memory,
accumulated_outputs=accumulated_outputs or {},
metrics_snapshot=metrics_snapshot or {},
is_clean=is_clean,
description=description,
)
class CheckpointSummary(BaseModel):
"""
Lightweight checkpoint metadata for index listings.
Used in checkpoint index to provide fast scanning without
loading full checkpoint data.
"""
checkpoint_id: str
checkpoint_type: str
created_at: str
current_node: str | None = None
next_node: str | None = None
is_clean: bool = True
description: str = ""
model_config = {"extra": "allow"}
@classmethod
def from_checkpoint(cls, checkpoint: Checkpoint) -> "CheckpointSummary":
"""Create summary from full checkpoint."""
return cls(
checkpoint_id=checkpoint.checkpoint_id,
checkpoint_type=checkpoint.checkpoint_type,
created_at=checkpoint.created_at,
current_node=checkpoint.current_node,
next_node=checkpoint.next_node,
is_clean=checkpoint.is_clean,
description=checkpoint.description,
)
class CheckpointIndex(BaseModel):
"""
Manifest of all checkpoints for a session.
Provides fast lookup and filtering without loading
full checkpoint files.
"""
session_id: str
checkpoints: list[CheckpointSummary] = Field(default_factory=list)
latest_checkpoint_id: str | None = None
total_checkpoints: int = 0
model_config = {"extra": "allow"}
def add_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Add a checkpoint to the index."""
summary = CheckpointSummary.from_checkpoint(checkpoint)
self.checkpoints.append(summary)
self.latest_checkpoint_id = checkpoint.checkpoint_id
self.total_checkpoints = len(self.checkpoints)
def get_checkpoint_summary(self, checkpoint_id: str) -> CheckpointSummary | None:
"""Get checkpoint summary by ID."""
for summary in self.checkpoints:
if summary.checkpoint_id == checkpoint_id:
return summary
return None
def filter_by_type(self, checkpoint_type: str) -> list[CheckpointSummary]:
"""Filter checkpoints by type."""
return [cp for cp in self.checkpoints if cp.checkpoint_type == checkpoint_type]
def filter_by_node(self, node_id: str) -> list[CheckpointSummary]:
"""Filter checkpoints by current_node."""
return [cp for cp in self.checkpoints if cp.current_node == node_id]
def get_clean_checkpoints(self) -> list[CheckpointSummary]:
"""Get all clean checkpoints (no failures before them)."""
return [cp for cp in self.checkpoints if cp.is_clean]
def get_latest_clean_checkpoint(self) -> CheckpointSummary | None:
"""Get the most recent clean checkpoint."""
clean = self.get_clean_checkpoints()
return clean[-1] if clean else None
+14 -1
View File
@@ -91,10 +91,11 @@ class SessionState(BaseModel):
Version History:
- v1.0: Initial schema (2026-02-06)
- v1.1: Added checkpoint support (2026-02-08)
"""
# Schema version for forward/backward compatibility
schema_version: str = "1.0"
schema_version: str = "1.1"
# Identity
session_id: str # Format: session_YYYYMMDD_HHMMSS_{uuid_8char}
@@ -136,6 +137,10 @@ class SessionState(BaseModel):
# Isolation level (from ExecutionContext)
isolation_level: str = "shared"
# Checkpointing (for crash recovery and resume-from-failure)
checkpoint_enabled: bool = False
latest_checkpoint_id: str | None = None
model_config = {"extra": "allow"}
@computed_field
@@ -154,6 +159,14 @@ class SessionState(BaseModel):
"""Can this session be resumed?"""
return self.status == SessionStatus.PAUSED and self.progress.resume_from is not None
@computed_field
@property
def is_resumable_from_checkpoint(self) -> bool:
"""Can this session be resumed from a checkpoint?"""
# ANY session with checkpoints can be resumed (not just failed ones)
# This enables: pause/resume, iterative execution, continuation after completion
return self.checkpoint_enabled and self.latest_checkpoint_id is not None
@classmethod
def from_execution_result(
cls,
+325
View File
@@ -0,0 +1,325 @@
"""
Checkpoint Store - Manages checkpoint storage with atomic writes.
Handles saving, loading, listing, and pruning of execution checkpoints
for session resumability.
"""
import asyncio
import logging
from datetime import datetime, timedelta
from pathlib import Path
from framework.schemas.checkpoint import Checkpoint, CheckpointIndex, CheckpointSummary
from framework.utils.io import atomic_write
logger = logging.getLogger(__name__)
class CheckpointStore:
"""
Manages checkpoint storage with atomic writes.
Stores checkpoints in a session's checkpoints/ directory with
an index for fast lookup and filtering.
Directory structure:
checkpoints/
index.json # Checkpoint manifest
cp_{type}_{node}_{timestamp}.json # Individual checkpoints
"""
def __init__(self, base_path: Path):
"""
Initialize checkpoint store.
Args:
base_path: Session directory (e.g., ~/.hive/agents/agent_name/sessions/session_ID/)
"""
self.base_path = Path(base_path)
self.checkpoints_dir = self.base_path / "checkpoints"
self.index_path = self.checkpoints_dir / "index.json"
self._index_lock = asyncio.Lock()
async def save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""
Atomically save checkpoint and update index.
Uses temp file + rename for crash safety. Updates index
after checkpoint is persisted.
Args:
checkpoint: Checkpoint to save
Raises:
OSError: If file write fails
"""
def _write():
# Ensure directory exists
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
# Write checkpoint file atomically
checkpoint_path = self.checkpoints_dir / f"{checkpoint.checkpoint_id}.json"
with atomic_write(checkpoint_path) as f:
f.write(checkpoint.model_dump_json(indent=2))
logger.debug(f"Saved checkpoint {checkpoint.checkpoint_id}")
# Write checkpoint file (blocking I/O in thread)
await asyncio.to_thread(_write)
# Update index (with lock to prevent concurrent modifications)
async with self._index_lock:
await self._update_index_add(checkpoint)
async def load_checkpoint(
self,
checkpoint_id: str | None = None,
) -> Checkpoint | None:
"""
Load checkpoint by ID or latest.
Args:
checkpoint_id: Checkpoint ID to load, or None for latest
Returns:
Checkpoint object, or None if not found
"""
def _read(checkpoint_id: str) -> Checkpoint | None:
checkpoint_path = self.checkpoints_dir / f"{checkpoint_id}.json"
if not checkpoint_path.exists():
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
return None
try:
return Checkpoint.model_validate_json(checkpoint_path.read_text())
except Exception as e:
logger.error(f"Failed to load checkpoint {checkpoint_id}: {e}")
return None
# Load index to get checkpoint ID if not provided
if checkpoint_id is None:
index = await self.load_index()
if not index or not index.latest_checkpoint_id:
logger.warning("No checkpoints found in index")
return None
checkpoint_id = index.latest_checkpoint_id
return await asyncio.to_thread(_read, checkpoint_id)
async def load_index(self) -> CheckpointIndex | None:
"""
Load checkpoint index.
Returns:
CheckpointIndex or None if not found
"""
def _read() -> CheckpointIndex | None:
if not self.index_path.exists():
return None
try:
return CheckpointIndex.model_validate_json(self.index_path.read_text())
except Exception as e:
logger.error(f"Failed to load checkpoint index: {e}")
return None
return await asyncio.to_thread(_read)
async def list_checkpoints(
self,
checkpoint_type: str | None = None,
is_clean: bool | None = None,
) -> list[CheckpointSummary]:
"""
List checkpoints with optional filters.
Args:
checkpoint_type: Filter by type (node_start, node_complete)
is_clean: Filter by clean status
Returns:
List of CheckpointSummary objects
"""
index = await self.load_index()
if not index:
return []
checkpoints = index.checkpoints
# Apply filters
if checkpoint_type:
checkpoints = [cp for cp in checkpoints if cp.checkpoint_type == checkpoint_type]
if is_clean is not None:
checkpoints = [cp for cp in checkpoints if cp.is_clean == is_clean]
return checkpoints
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
"""
Delete a specific checkpoint.
Args:
checkpoint_id: Checkpoint ID to delete
Returns:
True if deleted, False if not found
"""
def _delete(checkpoint_id: str) -> bool:
checkpoint_path = self.checkpoints_dir / f"{checkpoint_id}.json"
if not checkpoint_path.exists():
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
return False
try:
checkpoint_path.unlink()
logger.info(f"Deleted checkpoint {checkpoint_id}")
return True
except Exception as e:
logger.error(f"Failed to delete checkpoint {checkpoint_id}: {e}")
return False
# Delete checkpoint file
deleted = await asyncio.to_thread(_delete, checkpoint_id)
if deleted:
# Update index (with lock)
async with self._index_lock:
await self._update_index_remove(checkpoint_id)
return deleted
async def prune_checkpoints(
self,
max_age_days: int = 7,
) -> int:
"""
Prune checkpoints older than max_age_days.
Args:
max_age_days: Maximum age in days (default 7)
Returns:
Number of checkpoints deleted
"""
index = await self.load_index()
if not index or not index.checkpoints:
return 0
# Calculate cutoff datetime
cutoff = datetime.now() - timedelta(days=max_age_days)
# Find old checkpoints
old_checkpoints = []
for cp in index.checkpoints:
try:
created = datetime.fromisoformat(cp.created_at)
if created < cutoff:
old_checkpoints.append(cp.checkpoint_id)
except Exception as e:
logger.warning(f"Failed to parse timestamp for {cp.checkpoint_id}: {e}")
# Delete old checkpoints
deleted_count = 0
for checkpoint_id in old_checkpoints:
if await self.delete_checkpoint(checkpoint_id):
deleted_count += 1
if deleted_count > 0:
logger.info(f"Pruned {deleted_count} checkpoints older than {max_age_days} days")
return deleted_count
async def checkpoint_exists(self, checkpoint_id: str) -> bool:
"""
Check if a checkpoint exists.
Args:
checkpoint_id: Checkpoint ID
Returns:
True if checkpoint exists
"""
def _check(checkpoint_id: str) -> bool:
checkpoint_path = self.checkpoints_dir / f"{checkpoint_id}.json"
return checkpoint_path.exists()
return await asyncio.to_thread(_check, checkpoint_id)
async def _update_index_add(self, checkpoint: Checkpoint) -> None:
"""
Update index after adding a checkpoint.
Should be called with _index_lock held.
Args:
checkpoint: Checkpoint that was added
"""
def _write(index: CheckpointIndex):
# Ensure directory exists
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
# Write index atomically
with atomic_write(self.index_path) as f:
f.write(index.model_dump_json(indent=2))
# Load or create index
index = await self.load_index()
if not index:
index = CheckpointIndex(
session_id=checkpoint.session_id,
checkpoints=[],
)
# Add checkpoint to index
index.add_checkpoint(checkpoint)
# Write updated index
await asyncio.to_thread(_write, index)
logger.debug(f"Updated index with checkpoint {checkpoint.checkpoint_id}")
async def _update_index_remove(self, checkpoint_id: str) -> None:
"""
Update index after removing a checkpoint.
Should be called with _index_lock held.
Args:
checkpoint_id: Checkpoint ID that was removed
"""
def _write(index: CheckpointIndex):
with atomic_write(self.index_path) as f:
f.write(index.model_dump_json(indent=2))
# Load index
index = await self.load_index()
if not index:
return
# Remove checkpoint from index
index.checkpoints = [cp for cp in index.checkpoints if cp.checkpoint_id != checkpoint_id]
# Update totals
index.total_checkpoints = len(index.checkpoints)
# Update latest_checkpoint_id if we removed the latest
if index.latest_checkpoint_id == checkpoint_id:
index.latest_checkpoint_id = (
index.checkpoints[-1].checkpoint_id if index.checkpoints else None
)
# Write updated index
await asyncio.to_thread(_write, index)
logger.debug(f"Removed checkpoint {checkpoint_id} from index")
+101 -4
View File
@@ -6,7 +6,7 @@ import time
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, Vertical
from textual.widgets import Footer, Label
from textual.widgets import Footer, Input, Label
from framework.runtime.agent_runtime import AgentRuntime
from framework.runtime.event_bus import AgentEvent, EventType
@@ -208,17 +208,24 @@ class AdenTUI(App):
Binding("ctrl+c", "ctrl_c", "Interrupt", show=False, priority=True),
Binding("super+c", "ctrl_c", "Copy", show=False, priority=True),
Binding("ctrl+s", "screenshot", "Screenshot (SVG)", show=True, priority=True),
Binding("ctrl+z", "pause_execution", "Pause", show=True, priority=True),
Binding("ctrl+r", "show_sessions", "Sessions", show=True, priority=True),
Binding("tab", "focus_next", "Next Panel", show=True),
Binding("shift+tab", "focus_previous", "Previous Panel", show=False),
]
def __init__(self, runtime: AgentRuntime):
def __init__(
self,
runtime: AgentRuntime,
resume_session: str | None = None,
resume_checkpoint: str | None = None,
):
super().__init__()
self.runtime = runtime
self.log_pane = LogPane()
self.graph_view = GraphOverview(runtime)
self.chat_repl = ChatRepl(runtime)
self.chat_repl = ChatRepl(runtime, resume_session, resume_checkpoint)
self.status_bar = StatusBar(graph_id=runtime.graph.id)
self.is_ready = False
@@ -528,9 +535,99 @@ class AdenTUI(App):
except Exception as e:
self.notify(f"Screenshot failed: {e}", severity="error", timeout=5)
def action_pause_execution(self) -> None:
"""Immediately pause execution by cancelling task (bound to Ctrl+Z)."""
try:
chat_repl = self.query_one(ChatRepl)
if not chat_repl._current_exec_id:
self.notify(
"No active execution to pause",
severity="information",
timeout=3,
)
return
# Find and cancel the execution task - executor will catch and save state
task_cancelled = False
for stream in self.runtime._streams.values():
exec_id = chat_repl._current_exec_id
task = stream._execution_tasks.get(exec_id)
if task and not task.done():
task.cancel()
task_cancelled = True
self.notify(
"⏸ Execution paused - state saved",
severity="information",
timeout=3,
)
break
if not task_cancelled:
self.notify(
"Execution already completed",
severity="information",
timeout=2,
)
except Exception as e:
self.notify(
f"Error pausing execution: {e}",
severity="error",
timeout=5,
)
def action_show_sessions(self) -> None:
"""Show sessions list (bound to Ctrl+R)."""
# Send /sessions command to chat input
try:
chat_repl = self.query_one(ChatRepl)
chat_input = chat_repl.query_one("#chat-input", Input)
chat_input.value = "/sessions"
# Trigger submission
self.notify(
"💡 Type /sessions in the chat to see all sessions",
severity="information",
timeout=3,
)
except Exception:
self.notify(
"Use /sessions command to see all sessions",
severity="information",
timeout=3,
)
async def on_unmount(self) -> None:
"""Cleanup on app shutdown."""
"""Cleanup on app shutdown - cancel execution which will save state."""
self.is_ready = False
# Cancel any active execution - the executor will catch CancelledError
# and save current state as paused (no waiting needed!)
try:
import asyncio
chat_repl = self.query_one(ChatRepl)
if chat_repl._current_exec_id:
# Find the stream with this execution
for stream in self.runtime._streams.values():
exec_id = chat_repl._current_exec_id
task = stream._execution_tasks.get(exec_id)
if task and not task.done():
# Cancel the task - executor will catch and save state
task.cancel()
try:
# Wait for executor to save state (may take a few seconds)
# Longer timeout for quit to ensure state is properly saved
await asyncio.wait_for(task, timeout=5.0)
except (TimeoutError, asyncio.CancelledError):
# Expected - task was cancelled
# If timeout, state may not be fully saved
pass
except Exception:
# Ignore other exceptions during cleanup
pass
break
except Exception:
pass
try:
if hasattr(self, "_subscription_id"):
self.runtime.unsubscribe_from_events(self._subscription_id)
+601 -8
View File
@@ -17,6 +17,7 @@ Client-facing input:
import asyncio
import re
import threading
from pathlib import Path
from typing import Any
from textual.app import ComposeResult
@@ -69,13 +70,20 @@ class ChatRepl(Vertical):
}
"""
def __init__(self, runtime: AgentRuntime):
def __init__(
self,
runtime: AgentRuntime,
resume_session: str | None = None,
resume_checkpoint: str | None = None,
):
super().__init__()
self.runtime = runtime
self._current_exec_id: str | None = None
self._streaming_snapshot: str = ""
self._waiting_for_input: bool = False
self._input_node_id: str | None = None
self._resume_session = resume_session
self._resume_checkpoint = resume_checkpoint
# Dedicated event loop for agent execution.
# Keeps blocking runtime code (LLM calls, MCP tools) off
@@ -121,10 +129,589 @@ class ChatRepl(Vertical):
if was_at_bottom:
history.scroll_end(animate=False)
async def _handle_command(self, command: str) -> None:
"""Handle slash commands for session and checkpoint operations."""
parts = command.split(maxsplit=2)
cmd = parts[0].lower()
if cmd == "/help":
self._write_history("""[bold cyan]Available Commands:[/bold cyan]
[bold]/sessions[/bold] - List all sessions for this agent
[bold]/sessions[/bold] <session_id> - Show session details and checkpoints
[bold]/resume[/bold] - Resume latest paused/failed session
[bold]/resume[/bold] <session_id> - Resume session from where it stopped
[bold]/recover[/bold] <session_id> <cp_id> - Recover from specific checkpoint
[bold]/pause[/bold] - Pause current execution (Ctrl+Z)
[bold]/help[/bold] - Show this help message
[dim]Examples:[/dim]
/sessions [dim]# List all sessions[/dim]
/sessions session_20260208_143022 [dim]# Show session details[/dim]
/resume [dim]# Resume latest session (from state)[/dim]
/resume session_20260208_143022 [dim]# Resume specific session (from state)[/dim]
/recover session_20260208_143022 cp_xxx [dim]# Recover from specific checkpoint[/dim]
/pause [dim]# Pause (or Ctrl+Z)[/dim]
""")
elif cmd == "/sessions":
session_id = parts[1].strip() if len(parts) > 1 else None
await self._cmd_sessions(session_id)
elif cmd == "/resume":
# Resume from session state (not checkpoint-based)
if len(parts) < 2:
session_id = await self._find_latest_resumable_session()
if not session_id:
self._write_history("[bold red]No resumable sessions found[/bold red]")
self._write_history(" Tip: Use [bold]/sessions[/bold] to see all sessions")
return
else:
session_id = parts[1].strip()
await self._cmd_resume(session_id)
elif cmd == "/recover":
# Recover from specific checkpoint
if len(parts) < 3:
self._write_history(
"[bold red]Error:[/bold red] /recover requires session_id and checkpoint_id"
)
self._write_history(" Usage: [bold]/recover <session_id> <checkpoint_id>[/bold]")
self._write_history(
" Tip: Use [bold]/sessions <session_id>[/bold] to see checkpoints"
)
return
session_id = parts[1].strip()
checkpoint_id = parts[2].strip()
await self._cmd_recover(session_id, checkpoint_id)
elif cmd == "/pause":
await self._cmd_pause()
else:
self._write_history(
f"[bold red]Unknown command:[/bold red] {cmd}\n"
"Type [bold]/help[/bold] for available commands"
)
async def _cmd_sessions(self, session_id: str | None) -> None:
"""List sessions or show details of a specific session."""
try:
# Get storage path from runtime
storage_path = self.runtime._storage.base_path
if session_id:
# Show details of specific session including checkpoints
await self._show_session_details(storage_path, session_id)
else:
# List all sessions
await self._list_sessions(storage_path)
except Exception as e:
self._write_history(f"[bold red]Error:[/bold red] {e}")
self._write_history(" Could not access session data")
async def _find_latest_resumable_session(self) -> str | None:
"""Find the most recent paused or failed session."""
try:
storage_path = self.runtime._storage.base_path
sessions_dir = storage_path / "sessions"
if not sessions_dir.exists():
return None
# Get all sessions, most recent first
session_dirs = sorted(
[d for d in sessions_dir.iterdir() if d.is_dir()],
key=lambda d: d.name,
reverse=True,
)
# Find first paused, failed, or cancelled session
import json
for session_dir in session_dirs:
state_file = session_dir / "state.json"
if not state_file.exists():
continue
with open(state_file) as f:
state = json.load(f)
status = state.get("status", "").lower()
# Check if resumable (any non-completed status)
if status in ["paused", "failed", "cancelled", "active"]:
return session_dir.name
return None
except Exception:
return None
async def _list_sessions(self, storage_path: Path) -> None:
"""List all sessions for the agent."""
self._write_history("[bold cyan]Available Sessions:[/bold cyan]")
# Find all session directories
sessions_dir = storage_path / "sessions"
if not sessions_dir.exists():
self._write_history("[dim]No sessions found.[/dim]")
self._write_history(" Sessions will appear here after running the agent")
return
session_dirs = sorted(
[d for d in sessions_dir.iterdir() if d.is_dir()],
key=lambda d: d.name,
reverse=True, # Most recent first
)
if not session_dirs:
self._write_history("[dim]No sessions found.[/dim]")
return
self._write_history(f"[dim]Found {len(session_dirs)} session(s)[/dim]\n")
for session_dir in session_dirs[:10]: # Show last 10 sessions
session_id = session_dir.name
state_file = session_dir / "state.json"
if not state_file.exists():
continue
# Read session state
try:
import json
with open(state_file) as f:
state = json.load(f)
status = state.get("status", "unknown").upper()
# Status with color
if status == "COMPLETED":
status_colored = f"[green]{status}[/green]"
elif status == "FAILED":
status_colored = f"[red]{status}[/red]"
elif status == "PAUSED":
status_colored = f"[yellow]{status}[/yellow]"
elif status == "CANCELLED":
status_colored = f"[dim yellow]{status}[/dim yellow]"
else:
status_colored = f"[dim]{status}[/dim]"
# Check for checkpoints
checkpoint_dir = session_dir / "checkpoints"
checkpoint_count = 0
if checkpoint_dir.exists():
checkpoint_files = list(checkpoint_dir.glob("cp_*.json"))
checkpoint_count = len(checkpoint_files)
# Session line
self._write_history(f"📋 [bold]{session_id}[/bold]")
self._write_history(f" Status: {status_colored} Checkpoints: {checkpoint_count}")
if checkpoint_count > 0:
self._write_history(f" [dim]Resume: /resume {session_id}[/dim]")
self._write_history("") # Blank line
except Exception as e:
self._write_history(f" [dim red]Error reading: {e}[/dim red]")
async def _show_session_details(self, storage_path: Path, session_id: str) -> None:
"""Show detailed information about a specific session."""
self._write_history(f"[bold cyan]Session Details:[/bold cyan] {session_id}\n")
session_dir = storage_path / "sessions" / session_id
if not session_dir.exists():
self._write_history("[bold red]Error:[/bold red] Session not found")
self._write_history(f" Path: {session_dir}")
self._write_history(" Tip: Use [bold]/sessions[/bold] to see available sessions")
return
state_file = session_dir / "state.json"
if not state_file.exists():
self._write_history("[bold red]Error:[/bold red] Session state not found")
return
try:
import json
with open(state_file) as f:
state = json.load(f)
# Basic info
status = state.get("status", "unknown").upper()
if status == "COMPLETED":
status_colored = f"[green]{status}[/green]"
elif status == "FAILED":
status_colored = f"[red]{status}[/red]"
elif status == "PAUSED":
status_colored = f"[yellow]{status}[/yellow]"
elif status == "CANCELLED":
status_colored = f"[dim yellow]{status}[/dim yellow]"
else:
status_colored = status
self._write_history(f"Status: {status_colored}")
if "started_at" in state:
self._write_history(f"Started: {state['started_at']}")
if "completed_at" in state:
self._write_history(f"Completed: {state['completed_at']}")
# Execution path
if "execution_path" in state and state["execution_path"]:
self._write_history("\n[bold]Execution Path:[/bold]")
for node_id in state["execution_path"]:
self._write_history(f"{node_id}")
# Checkpoints
checkpoint_dir = session_dir / "checkpoints"
if checkpoint_dir.exists():
checkpoint_files = sorted(checkpoint_dir.glob("cp_*.json"))
if checkpoint_files:
self._write_history(
f"\n[bold]Available Checkpoints:[/bold] ({len(checkpoint_files)})"
)
# Load and show checkpoints
for i, cp_file in enumerate(checkpoint_files[-5:], 1): # Last 5
try:
with open(cp_file) as f:
cp_data = json.load(f)
cp_id = cp_data.get("checkpoint_id", cp_file.stem)
cp_type = cp_data.get("checkpoint_type", "unknown")
current_node = cp_data.get("current_node", "unknown")
is_clean = cp_data.get("is_clean", False)
clean_marker = "" if is_clean else ""
self._write_history(f" {i}. {clean_marker} [cyan]{cp_id}[/cyan]")
self._write_history(f" Type: {cp_type}, Node: {current_node}")
except Exception:
pass
# Quick actions
if checkpoint_dir.exists() and list(checkpoint_dir.glob("cp_*.json")):
self._write_history("\n[bold]Quick Actions:[/bold]")
self._write_history(
f" [dim]/resume {session_id}[/dim] - Resume from latest checkpoint"
)
except Exception as e:
self._write_history(f"[bold red]Error:[/bold red] {e}")
import traceback
self._write_history(f"[dim]{traceback.format_exc()}[/dim]")
async def _cmd_resume(self, session_id: str) -> None:
"""Resume a session from its last state (session state, not checkpoint)."""
try:
storage_path = self.runtime._storage.base_path
session_dir = storage_path / "sessions" / session_id
# Verify session exists
if not session_dir.exists():
self._write_history(f"[bold red]Error:[/bold red] Session not found: {session_id}")
self._write_history(" Use [bold]/sessions[/bold] to see available sessions")
return
# Load session state
state_file = session_dir / "state.json"
if not state_file.exists():
self._write_history("[bold red]Error:[/bold red] Session state not found")
return
import json
with open(state_file) as f:
state = json.load(f)
# Resume from session state (not checkpoint)
progress = state.get("progress", {})
paused_at = progress.get("paused_at") or progress.get("resume_from")
if paused_at:
# Has paused_at - resume from there
resume_session_state = {
"paused_at": paused_at,
"memory": state.get("memory", {}),
"execution_path": progress.get("path", []),
"node_visit_counts": progress.get("node_visit_counts", {}),
}
resume_info = f"From node: [cyan]{paused_at}[/cyan]"
else:
# No paused_at - just retry with same input
resume_session_state = {}
resume_info = "Retrying with same input"
# Display resume info
self._write_history(f"[bold cyan]🔄 Resuming session[/bold cyan] {session_id}")
self._write_history(f" {resume_info}")
if paused_at:
self._write_history(" [dim](Using session state, not checkpoint)[/dim]")
# Check if already executing
if self._current_exec_id is not None:
self._write_history(
"[bold yellow]Warning:[/bold yellow] An execution is already running"
)
self._write_history(" Wait for it to complete or use /pause first")
return
# Get original input data from session state
input_data = state.get("input_data", {})
# Show indicator
indicator = self.query_one("#processing-indicator", Label)
indicator.update("Resuming from session state...")
indicator.display = True
# Update placeholder
chat_input = self.query_one("#chat-input", Input)
chat_input.placeholder = "Commands: /pause, /sessions (agent resuming...)"
# Trigger execution with resume state
try:
entry_points = self.runtime.get_entry_points()
if not entry_points:
self._write_history("[bold red]Error:[/bold red] No entry points available")
return
# Submit execution with resume state and original input data
future = asyncio.run_coroutine_threadsafe(
self.runtime.trigger(
entry_points[0].id,
input_data=input_data,
session_state=resume_session_state,
),
self._agent_loop,
)
exec_id = await asyncio.wrap_future(future)
self._current_exec_id = exec_id
self._write_history(
f"[green]✓[/green] Resume started (execution: {exec_id[:12]}...)"
)
self._write_history(" Agent is continuing from where it stopped...")
except Exception as e:
self._write_history(f"[bold red]Error starting resume:[/bold red] {e}")
indicator.display = False
chat_input.placeholder = "Enter input for agent..."
except Exception as e:
self._write_history(f"[bold red]Error:[/bold red] {e}")
import traceback
self._write_history(f"[dim]{traceback.format_exc()}[/dim]")
async def _cmd_recover(self, session_id: str, checkpoint_id: str) -> None:
"""Recover a session from a specific checkpoint (time-travel debugging)."""
try:
storage_path = self.runtime._storage.base_path
session_dir = storage_path / "sessions" / session_id
# Verify session exists
if not session_dir.exists():
self._write_history(f"[bold red]Error:[/bold red] Session not found: {session_id}")
self._write_history(" Use [bold]/sessions[/bold] to see available sessions")
return
# Verify checkpoint exists
checkpoint_file = session_dir / "checkpoints" / f"{checkpoint_id}.json"
if not checkpoint_file.exists():
self._write_history(
f"[bold red]Error:[/bold red] Checkpoint not found: {checkpoint_id}"
)
self._write_history(
f" Use [bold]/sessions {session_id}[/bold] to see available checkpoints"
)
return
# Display recover info
self._write_history(f"[bold cyan]⏪ Recovering session[/bold cyan] {session_id}")
self._write_history(f" From checkpoint: [cyan]{checkpoint_id}[/cyan]")
self._write_history(
" [dim](Checkpoint-based recovery for time-travel debugging)[/dim]"
)
# Check if already executing
if self._current_exec_id is not None:
self._write_history(
"[bold yellow]Warning:[/bold yellow] An execution is already running"
)
self._write_history(" Wait for it to complete or use /pause first")
return
# Create session_state for checkpoint recovery
recover_session_state = {
"resume_from_checkpoint": checkpoint_id,
}
# Show indicator
indicator = self.query_one("#processing-indicator", Label)
indicator.update("Recovering from checkpoint...")
indicator.display = True
# Update placeholder
chat_input = self.query_one("#chat-input", Input)
chat_input.placeholder = "Commands: /pause, /sessions (agent recovering...)"
# Trigger execution with checkpoint recovery
try:
entry_points = self.runtime.get_entry_points()
if not entry_points:
self._write_history("[bold red]Error:[/bold red] No entry points available")
return
# Submit execution with checkpoint recovery state
future = asyncio.run_coroutine_threadsafe(
self.runtime.trigger(
entry_points[0].id,
input_data={},
session_state=recover_session_state,
),
self._agent_loop,
)
exec_id = await asyncio.wrap_future(future)
self._current_exec_id = exec_id
self._write_history(
f"[green]✓[/green] Recovery started (execution: {exec_id[:12]}...)"
)
self._write_history(" Agent is continuing from checkpoint...")
except Exception as e:
self._write_history(f"[bold red]Error starting recovery:[/bold red] {e}")
indicator.display = False
chat_input.placeholder = "Enter input for agent..."
except Exception as e:
self._write_history(f"[bold red]Error:[/bold red] {e}")
import traceback
self._write_history(f"[dim]{traceback.format_exc()}[/dim]")
async def _cmd_pause(self) -> None:
"""Immediately pause execution by cancelling task (same as Ctrl+Z)."""
# Check if there's a current execution
if not self._current_exec_id:
self._write_history("[bold yellow]No active execution to pause[/bold yellow]")
self._write_history(" Start an execution first, then use /pause during execution")
return
# Find and cancel the execution task - executor will catch and save state
task_cancelled = False
for stream in self.runtime._streams.values():
exec_id = self._current_exec_id
task = stream._execution_tasks.get(exec_id)
if task and not task.done():
task.cancel()
task_cancelled = True
self._write_history("[bold green]⏸ Execution paused - state saved[/bold green]")
self._write_history(" Resume later with: [bold]/resume[/bold]")
break
if not task_cancelled:
self._write_history("[bold yellow]Execution already completed[/bold yellow]")
def on_mount(self) -> None:
"""Add welcome message when widget mounts."""
"""Add welcome message and check for resumable sessions."""
history = self.query_one("#chat-history", RichLog)
history.write("[bold cyan]Chat REPL Ready[/bold cyan] — Type your input below\n")
history.write(
"[bold cyan]Chat REPL Ready[/bold cyan] — "
"Type your input or use [bold]/help[/bold] for commands\n"
)
# Auto-trigger resume/recover if CLI args provided
if self._resume_session:
if self._resume_checkpoint:
# Use /recover for checkpoint-based recovery
history.write(
"\n[bold cyan]🔄 Auto-recovering from checkpoint "
"(--resume-session + --checkpoint)[/bold cyan]"
)
self.call_later(self._cmd_recover, self._resume_session, self._resume_checkpoint)
else:
# Use /resume for session state resume
history.write(
"\n[bold cyan]🔄 Auto-resuming session (--resume-session)[/bold cyan]"
)
self.call_later(self._cmd_resume, self._resume_session)
return # Skip normal startup messages
# Check for resumable sessions
self._check_and_show_resumable_sessions()
history.write(
"[dim]Quick start: /sessions to see previous sessions, "
"/pause to pause execution[/dim]\n"
)
def _check_and_show_resumable_sessions(self) -> None:
"""Check for non-terminated sessions and prompt user."""
try:
storage_path = self.runtime._storage.base_path
sessions_dir = storage_path / "sessions"
if not sessions_dir.exists():
return
# Find non-terminated sessions (paused, failed, cancelled, active)
resumable = []
session_dirs = sorted(
[d for d in sessions_dir.iterdir() if d.is_dir()],
key=lambda d: d.name,
reverse=True, # Most recent first
)
import json
for session_dir in session_dirs[:5]: # Check last 5 sessions
state_file = session_dir / "state.json"
if not state_file.exists():
continue
try:
with open(state_file) as f:
state = json.load(f)
status = state.get("status", "").lower()
# Non-terminated statuses
if status in ["paused", "failed", "cancelled", "active"]:
resumable.append(
{
"session_id": session_dir.name,
"status": status.upper(),
}
)
except Exception:
continue
if resumable:
self._write_history("\n[bold yellow]⚠ Non-terminated sessions found:[/bold yellow]")
for i, session in enumerate(resumable[:3], 1): # Show top 3
status = session["status"]
session_id = session["session_id"]
# Color code status
if status == "PAUSED":
status_colored = f"[yellow]{status}[/yellow]"
elif status == "FAILED":
status_colored = f"[red]{status}[/red]"
elif status == "CANCELLED":
status_colored = f"[dim yellow]{status}[/dim yellow]"
else:
status_colored = f"[dim]{status}[/dim]"
self._write_history(f" {i}. {session_id[:32]}... [{status_colored}]")
self._write_history("\n[bold cyan]What would you like to do?[/bold cyan]")
self._write_history(" • Type [bold]/resume[/bold] to continue the latest session")
self._write_history(
f" • Type [bold]/resume {resumable[0]['session_id']}[/bold] "
"for specific session"
)
self._write_history(" • Or just type your input to start a new session\n")
except Exception:
# Silently fail - don't block TUI startup
pass
async def on_input_submitted(self, message: Input.Submitted) -> None:
"""Handle input submission — either start new execution or inject input."""
@@ -132,15 +719,21 @@ class ChatRepl(Vertical):
if not user_input:
return
# Handle commands (starting with /) - ALWAYS process commands first
# Commands work during execution, during client-facing input, anytime
if user_input.startswith("/"):
await self._handle_command(user_input)
message.input.value = ""
return
# Client-facing input: route to the waiting node
if self._waiting_for_input and self._input_node_id:
self._write_history(f"[bold green]You:[/bold green] {user_input}")
message.input.value = ""
# Disable input while agent processes the response
# Keep input enabled for commands (but change placeholder)
chat_input = self.query_one("#chat-input", Input)
chat_input.disabled = True
chat_input.placeholder = "Enter input for agent..."
chat_input.placeholder = "Commands: /pause, /sessions (agent processing...)"
self._waiting_for_input = False
indicator = self.query_one("#processing-indicator", Label)
@@ -193,9 +786,9 @@ class ChatRepl(Vertical):
indicator.update("Thinking...")
indicator.display = True
# Disable input while the agent is working
# Keep input enabled for commands during execution
chat_input = self.query_one("#chat-input", Input)
chat_input.disabled = True
chat_input.placeholder = "Commands available: /pause, /sessions, /help"
# Submit execution to the dedicated agent loop so blocking
# runtime code (LLM, MCP tools) never touches Textual's loop.