Merge upstream/main: resolve conflicts with Apollo integration
- Keep both APOLLO_CREDENTIALS and AIRTABLE_CREDENTIALS - Keep both apollo_tool and airtable_tool imports (alphabetical) Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -495,11 +495,96 @@ max_node_visits=3 # Prevent getting stuck
|
||||
- Confirm it calls set_output eventually
|
||||
```
|
||||
|
||||
#### Template 6: Checkpoint Recovery (Post-Fix Resume)
|
||||
|
||||
```markdown
|
||||
## Recovery Strategy: Resume from Last Clean Checkpoint
|
||||
|
||||
**Situation:** You've fixed the issue, but the failed session is stuck mid-execution
|
||||
|
||||
**Solution:** Resume execution from a checkpoint before the failure
|
||||
|
||||
### Option A: Auto-Resume from Latest Checkpoint (Recommended)
|
||||
|
||||
Use CLI arguments to auto-resume when launching TUI:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m {agent_name} --tui \
|
||||
--resume-session {session_id}
|
||||
```
|
||||
|
||||
This will:
|
||||
- Load session state from `state.json`
|
||||
- Continue from where it paused/failed
|
||||
- Apply your fixes immediately
|
||||
|
||||
### Option B: Resume from Specific Checkpoint (Time-Travel)
|
||||
|
||||
If you need to go back to an earlier point:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m {agent_name} --tui \
|
||||
--resume-session {session_id} \
|
||||
--checkpoint {checkpoint_id}
|
||||
```
|
||||
|
||||
Example:
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m deep_research_agent --tui \
|
||||
--resume-session session_20260208_143022_abc12345 \
|
||||
--checkpoint cp_node_complete_intake_143030
|
||||
```
|
||||
|
||||
### Option C: Use TUI Commands
|
||||
|
||||
Alternatively, launch TUI normally and use commands:
|
||||
|
||||
```bash
|
||||
# Launch TUI
|
||||
PYTHONPATH=core:exports python -m {agent_name} --tui
|
||||
|
||||
# In TUI, use commands:
|
||||
/resume {session_id} # Resume from session state
|
||||
/recover {session_id} {checkpoint_id} # Recover from specific checkpoint
|
||||
```
|
||||
|
||||
### When to Use Each Option:
|
||||
|
||||
**Use `/resume` (or --resume-session) when:**
|
||||
- You fixed credentials and want to retry
|
||||
- Agent paused and you want to continue
|
||||
- Agent failed and you want to retry from last state
|
||||
|
||||
**Use `/recover` (or --resume-session + --checkpoint) when:**
|
||||
- You need to go back to an earlier checkpoint
|
||||
- You want to try a different path from a specific point
|
||||
- Debugging requires time-travel to earlier state
|
||||
|
||||
### Find Available Checkpoints:
|
||||
|
||||
```bash
|
||||
# In TUI:
|
||||
/sessions {session_id}
|
||||
|
||||
# This shows all checkpoints with timestamps:
|
||||
Available Checkpoints: (3)
|
||||
1. cp_node_complete_intake_143030
|
||||
2. cp_node_complete_research_143115
|
||||
3. cp_pause_research_143130
|
||||
```
|
||||
|
||||
**Verification:**
|
||||
- Use `--resume-session` to test your fix immediately
|
||||
- No need to re-run from the beginning
|
||||
- Session continues with your code changes applied
|
||||
```
|
||||
|
||||
**Selecting the right template:**
|
||||
- Match the issue category from Stage 4
|
||||
- Customize with specific details from Stage 5
|
||||
- Include actual error messages and code snippets
|
||||
- Provide file paths and line numbers when possible
|
||||
- **Always include recovery commands** (Template 6) after providing fix recommendations
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -197,16 +197,18 @@ exports/agent_name/
|
||||
|
||||
### What This Phase Does
|
||||
|
||||
Creates comprehensive test suite:
|
||||
- Constraint tests (verify hard requirements)
|
||||
- Success criteria tests (measure goal achievement)
|
||||
- Edge case tests (handle failures gracefully)
|
||||
- Integration tests (end-to-end workflows)
|
||||
### What This Phase Does
|
||||
|
||||
Guides the creation and execution of a comprehensive test suite:
|
||||
- Constraint tests
|
||||
- Success criteria tests
|
||||
- Edge case tests
|
||||
- Integration tests
|
||||
|
||||
### Process
|
||||
|
||||
1. **Analyze agent** - Read goal, constraints, success criteria
|
||||
2. **Generate tests** - Create pytest files in `exports/agent_name/tests/`
|
||||
2. **Generate tests** - The calling agent writes pytest files in `exports/agent_name/tests/` using hive-test guidelines and templates
|
||||
3. **User approval** - Review and approve each test
|
||||
4. **Run evaluation** - Execute tests and collect results
|
||||
5. **Debug failures** - Identify and fix issues
|
||||
|
||||
+36
-12
@@ -8,26 +8,49 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Initial project structure
|
||||
- React frontend (honeycomb) with Vite and TypeScript
|
||||
- Node.js backend (hive) with Express and TypeScript
|
||||
- Docker Compose configuration for local development
|
||||
- Configuration system via `config.yaml`
|
||||
- GitHub Actions CI/CD workflows
|
||||
- Comprehensive documentation
|
||||
|
||||
### Changed
|
||||
- N/A
|
||||
|
||||
### Fixed
|
||||
|
||||
### Security
|
||||
|
||||
## [0.4.2] - 2026-02-08
|
||||
|
||||
### Added
|
||||
- Resumable sessions: agents now automatically save state and can resume after interruptions
|
||||
- `/resume` command in TUI to resume latest paused/failed session
|
||||
- `/resume <session_id>` command to resume specific sessions
|
||||
- `/sessions` command to list all sessions for current agent
|
||||
- `--resume-session` CLI flag for automatic session resumption on startup
|
||||
- `--checkpoint <checkpoint_id>` CLI flag for checkpoint-based recovery
|
||||
- Ctrl+Z now immediately pauses execution with full state capture
|
||||
- `/pause` command for immediate pause during execution
|
||||
- Session state persistence: memory, execution path, node positions, visit counts
|
||||
- Unified session storage at `~/.hive/agents/{agent_name}/sessions/`
|
||||
- Automatic memory restoration on resume with full conversation history
|
||||
|
||||
### Changed
|
||||
- TUI quit now pauses execution and saves state instead of cancelling
|
||||
- Pause operations now use immediate task cancellation instead of waiting for node boundaries
|
||||
- Session cleanup timeout increased from 0.5s to 5s to ensure proper state saving
|
||||
- Session status now tracked as: active, paused, completed, failed, cancelled
|
||||
|
||||
### Deprecated
|
||||
- N/A
|
||||
- Pause nodes (use client-facing EventLoopNodes instead)
|
||||
- `request_pause()` method (replaced with immediate task cancellation)
|
||||
|
||||
### Removed
|
||||
- N/A
|
||||
|
||||
|
||||
### Fixed
|
||||
- tools: Fixed web_scrape tool attempting to parse non-HTML content (PDF, JSON) as HTML (#487)
|
||||
- Memory persistence: ExecutionResult.session_state["memory"] now populated at all exit points
|
||||
- Resume now starts at correct paused_at node instead of intake node
|
||||
- Visit count double-counting on resume (paused node count now properly adjusted)
|
||||
- Session selection now picks most recent session instead of oldest
|
||||
- Quit state save failures due to insufficient timeout
|
||||
- Ctrl+Z pause implementation (was only showing notification without pausing)
|
||||
- Empty memory on resume by ensuring session_state["memory"] is properly populated
|
||||
|
||||
### Security
|
||||
- N/A
|
||||
@@ -37,5 +60,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
### Added
|
||||
- Initial release
|
||||
|
||||
[Unreleased]: https://github.com/adenhq/hive/compare/v0.1.0...HEAD
|
||||
[Unreleased]: https://github.com/adenhq/hive/compare/v0.4.2...HEAD
|
||||
[0.4.2]: https://github.com/adenhq/hive/compare/v0.4.0...v0.4.2
|
||||
[0.1.0]: https://github.com/adenhq/hive/releases/tag/v0.1.0
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
# Contributing to Aden Agent Framework
|
||||
|
||||
Thank you for your interest in contributing to the Aden Agent Framework! This document provides guidelines and information for contributors. We’re especially looking for help building tools, integrations([check #2805](https://github.com/adenhq/hive/issues/2805)), and example agents for the framework. If you’re interested in extending its functionality, this is the perfect place to start.
|
||||
Thank you for your interest in contributing to the Aden Agent Framework! This document provides guidelines and information for contributors. We’re especially looking for help building tools, integrations ([check #2805](https://github.com/adenhq/hive/issues/2805)), and example agents for the framework. If you’re interested in extending its functionality, this is the perfect place to start.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
|
||||
@@ -75,6 +75,7 @@ Use Hive when you need:
|
||||
- **[Changelog](https://github.com/adenhq/hive/releases)** - Latest updates and releases
|
||||
- **[Roadmap](docs/roadmap.md)** - Upcoming features and plans
|
||||
- **[Report Issues](https://github.com/adenhq/hive/issues)** - Bug reports and feature requests
|
||||
- **[Contributing](CONTRIBUTING.md)** - How to contribute and submit PRs
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
||||
@@ -143,19 +143,34 @@ class AdenCredentialResponse:
|
||||
def from_dict(
|
||||
cls, data: dict[str, Any], integration_id: str | None = None
|
||||
) -> AdenCredentialResponse:
|
||||
"""Create from API response dictionary."""
|
||||
"""Create from API response dictionary or normalized credential dict."""
|
||||
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
|
||||
|
||||
resolved_integration_id = (
|
||||
integration_id
|
||||
or data.get("integration_id")
|
||||
or data.get("alias")
|
||||
or data.get("provider", "")
|
||||
)
|
||||
|
||||
resolved_integration_type = data.get("integration_type") or data.get("provider", "")
|
||||
metadata = data.get("metadata")
|
||||
if metadata is None and data.get("email"):
|
||||
metadata = {"email": data.get("email")}
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
return cls(
|
||||
integration_id=integration_id or data.get("alias", data.get("provider", "")),
|
||||
integration_type=data.get("provider", ""),
|
||||
integration_id=resolved_integration_id,
|
||||
integration_type=resolved_integration_type,
|
||||
access_token=data["access_token"],
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
expires_at=expires_at,
|
||||
scopes=data.get("scopes", []),
|
||||
metadata={"email": data.get("email")} if data.get("email") else {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Checkpoint Configuration - Controls checkpoint behavior during execution.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointConfig:
|
||||
"""
|
||||
Configuration for checkpoint behavior during graph execution.
|
||||
|
||||
Controls when checkpoints are created, how they're stored,
|
||||
and when they're pruned.
|
||||
"""
|
||||
|
||||
# Enable/disable checkpointing
|
||||
enabled: bool = True
|
||||
|
||||
# When to checkpoint
|
||||
checkpoint_on_node_start: bool = True
|
||||
checkpoint_on_node_complete: bool = True
|
||||
|
||||
# Pruning (time-based)
|
||||
checkpoint_max_age_days: int = 7 # Prune checkpoints older than 1 week
|
||||
prune_every_n_nodes: int = 10 # Check for pruning every N nodes
|
||||
|
||||
# Performance
|
||||
async_checkpoint: bool = True # Don't block execution on checkpoint writes
|
||||
|
||||
# What to include in checkpoints
|
||||
include_full_memory: bool = True
|
||||
include_metrics: bool = True
|
||||
|
||||
def should_checkpoint_node_start(self) -> bool:
|
||||
"""Check if should checkpoint before node execution."""
|
||||
return self.enabled and self.checkpoint_on_node_start
|
||||
|
||||
def should_checkpoint_node_complete(self) -> bool:
|
||||
"""Check if should checkpoint after node execution."""
|
||||
return self.enabled and self.checkpoint_on_node_complete
|
||||
|
||||
def should_prune_checkpoints(self, nodes_executed: int) -> bool:
|
||||
"""
|
||||
Check if should prune checkpoints based on execution progress.
|
||||
|
||||
Args:
|
||||
nodes_executed: Number of nodes executed so far
|
||||
|
||||
Returns:
|
||||
True if should check for old checkpoints and prune them
|
||||
"""
|
||||
return (
|
||||
self.enabled
|
||||
and self.prune_every_n_nodes > 0
|
||||
and nodes_executed % self.prune_every_n_nodes == 0
|
||||
)
|
||||
|
||||
|
||||
# Default configuration for most agents
|
||||
DEFAULT_CHECKPOINT_CONFIG = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=True,
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
prune_every_n_nodes=10,
|
||||
async_checkpoint=True,
|
||||
)
|
||||
|
||||
|
||||
# Minimal configuration (only checkpoint at node completion)
|
||||
MINIMAL_CHECKPOINT_CONFIG = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=False,
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
prune_every_n_nodes=20,
|
||||
async_checkpoint=True,
|
||||
)
|
||||
|
||||
|
||||
# Disabled configuration (no checkpointing)
|
||||
DISABLED_CHECKPOINT_CONFIG = CheckpointConfig(
|
||||
enabled=False,
|
||||
)
|
||||
@@ -1763,7 +1763,19 @@ class EventLoopNode(NodeProtocol):
|
||||
conversation: NodeConversation,
|
||||
iteration: int,
|
||||
) -> bool:
|
||||
"""Check if pause has been requested. Returns True if paused."""
|
||||
"""
|
||||
Check if pause has been requested. Returns True if paused.
|
||||
|
||||
Note: This check happens BEFORE starting iteration N, after completing N-1.
|
||||
If paused, the node exits having completed {iteration} iterations (0 to iteration-1).
|
||||
"""
|
||||
# Check executor-level pause event (for /pause command, Ctrl+Z)
|
||||
if ctx.pause_event and ctx.pause_event.is_set():
|
||||
completed = iteration # 0-indexed: iteration=3 means 3 iterations completed (0,1,2)
|
||||
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (executor-level)")
|
||||
return True
|
||||
|
||||
# Check context-level pause flags (legacy/alternative methods)
|
||||
pause_requested = ctx.input_data.get("pause_requested", False)
|
||||
if not pause_requested:
|
||||
try:
|
||||
@@ -1771,8 +1783,10 @@ class EventLoopNode(NodeProtocol):
|
||||
except (PermissionError, KeyError):
|
||||
pause_requested = False
|
||||
if pause_requested:
|
||||
logger.info(f"Pause requested at iteration {iteration}")
|
||||
completed = iteration
|
||||
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (context-level)")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
@@ -17,6 +17,7 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import (
|
||||
@@ -33,6 +34,8 @@ from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
|
||||
from framework.graph.validator import OutputValidator
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.schemas.checkpoint import Checkpoint
|
||||
from framework.storage.checkpoint_store import CheckpointStore
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -179,6 +182,9 @@ class GraphExecutor:
|
||||
self.enable_parallel_execution = enable_parallel_execution
|
||||
self._parallel_config = parallel_config or ParallelExecutionConfig()
|
||||
|
||||
# Pause/resume control
|
||||
self._pause_requested = asyncio.Event()
|
||||
|
||||
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
||||
"""
|
||||
Validate that all tools declared by nodes are available.
|
||||
@@ -208,6 +214,7 @@ class GraphExecutor:
|
||||
goal: Goal,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
session_state: dict[str, Any] | None = None,
|
||||
checkpoint_config: "CheckpointConfig | None" = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a graph for a goal.
|
||||
@@ -246,6 +253,12 @@ class GraphExecutor:
|
||||
# Initialize execution state
|
||||
memory = SharedMemory()
|
||||
|
||||
# Initialize checkpoint store if checkpointing is enabled
|
||||
checkpoint_store: CheckpointStore | None = None
|
||||
if checkpoint_config and checkpoint_config.enabled and self._storage_path:
|
||||
checkpoint_store = CheckpointStore(self._storage_path)
|
||||
self.logger.info("✓ Checkpointing enabled")
|
||||
|
||||
# Restore session state if provided
|
||||
if session_state and "memory" in session_state:
|
||||
memory_data = session_state["memory"]
|
||||
@@ -273,8 +286,110 @@ class GraphExecutor:
|
||||
node_visit_counts: dict[str, int] = {} # Track visits for feedback loops
|
||||
_is_retry = False # True when looping back for a retry (not a new visit)
|
||||
|
||||
# Restore node_visit_counts from session state if available
|
||||
if session_state and "node_visit_counts" in session_state:
|
||||
node_visit_counts = dict(session_state["node_visit_counts"])
|
||||
if node_visit_counts:
|
||||
self.logger.info(f"📥 Restored node visit counts: {node_visit_counts}")
|
||||
|
||||
# If resuming at a specific node (paused_at), that node was counted
|
||||
# but never completed, so decrement its count
|
||||
paused_at = session_state.get("paused_at")
|
||||
if (
|
||||
paused_at
|
||||
and paused_at in node_visit_counts
|
||||
and node_visit_counts[paused_at] > 0
|
||||
):
|
||||
old_count = node_visit_counts[paused_at]
|
||||
node_visit_counts[paused_at] -= 1
|
||||
self.logger.info(
|
||||
f"📥 Decremented visit count for paused node '{paused_at}': "
|
||||
f"{old_count} -> {node_visit_counts[paused_at]}"
|
||||
)
|
||||
|
||||
# Determine entry point (may differ if resuming)
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
# Check if resuming from checkpoint
|
||||
if session_state and session_state.get("resume_from_checkpoint") and checkpoint_store:
|
||||
checkpoint_id = session_state["resume_from_checkpoint"]
|
||||
try:
|
||||
checkpoint = await checkpoint_store.load_checkpoint(checkpoint_id)
|
||||
|
||||
if checkpoint:
|
||||
self.logger.info(
|
||||
f"🔄 Resuming from checkpoint: {checkpoint_id} "
|
||||
f"(node: {checkpoint.current_node})"
|
||||
)
|
||||
|
||||
# Restore memory from checkpoint
|
||||
for key, value in checkpoint.shared_memory.items():
|
||||
memory.write(key, value, validate=False)
|
||||
|
||||
# Start from checkpoint's next node or current node
|
||||
current_node_id = (
|
||||
checkpoint.next_node or checkpoint.current_node or graph.entry_node
|
||||
)
|
||||
|
||||
# Restore execution path
|
||||
path.extend(checkpoint.execution_path)
|
||||
|
||||
self.logger.info(
|
||||
f"📥 Restored memory with {len(checkpoint.shared_memory)} keys, "
|
||||
f"resuming at node: {current_node_id}"
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"Checkpoint {checkpoint_id} not found, resuming from normal entry point"
|
||||
)
|
||||
# Check if resuming from paused_at (fallback to session state)
|
||||
paused_at = session_state.get("paused_at") if session_state else None
|
||||
if paused_at and graph.get_node(paused_at) is not None:
|
||||
current_node_id = paused_at
|
||||
self.logger.info(f"🔄 Resuming from paused node: {paused_at}")
|
||||
else:
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to load checkpoint {checkpoint_id}: {e}, "
|
||||
f"resuming from normal entry point"
|
||||
)
|
||||
# Check if resuming from paused_at (fallback to session state)
|
||||
paused_at = session_state.get("paused_at") if session_state else None
|
||||
if paused_at and graph.get_node(paused_at) is not None:
|
||||
current_node_id = paused_at
|
||||
self.logger.info(f"🔄 Resuming from paused node: {paused_at}")
|
||||
else:
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
else:
|
||||
# Check if resuming from paused_at (session state resume)
|
||||
paused_at = session_state.get("paused_at") if session_state else None
|
||||
node_ids = [n.id for n in graph.nodes]
|
||||
self.logger.info(f"🔍 Debug: paused_at={paused_at}, available node IDs={node_ids}")
|
||||
|
||||
if paused_at and graph.get_node(paused_at) is not None:
|
||||
# Resume from paused_at node directly (works for any node, not just pause_nodes)
|
||||
current_node_id = paused_at
|
||||
|
||||
# Restore execution path from session state if available
|
||||
if session_state:
|
||||
execution_path = session_state.get("execution_path", [])
|
||||
if execution_path:
|
||||
path.extend(execution_path)
|
||||
self.logger.info(
|
||||
f"🔄 Resuming from paused node: {paused_at} "
|
||||
f"(restored path: {execution_path})"
|
||||
)
|
||||
else:
|
||||
self.logger.info(f"🔄 Resuming from paused node: {paused_at}")
|
||||
else:
|
||||
self.logger.info(f"🔄 Resuming from paused node: {paused_at}")
|
||||
else:
|
||||
# Fall back to normal entry point logic
|
||||
self.logger.warning(
|
||||
f"⚠ paused_at={paused_at} is not a valid node, falling back to entry point"
|
||||
)
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
|
||||
steps = 0
|
||||
|
||||
if session_state and current_node_id != graph.entry_node:
|
||||
@@ -313,6 +428,45 @@ class GraphExecutor:
|
||||
while steps < graph.max_steps:
|
||||
steps += 1
|
||||
|
||||
# Check for pause request
|
||||
if self._pause_requested.is_set():
|
||||
self.logger.info("⏸ Pause detected - stopping at node boundary")
|
||||
|
||||
# Create session state for pause
|
||||
saved_memory = memory.read_all()
|
||||
pause_session_state: dict[str, Any] = {
|
||||
"memory": saved_memory, # Include memory for resume
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
}
|
||||
|
||||
# Create a pause checkpoint
|
||||
if checkpoint_store:
|
||||
pause_checkpoint = self._create_checkpoint(
|
||||
checkpoint_type="pause",
|
||||
current_node=current_node_id,
|
||||
execution_path=path,
|
||||
memory=memory,
|
||||
next_node=current_node_id,
|
||||
is_clean=True,
|
||||
)
|
||||
await checkpoint_store.save_checkpoint(pause_checkpoint)
|
||||
pause_session_state["latest_checkpoint_id"] = pause_checkpoint.checkpoint_id
|
||||
pause_session_state["resume_from_checkpoint"] = (
|
||||
pause_checkpoint.checkpoint_id
|
||||
)
|
||||
|
||||
# Return with paused status
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
output=saved_memory,
|
||||
path=path,
|
||||
paused_at=current_node_id,
|
||||
error="Execution paused by user request",
|
||||
session_state=pause_session_state,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
# Get current node
|
||||
node_spec = graph.get_node(current_node_id)
|
||||
if node_spec is None:
|
||||
@@ -391,6 +545,27 @@ class GraphExecutor:
|
||||
description=f"Validation errors for {current_node_id}: {validation_errors}",
|
||||
)
|
||||
|
||||
# CHECKPOINT: node_start
|
||||
if (
|
||||
checkpoint_store
|
||||
and checkpoint_config
|
||||
and checkpoint_config.should_checkpoint_node_start()
|
||||
):
|
||||
checkpoint = self._create_checkpoint(
|
||||
checkpoint_type="node_start",
|
||||
current_node=node_spec.id,
|
||||
execution_path=list(path),
|
||||
memory=memory,
|
||||
is_clean=(sum(node_retry_counts.values()) == 0),
|
||||
)
|
||||
|
||||
if checkpoint_config.async_checkpoint:
|
||||
# Non-blocking checkpoint save
|
||||
asyncio.create_task(checkpoint_store.save_checkpoint(checkpoint))
|
||||
else:
|
||||
# Blocking checkpoint save
|
||||
await checkpoint_store.save_checkpoint(checkpoint)
|
||||
|
||||
# Emit node-started event (skip event_loop nodes — they emit their own)
|
||||
if self._event_bus and node_spec.node_type != "event_loop":
|
||||
await self._event_bus.emit_node_loop_started(
|
||||
@@ -464,6 +639,13 @@ class GraphExecutor:
|
||||
if len(value_str) > 200:
|
||||
value_str = value_str[:200] + "..."
|
||||
self.logger.info(f" {key}: {value_str}")
|
||||
|
||||
# Write node outputs to memory BEFORE edge evaluation
|
||||
# This enables direct key access in conditional expressions (e.g., "score > 80")
|
||||
# Without this, conditional edges can only use output['key'] syntax
|
||||
if result.output:
|
||||
for key, value in result.output.items():
|
||||
memory.write(key, value, validate=False)
|
||||
else:
|
||||
self.logger.error(f" ✗ Failed: {result.error}")
|
||||
|
||||
@@ -557,13 +739,21 @@ class GraphExecutor:
|
||||
execution_quality="failed",
|
||||
)
|
||||
|
||||
# Save memory for potential resume
|
||||
saved_memory = memory.read_all()
|
||||
failure_session_state = {
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
}
|
||||
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=(
|
||||
f"Node '{node_spec.name}' failed after "
|
||||
f"{max_retries} attempts: {result.error}"
|
||||
),
|
||||
output=memory.read_all(),
|
||||
output=saved_memory,
|
||||
steps_executed=steps,
|
||||
total_tokens=total_tokens,
|
||||
total_latency_ms=total_latency,
|
||||
@@ -574,6 +764,7 @@ class GraphExecutor:
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
session_state=failure_session_state,
|
||||
)
|
||||
|
||||
# Check if we just executed a pause node - if so, save state and return
|
||||
@@ -696,6 +887,39 @@ class GraphExecutor:
|
||||
break
|
||||
next_spec = graph.get_node(next_node)
|
||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||
|
||||
# CHECKPOINT: node_complete (after determining next node)
|
||||
if (
|
||||
checkpoint_store
|
||||
and checkpoint_config
|
||||
and checkpoint_config.should_checkpoint_node_complete()
|
||||
):
|
||||
checkpoint = self._create_checkpoint(
|
||||
checkpoint_type="node_complete",
|
||||
current_node=node_spec.id,
|
||||
execution_path=list(path),
|
||||
memory=memory,
|
||||
next_node=next_node,
|
||||
is_clean=(sum(node_retry_counts.values()) == 0),
|
||||
)
|
||||
|
||||
if checkpoint_config.async_checkpoint:
|
||||
asyncio.create_task(checkpoint_store.save_checkpoint(checkpoint))
|
||||
else:
|
||||
await checkpoint_store.save_checkpoint(checkpoint)
|
||||
|
||||
# Periodic checkpoint pruning
|
||||
if (
|
||||
checkpoint_store
|
||||
and checkpoint_config
|
||||
and checkpoint_config.should_prune_checkpoints(len(path))
|
||||
):
|
||||
asyncio.create_task(
|
||||
checkpoint_store.prune_checkpoints(
|
||||
max_age_days=checkpoint_config.checkpoint_max_age_days
|
||||
)
|
||||
)
|
||||
|
||||
current_node_id = next_node
|
||||
|
||||
# Update input_data for next node
|
||||
@@ -753,6 +977,50 @@ class GraphExecutor:
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation (e.g., TUI quit) - save as paused instead of failed
|
||||
self.logger.info("⏸ Execution cancelled - saving state for resume")
|
||||
|
||||
# Save memory and state for resume
|
||||
saved_memory = memory.read_all()
|
||||
session_state_out: dict[str, Any] = {
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
}
|
||||
|
||||
# Calculate quality metrics
|
||||
total_retries_count = sum(node_retry_counts.values())
|
||||
nodes_failed = [nid for nid, count in node_retry_counts.items() if count > 0]
|
||||
exec_quality = "degraded" if total_retries_count > 0 else "clean"
|
||||
|
||||
if self.runtime_logger:
|
||||
await self.runtime_logger.end_run(
|
||||
status="paused",
|
||||
duration_ms=total_latency,
|
||||
node_path=path,
|
||||
execution_quality=exec_quality,
|
||||
)
|
||||
|
||||
# Return with paused status
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error="Execution paused by user",
|
||||
output=saved_memory,
|
||||
steps_executed=steps,
|
||||
total_tokens=total_tokens,
|
||||
total_latency_ms=total_latency,
|
||||
path=path,
|
||||
paused_at=current_node_id, # Save where we were
|
||||
session_state=session_state_out,
|
||||
total_retries=total_retries_count,
|
||||
nodes_with_failures=nodes_failed,
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
@@ -790,9 +1058,40 @@ class GraphExecutor:
|
||||
execution_quality="failed",
|
||||
)
|
||||
|
||||
# Save memory and state for potential resume
|
||||
saved_memory = memory.read_all()
|
||||
session_state_out: dict[str, Any] = {
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
}
|
||||
|
||||
# Mark latest checkpoint for resume on failure
|
||||
if checkpoint_store:
|
||||
try:
|
||||
checkpoints = await checkpoint_store.list_checkpoints()
|
||||
if checkpoints:
|
||||
# Find latest clean checkpoint
|
||||
index = await checkpoint_store.load_index()
|
||||
if index:
|
||||
latest_clean = index.get_latest_clean_checkpoint()
|
||||
if latest_clean:
|
||||
session_state_out["resume_from_checkpoint"] = (
|
||||
latest_clean.checkpoint_id
|
||||
)
|
||||
session_state_out["latest_checkpoint_id"] = (
|
||||
latest_clean.checkpoint_id
|
||||
)
|
||||
self.logger.info(
|
||||
f"💾 Marked checkpoint for resume: {latest_clean.checkpoint_id}"
|
||||
)
|
||||
except Exception as checkpoint_err:
|
||||
self.logger.warning(f"Failed to mark checkpoint for resume: {checkpoint_err}")
|
||||
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
output=saved_memory,
|
||||
steps_executed=steps,
|
||||
path=path,
|
||||
total_retries=total_retries_count,
|
||||
@@ -801,6 +1100,7 @@ class GraphExecutor:
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
session_state=session_state_out,
|
||||
)
|
||||
|
||||
finally:
|
||||
@@ -841,6 +1141,7 @@ class GraphExecutor:
|
||||
goal=goal, # Pass Goal object for LLM-powered routers
|
||||
max_tokens=max_tokens,
|
||||
runtime_logger=self.runtime_logger,
|
||||
pause_event=self._pause_requested, # Pass pause event for granular control
|
||||
)
|
||||
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
@@ -1353,3 +1654,50 @@ class GraphExecutor:
|
||||
def register_function(self, node_id: str, func: Callable) -> None:
|
||||
"""Register a function as a node."""
|
||||
self.node_registry[node_id] = FunctionNode(func)
|
||||
|
||||
def request_pause(self) -> None:
|
||||
"""
|
||||
Request graceful pause of the current execution.
|
||||
|
||||
The execution will pause at the next node boundary after the current
|
||||
node completes. A checkpoint will be saved at the pause point, allowing
|
||||
the execution to be resumed later.
|
||||
|
||||
This method is safe to call from any thread.
|
||||
"""
|
||||
self._pause_requested.set()
|
||||
self.logger.info("⏸ Pause requested - will pause at next node boundary")
|
||||
|
||||
def _create_checkpoint(
|
||||
self,
|
||||
checkpoint_type: str,
|
||||
current_node: str,
|
||||
execution_path: list[str],
|
||||
memory: SharedMemory,
|
||||
next_node: str | None = None,
|
||||
is_clean: bool = True,
|
||||
) -> Checkpoint:
|
||||
"""
|
||||
Create a checkpoint from current execution state.
|
||||
|
||||
Args:
|
||||
checkpoint_type: Type of checkpoint (node_start, node_complete)
|
||||
current_node: Current node ID
|
||||
execution_path: Nodes executed so far
|
||||
memory: SharedMemory instance
|
||||
next_node: Next node to execute (for node_complete checkpoints)
|
||||
is_clean: Whether execution was clean up to this point
|
||||
|
||||
Returns:
|
||||
New Checkpoint instance
|
||||
"""
|
||||
|
||||
return Checkpoint.create(
|
||||
checkpoint_type=checkpoint_type,
|
||||
session_id=self._storage_path.name if self._storage_path else "unknown",
|
||||
current_node=current_node,
|
||||
execution_path=execution_path,
|
||||
shared_memory=memory.read_all(),
|
||||
next_node=next_node,
|
||||
is_clean=is_clean,
|
||||
)
|
||||
|
||||
@@ -480,6 +480,9 @@ class NodeContext:
|
||||
# Runtime logging (optional)
|
||||
runtime_logger: Any = None # RuntimeLogger | None — uses Any to avoid import
|
||||
|
||||
# Pause control (optional) - asyncio.Event for pause requests
|
||||
pause_event: Any = None # asyncio.Event | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeResult:
|
||||
|
||||
@@ -63,6 +63,18 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
default=None,
|
||||
help="LLM model to use (any LiteLLM-compatible name)",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--resume-session",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Resume from a specific session ID",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Resume from a specific checkpoint (requires --resume-session)",
|
||||
)
|
||||
run_parser.set_defaults(func=cmd_run)
|
||||
|
||||
# info command
|
||||
@@ -196,6 +208,129 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
)
|
||||
tui_parser.set_defaults(func=cmd_tui)
|
||||
|
||||
# sessions command group (checkpoint/resume management)
|
||||
sessions_parser = subparsers.add_parser(
|
||||
"sessions",
|
||||
help="Manage agent sessions",
|
||||
description="List, inspect, and manage agent execution sessions.",
|
||||
)
|
||||
sessions_subparsers = sessions_parser.add_subparsers(
|
||||
dest="sessions_cmd",
|
||||
help="Session management commands",
|
||||
)
|
||||
|
||||
# sessions list
|
||||
sessions_list_parser = sessions_subparsers.add_parser(
|
||||
"list",
|
||||
help="List agent sessions",
|
||||
description="List all sessions for an agent.",
|
||||
)
|
||||
sessions_list_parser.add_argument(
|
||||
"agent_path",
|
||||
type=str,
|
||||
help="Path to agent folder",
|
||||
)
|
||||
sessions_list_parser.add_argument(
|
||||
"--status",
|
||||
choices=["all", "active", "failed", "completed", "paused"],
|
||||
default="all",
|
||||
help="Filter by session status (default: all)",
|
||||
)
|
||||
sessions_list_parser.add_argument(
|
||||
"--has-checkpoints",
|
||||
action="store_true",
|
||||
help="Show only sessions with checkpoints",
|
||||
)
|
||||
sessions_list_parser.set_defaults(func=cmd_sessions_list)
|
||||
|
||||
# sessions show
|
||||
sessions_show_parser = sessions_subparsers.add_parser(
|
||||
"show",
|
||||
help="Show session details",
|
||||
description="Display detailed information about a specific session.",
|
||||
)
|
||||
sessions_show_parser.add_argument(
|
||||
"agent_path",
|
||||
type=str,
|
||||
help="Path to agent folder",
|
||||
)
|
||||
sessions_show_parser.add_argument(
|
||||
"session_id",
|
||||
type=str,
|
||||
help="Session ID to inspect",
|
||||
)
|
||||
sessions_show_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Output as JSON",
|
||||
)
|
||||
sessions_show_parser.set_defaults(func=cmd_sessions_show)
|
||||
|
||||
# sessions checkpoints
|
||||
sessions_checkpoints_parser = sessions_subparsers.add_parser(
|
||||
"checkpoints",
|
||||
help="List session checkpoints",
|
||||
description="List all checkpoints for a session.",
|
||||
)
|
||||
sessions_checkpoints_parser.add_argument(
|
||||
"agent_path",
|
||||
type=str,
|
||||
help="Path to agent folder",
|
||||
)
|
||||
sessions_checkpoints_parser.add_argument(
|
||||
"session_id",
|
||||
type=str,
|
||||
help="Session ID",
|
||||
)
|
||||
sessions_checkpoints_parser.set_defaults(func=cmd_sessions_checkpoints)
|
||||
|
||||
# pause command
|
||||
pause_parser = subparsers.add_parser(
|
||||
"pause",
|
||||
help="Pause running session",
|
||||
description="Request graceful pause of a running agent session.",
|
||||
)
|
||||
pause_parser.add_argument(
|
||||
"agent_path",
|
||||
type=str,
|
||||
help="Path to agent folder",
|
||||
)
|
||||
pause_parser.add_argument(
|
||||
"session_id",
|
||||
type=str,
|
||||
help="Session ID to pause",
|
||||
)
|
||||
pause_parser.set_defaults(func=cmd_pause)
|
||||
|
||||
# resume command
|
||||
resume_parser = subparsers.add_parser(
|
||||
"resume",
|
||||
help="Resume session from checkpoint",
|
||||
description="Resume a paused or failed session from a checkpoint.",
|
||||
)
|
||||
resume_parser.add_argument(
|
||||
"agent_path",
|
||||
type=str,
|
||||
help="Path to agent folder",
|
||||
)
|
||||
resume_parser.add_argument(
|
||||
"session_id",
|
||||
type=str,
|
||||
help="Session ID to resume",
|
||||
)
|
||||
resume_parser.add_argument(
|
||||
"--checkpoint",
|
||||
"-c",
|
||||
type=str,
|
||||
help="Specific checkpoint ID to resume from (default: latest)",
|
||||
)
|
||||
resume_parser.add_argument(
|
||||
"--tui",
|
||||
action="store_true",
|
||||
help="Resume in TUI dashboard mode",
|
||||
)
|
||||
resume_parser.set_defaults(func=cmd_resume)
|
||||
|
||||
|
||||
def cmd_run(args: argparse.Namespace) -> int:
|
||||
"""Run an exported agent."""
|
||||
@@ -253,7 +388,11 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
if runner._agent_runtime and not runner._agent_runtime.is_running:
|
||||
await runner._agent_runtime.start()
|
||||
|
||||
app = AdenTUI(runner._agent_runtime)
|
||||
app = AdenTUI(
|
||||
runner._agent_runtime,
|
||||
resume_session=getattr(args, "resume_session", None),
|
||||
resume_checkpoint=getattr(args, "checkpoint", None),
|
||||
)
|
||||
|
||||
# TUI handles execution via ChatRepl — user submits input,
|
||||
# ChatRepl calls runtime.trigger_and_wait(). No auto-launch.
|
||||
@@ -1432,3 +1571,53 @@ def _interactive_multi(agents_dir: Path) -> int:
|
||||
|
||||
orchestrator.cleanup()
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_sessions_list(args: argparse.Namespace) -> int:
|
||||
"""List agent sessions."""
|
||||
print("⚠ Sessions list command not yet implemented")
|
||||
print("This will be available once checkpoint infrastructure is complete.")
|
||||
print(f"\nAgent: {args.agent_path}")
|
||||
print(f"Status filter: {args.status}")
|
||||
print(f"Has checkpoints: {args.has_checkpoints}")
|
||||
return 1
|
||||
|
||||
|
||||
def cmd_sessions_show(args: argparse.Namespace) -> int:
|
||||
"""Show detailed session information."""
|
||||
print("⚠ Session show command not yet implemented")
|
||||
print("This will be available once checkpoint infrastructure is complete.")
|
||||
print(f"\nAgent: {args.agent_path}")
|
||||
print(f"Session: {args.session_id}")
|
||||
return 1
|
||||
|
||||
|
||||
def cmd_sessions_checkpoints(args: argparse.Namespace) -> int:
|
||||
"""List checkpoints for a session."""
|
||||
print("⚠ Session checkpoints command not yet implemented")
|
||||
print("This will be available once checkpoint infrastructure is complete.")
|
||||
print(f"\nAgent: {args.agent_path}")
|
||||
print(f"Session: {args.session_id}")
|
||||
return 1
|
||||
|
||||
|
||||
def cmd_pause(args: argparse.Namespace) -> int:
|
||||
"""Pause a running session."""
|
||||
print("⚠ Pause command not yet implemented")
|
||||
print("This will be available once executor pause integration is complete.")
|
||||
print(f"\nAgent: {args.agent_path}")
|
||||
print(f"Session: {args.session_id}")
|
||||
return 1
|
||||
|
||||
|
||||
def cmd_resume(args: argparse.Namespace) -> int:
|
||||
"""Resume a session from checkpoint."""
|
||||
print("⚠ Resume command not yet implemented")
|
||||
print("This will be available once checkpoint resume integration is complete.")
|
||||
print(f"\nAgent: {args.agent_path}")
|
||||
print(f"Session: {args.session_id}")
|
||||
if args.checkpoint:
|
||||
print(f"Checkpoint: {args.checkpoint}")
|
||||
if args.tui:
|
||||
print("Mode: TUI")
|
||||
return 1
|
||||
|
||||
@@ -741,6 +741,17 @@ class AgentRunner:
|
||||
# Create AgentRuntime with all entry points
|
||||
log_store = RuntimeLogStore(base_path=self._storage_path / "runtime_logs")
|
||||
|
||||
# Enable checkpointing by default for resumable sessions
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
|
||||
checkpoint_config = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=False, # Only checkpoint after nodes complete
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
async_checkpoint=True, # Non-blocking
|
||||
)
|
||||
|
||||
self._agent_runtime = create_agent_runtime(
|
||||
graph=self.graph,
|
||||
goal=self.goal,
|
||||
@@ -750,6 +761,7 @@ class AgentRunner:
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
runtime_log_store=log_store,
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -0,0 +1,842 @@
|
||||
# Resumable Sessions Design
|
||||
|
||||
## Problem Statement
|
||||
|
||||
Currently, when an agent encounters a failure during execution (e.g., credential validation, API errors, tool failures), the entire session is lost. This creates a poor user experience, especially when:
|
||||
|
||||
1. The agent has completed significant work before the failure
|
||||
2. The failure is recoverable (e.g., adding missing credentials)
|
||||
3. The user wants to retry from the exact failure point without redoing work
|
||||
|
||||
## Design Goals
|
||||
|
||||
1. **Crash Recovery**: Sessions can resume after process crashes or errors
|
||||
2. **Partial Completion**: Preserve work done by nodes that completed successfully
|
||||
3. **Flexible Resume Points**: Resume from exact failure point or previous checkpoints
|
||||
4. **State Consistency**: Guarantee consistent SharedMemory and conversation state
|
||||
5. **Minimal Overhead**: Checkpointing shouldn't significantly impact performance
|
||||
6. **User Control**: Users can inspect, modify, and resume sessions explicitly
|
||||
|
||||
## Architecture
|
||||
|
||||
### 1. Checkpoint System
|
||||
|
||||
#### Checkpoint Types
|
||||
|
||||
**Automatic Checkpoints** (saved automatically by framework):
|
||||
- `node_start`: Before each node begins execution
|
||||
- `node_complete`: After each node successfully completes
|
||||
- `edge_transition`: Before traversing to next node
|
||||
- `loop_iteration`: At each iteration in EventLoopNode (optional)
|
||||
|
||||
**Manual Checkpoints** (triggered by agent designer):
|
||||
- `safe_point`: Explicitly marked safe points in graph
|
||||
- `user_checkpoint`: Before awaiting user input in client-facing nodes
|
||||
|
||||
#### Checkpoint Data Structure
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
"""Single checkpoint in execution timeline."""
|
||||
|
||||
# Identity
|
||||
checkpoint_id: str # Format: checkpoint_{timestamp}_{uuid_short}
|
||||
session_id: str
|
||||
checkpoint_type: str # "node_start", "node_complete", etc.
|
||||
|
||||
# Timestamps
|
||||
created_at: str # ISO 8601
|
||||
|
||||
# Execution state
|
||||
current_node: str | None
|
||||
next_node: str | None # For edge_transition checkpoints
|
||||
execution_path: list[str] # Nodes executed so far
|
||||
|
||||
# Memory state (snapshot)
|
||||
shared_memory: dict[str, Any] # Full SharedMemory._data
|
||||
|
||||
# Per-node conversation state references
|
||||
# (actual conversations stored separately, reference by node_id)
|
||||
conversation_states: dict[str, str] # {node_id: conversation_checkpoint_id}
|
||||
|
||||
# Output accumulator state
|
||||
accumulated_outputs: dict[str, Any]
|
||||
|
||||
# Execution metrics (for resuming quality tracking)
|
||||
metrics_snapshot: dict[str, Any]
|
||||
|
||||
# Metadata
|
||||
is_clean: bool # True if no failures/retries before this checkpoint
|
||||
can_resume_from: bool # False if checkpoint is in unstable state
|
||||
description: str # Human-readable checkpoint description
|
||||
```
|
||||
|
||||
#### Storage Structure
|
||||
|
||||
```
|
||||
~/.hive/agents/{agent_name}/
|
||||
└── sessions/
|
||||
└── session_YYYYMMDD_HHMMSS_{uuid}/
|
||||
├── state.json # Session state (existing)
|
||||
├── checkpoints/
|
||||
│ ├── index.json # Checkpoint index/manifest
|
||||
│ ├── checkpoint_1.json # Individual checkpoints
|
||||
│ ├── checkpoint_2.json
|
||||
│ └── checkpoint_N.json
|
||||
├── conversations/ # Per-node conversation state (existing)
|
||||
│ ├── node_id_1/
|
||||
│ │ ├── parts/
|
||||
│ │ ├── meta.json
|
||||
│ │ └── cursor.json
|
||||
│ └── node_id_2/...
|
||||
├── data/ # Spillover artifacts (existing)
|
||||
└── logs/ # L1/L2/L3 logs (existing)
|
||||
```
|
||||
|
||||
**Checkpoint Index Format** (`checkpoints/index.json`):
|
||||
```json
|
||||
{
|
||||
"session_id": "session_20260208_143022_abc12345",
|
||||
"checkpoints": [
|
||||
{
|
||||
"checkpoint_id": "checkpoint_20260208_143030_xyz123",
|
||||
"type": "node_complete",
|
||||
"created_at": "2026-02-08T14:30:30.123Z",
|
||||
"current_node": "collector",
|
||||
"is_clean": true,
|
||||
"can_resume_from": true,
|
||||
"description": "Completed collector node successfully"
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "checkpoint_20260208_143045_abc789",
|
||||
"type": "node_start",
|
||||
"created_at": "2026-02-08T14:30:45.456Z",
|
||||
"current_node": "analyzer",
|
||||
"is_clean": true,
|
||||
"can_resume_from": true,
|
||||
"description": "Starting analyzer node"
|
||||
}
|
||||
],
|
||||
"latest_checkpoint_id": "checkpoint_20260208_143045_abc789",
|
||||
"total_checkpoints": 2
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Resume Mechanism
|
||||
|
||||
#### Resume Flow
|
||||
|
||||
```python
|
||||
# High-level resume flow
|
||||
async def resume_session(
|
||||
session_id: str,
|
||||
checkpoint_id: str | None = None, # None = resume from latest
|
||||
modifications: dict[str, Any] | None = None, # Override memory values
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Resume a session from a checkpoint.
|
||||
|
||||
Args:
|
||||
session_id: Session to resume
|
||||
checkpoint_id: Specific checkpoint (None = latest)
|
||||
modifications: Optional memory/state modifications before resume
|
||||
|
||||
Returns:
|
||||
ExecutionResult with resumed execution
|
||||
"""
|
||||
# 1. Load session state
|
||||
session_state = await session_store.read_state(session_id)
|
||||
|
||||
# 2. Verify session is resumable
|
||||
if not session_state.is_resumable:
|
||||
raise ValueError(f"Session {session_id} is not resumable")
|
||||
|
||||
# 3. Load checkpoint
|
||||
checkpoint = await checkpoint_store.load_checkpoint(
|
||||
session_id,
|
||||
checkpoint_id or session_state.progress.resume_from
|
||||
)
|
||||
|
||||
# 4. Restore state
|
||||
# - Restore SharedMemory from checkpoint.shared_memory
|
||||
# - Restore per-node conversations from checkpoint.conversation_states
|
||||
# - Restore output accumulator from checkpoint.accumulated_outputs
|
||||
# - Apply modifications if provided
|
||||
|
||||
# 5. Resume execution from checkpoint.next_node or checkpoint.current_node
|
||||
result = await executor.execute(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
memory=restored_memory,
|
||||
entry_point=checkpoint.next_node or checkpoint.current_node,
|
||||
session_state=restored_session_state,
|
||||
)
|
||||
|
||||
# 6. Update session state with resumed execution
|
||||
await session_store.write_state(session_id, updated_state)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
#### Checkpoint Restoration
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CheckpointStore:
|
||||
"""Manages checkpoint storage and retrieval."""
|
||||
|
||||
async def save_checkpoint(
|
||||
self,
|
||||
session_id: str,
|
||||
checkpoint: Checkpoint,
|
||||
) -> None:
|
||||
"""Save a checkpoint atomically."""
|
||||
# 1. Write checkpoint file: checkpoints/checkpoint_{id}.json
|
||||
# 2. Update index: checkpoints/index.json
|
||||
# 3. Use atomic write for crash safety
|
||||
|
||||
async def load_checkpoint(
|
||||
self,
|
||||
session_id: str,
|
||||
checkpoint_id: str | None = None,
|
||||
) -> Checkpoint | None:
|
||||
"""Load a checkpoint by ID or latest."""
|
||||
# 1. Read checkpoint index
|
||||
# 2. Find checkpoint by ID (or latest if None)
|
||||
# 3. Load and deserialize checkpoint file
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
session_id: str,
|
||||
checkpoint_type: str | None = None,
|
||||
is_clean: bool | None = None,
|
||||
) -> list[Checkpoint]:
|
||||
"""List all checkpoints for a session with optional filters."""
|
||||
|
||||
async def delete_checkpoint(
|
||||
self,
|
||||
session_id: str,
|
||||
checkpoint_id: str,
|
||||
) -> bool:
|
||||
"""Delete a specific checkpoint."""
|
||||
|
||||
async def prune_checkpoints(
|
||||
self,
|
||||
session_id: str,
|
||||
keep_count: int = 10,
|
||||
keep_clean_only: bool = False,
|
||||
) -> int:
|
||||
"""Prune old checkpoints, keeping most recent N."""
|
||||
```
|
||||
|
||||
### 3. GraphExecutor Integration
|
||||
|
||||
#### Modified Execution Loop
|
||||
|
||||
```python
|
||||
# In GraphExecutor.execute()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
memory: SharedMemory | None = None,
|
||||
entry_point: str = "start",
|
||||
session_state: dict[str, Any] | None = None,
|
||||
checkpoint_config: CheckpointConfig | None = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute graph with checkpointing support.
|
||||
|
||||
New parameters:
|
||||
checkpoint_config: Configuration for checkpointing behavior
|
||||
"""
|
||||
|
||||
# Initialize checkpoint store
|
||||
checkpoint_store = CheckpointStore(storage_path / "checkpoints")
|
||||
|
||||
# Restore from checkpoint if session_state indicates resume
|
||||
if session_state and session_state.get("resume_from"):
|
||||
checkpoint = await checkpoint_store.load_checkpoint(
|
||||
session_id,
|
||||
session_state["resume_from"]
|
||||
)
|
||||
memory = self._restore_memory_from_checkpoint(checkpoint)
|
||||
entry_point = checkpoint.next_node or checkpoint.current_node
|
||||
|
||||
current_node = entry_point
|
||||
|
||||
while current_node:
|
||||
# CHECKPOINT: node_start
|
||||
if checkpoint_config and checkpoint_config.checkpoint_on_node_start:
|
||||
await self._save_checkpoint(
|
||||
checkpoint_store,
|
||||
checkpoint_type="node_start",
|
||||
current_node=current_node,
|
||||
memory=memory,
|
||||
# ... other state
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute node
|
||||
result = await self._execute_node(current_node, memory, context)
|
||||
|
||||
# CHECKPOINT: node_complete
|
||||
if checkpoint_config and checkpoint_config.checkpoint_on_node_complete:
|
||||
await self._save_checkpoint(
|
||||
checkpoint_store,
|
||||
checkpoint_type="node_complete",
|
||||
current_node=current_node,
|
||||
memory=memory,
|
||||
# ... other state
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# On failure, mark current checkpoint as resume point
|
||||
await self._mark_failure_checkpoint(
|
||||
checkpoint_store,
|
||||
current_node=current_node,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
# Find next edge
|
||||
next_node = self._find_next_node(current_node, result, memory)
|
||||
|
||||
# CHECKPOINT: edge_transition
|
||||
if next_node and checkpoint_config and checkpoint_config.checkpoint_on_edge:
|
||||
await self._save_checkpoint(
|
||||
checkpoint_store,
|
||||
checkpoint_type="edge_transition",
|
||||
current_node=current_node,
|
||||
next_node=next_node,
|
||||
memory=memory,
|
||||
# ... other state
|
||||
)
|
||||
|
||||
current_node = next_node
|
||||
```
|
||||
|
||||
### 4. EventLoopNode Integration
|
||||
|
||||
#### Conversation State Checkpointing
|
||||
|
||||
EventLoopNode already has conversation persistence via `ConversationStore`. For resumability:
|
||||
|
||||
```python
|
||||
class EventLoopNode:
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
"""Execute with checkpoint support."""
|
||||
|
||||
# Try to restore from checkpoint
|
||||
if ctx.checkpoint_id:
|
||||
conversation = await self._restore_conversation(ctx.checkpoint_id)
|
||||
output_accumulator = await OutputAccumulator.restore(self.store)
|
||||
else:
|
||||
# Fresh start
|
||||
conversation = await self._initialize_conversation(ctx)
|
||||
output_accumulator = OutputAccumulator(store=self.store)
|
||||
|
||||
# Event loop with periodic checkpointing
|
||||
iteration = 0
|
||||
while iteration < self.config.max_iterations:
|
||||
|
||||
# Optional: checkpoint every N iterations
|
||||
if self.config.checkpoint_every_n_iterations:
|
||||
if iteration % self.config.checkpoint_every_n_iterations == 0:
|
||||
await self._save_loop_checkpoint(
|
||||
conversation,
|
||||
output_accumulator,
|
||||
iteration,
|
||||
)
|
||||
|
||||
# ... rest of event loop
|
||||
|
||||
iteration += 1
|
||||
```
|
||||
|
||||
**Note**: EventLoopNode conversation state is already persisted to disk after each turn via `ConversationStore`, so it's naturally resumable. We just need to:
|
||||
1. Track which conversation checkpoint to restore from
|
||||
2. Ensure output accumulator state is also restored
|
||||
|
||||
### 5. User-Facing API
|
||||
|
||||
#### MCP Tools for Resume
|
||||
|
||||
```python
|
||||
# In tools/src/aden_tools/tools/session_management/
|
||||
|
||||
@tool
|
||||
async def list_resumable_sessions(
|
||||
agent_work_dir: str,
|
||||
status: str = "failed", # "failed", "paused", "cancelled"
|
||||
limit: int = 20,
|
||||
) -> dict:
|
||||
"""
|
||||
List sessions that can be resumed.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"sessions": [
|
||||
{
|
||||
"session_id": "session_20260208_143022_abc12345",
|
||||
"status": "failed",
|
||||
"error": "Missing API key: OPENAI_API_KEY",
|
||||
"failed_at_node": "analyzer",
|
||||
"last_checkpoint": "checkpoint_20260208_143045_abc789",
|
||||
"created_at": "2026-02-08T14:30:22Z",
|
||||
"updated_at": "2026-02-08T14:30:45Z"
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def list_session_checkpoints(
|
||||
agent_work_dir: str,
|
||||
session_id: str,
|
||||
checkpoint_type: str = "", # Filter by type
|
||||
clean_only: bool = False, # Only show clean checkpoints
|
||||
) -> dict:
|
||||
"""
|
||||
List all checkpoints for a session.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"session_id": "session_20260208_143022_abc12345",
|
||||
"checkpoints": [
|
||||
{
|
||||
"checkpoint_id": "checkpoint_20260208_143030_xyz123",
|
||||
"type": "node_complete",
|
||||
"created_at": "2026-02-08T14:30:30Z",
|
||||
"current_node": "collector",
|
||||
"is_clean": true,
|
||||
"can_resume_from": true,
|
||||
"description": "Completed collector node successfully"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def inspect_checkpoint(
|
||||
agent_work_dir: str,
|
||||
session_id: str,
|
||||
checkpoint_id: str,
|
||||
include_memory: bool = False, # Include full memory state
|
||||
) -> dict:
|
||||
"""
|
||||
Inspect a checkpoint's detailed state.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"checkpoint_id": "checkpoint_20260208_143030_xyz123",
|
||||
"type": "node_complete",
|
||||
"current_node": "collector",
|
||||
"execution_path": ["start", "collector"],
|
||||
"accumulated_outputs": {
|
||||
"twitter_handles": ["@user1", "@user2"]
|
||||
},
|
||||
"memory": {...}, # If include_memory=True
|
||||
"metrics_snapshot": {
|
||||
"total_retries": 2,
|
||||
"nodes_with_failures": []
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def resume_session(
|
||||
agent_work_dir: str,
|
||||
session_id: str,
|
||||
checkpoint_id: str = "", # Empty = latest checkpoint
|
||||
memory_modifications: str = "{}", # JSON string of memory overrides
|
||||
) -> dict:
|
||||
"""
|
||||
Resume a session from a checkpoint.
|
||||
|
||||
Args:
|
||||
agent_work_dir: Path to agent workspace
|
||||
session_id: Session to resume
|
||||
checkpoint_id: Specific checkpoint (empty = latest)
|
||||
memory_modifications: JSON object with memory key overrides
|
||||
|
||||
Returns:
|
||||
{
|
||||
"session_id": "session_20260208_143022_abc12345",
|
||||
"resumed_from": "checkpoint_20260208_143045_abc789",
|
||||
"status": "active", # Now actively running
|
||||
"message": "Session resumed successfully from checkpoint_20260208_143045_abc789"
|
||||
}
|
||||
"""
|
||||
```
|
||||
|
||||
#### CLI Commands
|
||||
|
||||
```bash
|
||||
# List resumable sessions
|
||||
hive sessions list --agent twitter_outreach --status failed
|
||||
|
||||
# Show checkpoints for a session
|
||||
hive sessions checkpoints session_20260208_143022_abc12345
|
||||
|
||||
# Inspect a checkpoint
|
||||
hive sessions inspect session_20260208_143022_abc12345 checkpoint_20260208_143045_abc789
|
||||
|
||||
# Resume a session
|
||||
hive sessions resume session_20260208_143022_abc12345
|
||||
|
||||
# Resume from specific checkpoint
|
||||
hive sessions resume session_20260208_143022_abc12345 --checkpoint checkpoint_20260208_143030_xyz123
|
||||
|
||||
# Resume with memory modifications (e.g., after adding credentials)
|
||||
hive sessions resume session_20260208_143022_abc12345 --set api_key=sk-...
|
||||
```
|
||||
|
||||
### 6. Configuration
|
||||
|
||||
#### CheckpointConfig
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CheckpointConfig:
|
||||
"""Configuration for checkpoint behavior."""
|
||||
|
||||
# When to checkpoint
|
||||
checkpoint_on_node_start: bool = True
|
||||
checkpoint_on_node_complete: bool = True
|
||||
checkpoint_on_edge: bool = False # Usually redundant with node_start
|
||||
checkpoint_on_loop_iteration: bool = False # Can be expensive
|
||||
checkpoint_every_n_iterations: int = 0 # 0 = disabled
|
||||
|
||||
# Pruning
|
||||
max_checkpoints_per_session: int = 100
|
||||
prune_after_node_count: int = 10 # Prune every N nodes
|
||||
keep_clean_checkpoints_only: bool = False
|
||||
|
||||
# Performance
|
||||
async_checkpoint: bool = True # Don't block execution on checkpoint writes
|
||||
|
||||
# What to include
|
||||
include_conversation_snapshots: bool = True
|
||||
include_full_memory: bool = True
|
||||
```
|
||||
|
||||
#### Agent-Level Configuration
|
||||
|
||||
```python
|
||||
# In agent.py or config.py
|
||||
|
||||
class MyAgent(Agent):
|
||||
def get_checkpoint_config(self) -> CheckpointConfig:
|
||||
"""Override to customize checkpoint behavior."""
|
||||
return CheckpointConfig(
|
||||
checkpoint_on_node_start=True,
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_every_n_iterations=5, # Checkpoint every 5 iterations in loops
|
||||
max_checkpoints_per_session=50,
|
||||
)
|
||||
```
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Phase 1: Core Checkpoint Infrastructure (Week 1)
|
||||
|
||||
1. **Create checkpoint schemas**
|
||||
- `Checkpoint` dataclass
|
||||
- `CheckpointIndex` for manifest
|
||||
- Serialization/deserialization
|
||||
|
||||
2. **Implement CheckpointStore**
|
||||
- `save_checkpoint()` with atomic writes
|
||||
- `load_checkpoint()` with deserialization
|
||||
- `list_checkpoints()` with filtering
|
||||
- `prune_checkpoints()` for cleanup
|
||||
|
||||
3. **Update SessionState schema**
|
||||
- Add `resume_from_checkpoint_id` field
|
||||
- Add `checkpoints_enabled` flag
|
||||
|
||||
### Phase 2: GraphExecutor Integration (Week 2)
|
||||
|
||||
1. **Modify GraphExecutor**
|
||||
- Add `CheckpointConfig` parameter
|
||||
- Implement checkpoint saving at node boundaries
|
||||
- Implement checkpoint restoration logic
|
||||
- Handle memory state snapshots
|
||||
|
||||
2. **Update execution loop**
|
||||
- Checkpoint before node execution
|
||||
- Checkpoint after successful completion
|
||||
- Mark failure checkpoints on errors
|
||||
|
||||
### Phase 3: EventLoopNode Integration (Week 3)
|
||||
|
||||
1. **Enhance conversation restoration**
|
||||
- Link checkpoints to conversation states
|
||||
- Ensure OutputAccumulator is checkpointed
|
||||
- Test loop resumption from middle of execution
|
||||
|
||||
2. **Add optional loop iteration checkpoints**
|
||||
- Configurable iteration frequency
|
||||
- Balance between granularity and performance
|
||||
|
||||
### Phase 4: User-Facing Features (Week 4)
|
||||
|
||||
1. **Implement MCP tools**
|
||||
- `list_resumable_sessions`
|
||||
- `list_session_checkpoints`
|
||||
- `inspect_checkpoint`
|
||||
- `resume_session`
|
||||
|
||||
2. **Add CLI commands**
|
||||
- `hive sessions list`
|
||||
- `hive sessions checkpoints`
|
||||
- `hive sessions inspect`
|
||||
- `hive sessions resume`
|
||||
|
||||
3. **Update TUI**
|
||||
- Show resumable sessions in UI
|
||||
- Allow resume from TUI interface
|
||||
|
||||
### Phase 5: Testing & Documentation (Week 5)
|
||||
|
||||
1. **Write comprehensive tests**
|
||||
- Unit tests for CheckpointStore
|
||||
- Integration tests for resume flow
|
||||
- Edge case testing (concurrent checkpoints, corruption, etc.)
|
||||
|
||||
2. **Performance testing**
|
||||
- Measure checkpoint overhead
|
||||
- Optimize async checkpoint writing
|
||||
- Test with large memory states
|
||||
|
||||
3. **Documentation**
|
||||
- Update skills with resume patterns
|
||||
- Document checkpoint configuration
|
||||
- Add troubleshooting guide
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Checkpoint Overhead
|
||||
|
||||
**Estimated overhead per checkpoint**:
|
||||
- Memory serialization: ~5-10ms for typical state (< 1MB)
|
||||
- File I/O: ~10-20ms for atomic write
|
||||
- Total: ~15-30ms per checkpoint
|
||||
|
||||
**Mitigation strategies**:
|
||||
1. **Async checkpointing**: Don't block execution on writes
|
||||
2. **Selective checkpointing**: Only checkpoint at important boundaries
|
||||
3. **Incremental checkpoints**: Store deltas instead of full state (future)
|
||||
4. **Compression**: Compress large memory states before writing
|
||||
|
||||
### Storage Size
|
||||
|
||||
**Typical checkpoint size**:
|
||||
- Small memory state (< 100KB): ~50-100KB per checkpoint
|
||||
- Medium memory state (< 1MB): ~500KB-1MB per checkpoint
|
||||
- Large memory state (> 1MB): ~1-5MB per checkpoint
|
||||
|
||||
**Mitigation strategies**:
|
||||
1. **Pruning**: Keep only N most recent checkpoints
|
||||
2. **Clean-only retention**: Only keep checkpoints from clean execution
|
||||
3. **Compression**: Use gzip for checkpoint files
|
||||
4. **Archiving**: Move old checkpoints to archive storage
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Checkpoint Save Failures
|
||||
|
||||
**Scenarios**:
|
||||
- Disk full
|
||||
- Permission errors
|
||||
- Serialization failures
|
||||
- Concurrent writes
|
||||
|
||||
**Handling**:
|
||||
```python
|
||||
try:
|
||||
await checkpoint_store.save_checkpoint(session_id, checkpoint)
|
||||
except CheckpointSaveError as e:
|
||||
# Log warning but don't fail execution
|
||||
logger.warning(f"Failed to save checkpoint: {e}")
|
||||
# Continue execution without checkpoint
|
||||
```
|
||||
|
||||
### Checkpoint Load Failures
|
||||
|
||||
**Scenarios**:
|
||||
- Checkpoint file corrupted
|
||||
- Checkpoint format incompatible
|
||||
- Referenced conversation state missing
|
||||
|
||||
**Handling**:
|
||||
```python
|
||||
try:
|
||||
checkpoint = await checkpoint_store.load_checkpoint(session_id, checkpoint_id)
|
||||
except CheckpointLoadError as e:
|
||||
# Try to find previous valid checkpoint
|
||||
checkpoints = await checkpoint_store.list_checkpoints(session_id)
|
||||
for cp in reversed(checkpoints):
|
||||
try:
|
||||
checkpoint = await checkpoint_store.load_checkpoint(session_id, cp.checkpoint_id)
|
||||
logger.info(f"Fell back to checkpoint {cp.checkpoint_id}")
|
||||
break
|
||||
except CheckpointLoadError:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"No valid checkpoints found for session {session_id}")
|
||||
```
|
||||
|
||||
### Resume Failures
|
||||
|
||||
**Scenarios**:
|
||||
- Checkpoint state inconsistent with current graph
|
||||
- Node no longer exists in updated agent code
|
||||
- Memory keys missing required values
|
||||
|
||||
**Handling**:
|
||||
1. **Validation**: Verify checkpoint compatibility before resume
|
||||
2. **Graceful degradation**: Resume from earlier checkpoint if possible
|
||||
3. **User notification**: Clear error messages about why resume failed
|
||||
|
||||
## Migration Path
|
||||
|
||||
### Backward Compatibility
|
||||
|
||||
**Existing sessions** (without checkpoints):
|
||||
- Can still be executed normally
|
||||
- Checkpoint system is opt-in per agent
|
||||
- No breaking changes to existing APIs
|
||||
|
||||
**Enabling checkpoints**:
|
||||
```python
|
||||
# Option 1: Agent-level default
|
||||
class MyAgent(Agent):
|
||||
checkpoint_config = CheckpointConfig(
|
||||
checkpoint_on_node_complete=True,
|
||||
)
|
||||
|
||||
# Option 2: Runtime override
|
||||
runtime = create_agent_runtime(
|
||||
agent=my_agent,
|
||||
checkpoint_config=CheckpointConfig(...),
|
||||
)
|
||||
|
||||
# Option 3: Per-execution
|
||||
result = await executor.execute(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
checkpoint_config=CheckpointConfig(...),
|
||||
)
|
||||
```
|
||||
|
||||
### Gradual Rollout
|
||||
|
||||
1. **Phase 1**: Core infrastructure, no user-facing features
|
||||
2. **Phase 2**: Opt-in for specific agents via config
|
||||
3. **Phase 3**: User-facing MCP tools and CLI
|
||||
4. **Phase 4**: Enable by default for all new agents
|
||||
5. **Phase 5**: TUI integration
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### 1. Incremental Checkpoints
|
||||
|
||||
Instead of full state snapshots, store only deltas:
|
||||
```python
|
||||
@dataclass
|
||||
class IncrementalCheckpoint:
|
||||
"""Checkpoint with only changed state."""
|
||||
base_checkpoint_id: str # Parent checkpoint
|
||||
memory_delta: dict[str, Any] # Only changed keys
|
||||
added_outputs: dict[str, Any] # Only new outputs
|
||||
```
|
||||
|
||||
### 2. Distributed Checkpointing
|
||||
|
||||
For long-running agents, checkpoint to cloud storage:
|
||||
```python
|
||||
checkpoint_config = CheckpointConfig(
|
||||
storage_backend="s3", # or "gcs", "azure"
|
||||
storage_url="s3://my-bucket/checkpoints/",
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Checkpoint Compression
|
||||
|
||||
Compress large memory states:
|
||||
```python
|
||||
checkpoint_config = CheckpointConfig(
|
||||
compress=True,
|
||||
compression_threshold_bytes=100_000, # Compress if > 100KB
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Smart Checkpoint Selection
|
||||
|
||||
Use heuristics to decide when to checkpoint:
|
||||
```python
|
||||
class SmartCheckpointStrategy:
|
||||
def should_checkpoint(self, context: ExecutionContext) -> bool:
|
||||
# Checkpoint after expensive nodes
|
||||
if context.node_latency_ms > 30_000:
|
||||
return True
|
||||
# Checkpoint before risky operations
|
||||
if context.node_id in ["api_call", "external_tool"]:
|
||||
return True
|
||||
# Checkpoint after significant memory changes
|
||||
if context.memory_delta_size > 10:
|
||||
return True
|
||||
return False
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### 1. Sensitive Data in Checkpoints
|
||||
|
||||
**Problem**: Checkpoints may contain sensitive data (API keys, credentials, PII)
|
||||
|
||||
**Mitigation**:
|
||||
```python
|
||||
@dataclass
|
||||
class CheckpointConfig:
|
||||
# Exclude sensitive keys from checkpoint
|
||||
exclude_memory_keys: list[str] = field(default_factory=lambda: [
|
||||
"api_key",
|
||||
"credentials",
|
||||
"access_token",
|
||||
])
|
||||
|
||||
# Encrypt checkpoint files
|
||||
encrypt_checkpoints: bool = True
|
||||
encryption_key_source: str = "keychain" # or "env_var", "file"
|
||||
```
|
||||
|
||||
### 2. Checkpoint Tampering
|
||||
|
||||
**Problem**: Malicious modification of checkpoint files
|
||||
|
||||
**Mitigation**:
|
||||
```python
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
# Add cryptographic signature
|
||||
signature: str # HMAC of checkpoint content
|
||||
|
||||
def verify_signature(self, secret_key: str) -> bool:
|
||||
"""Verify checkpoint hasn't been tampered with."""
|
||||
...
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [RUNTIME_LOGGING.md](./RUNTIME_LOGGING.md) - Current logging system
|
||||
- [session_state.py](../schemas/session_state.py) - Session state schema
|
||||
- [session_store.py](../storage/session_store.py) - Session storage
|
||||
- [executor.py](../graph/executor.py) - Graph executor
|
||||
- [event_loop_node.py](../graph/event_loop_node.py) - EventLoop implementation
|
||||
@@ -12,6 +12,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.runtime.execution_stream import EntryPointSpec, ExecutionStream
|
||||
@@ -102,6 +103,7 @@ class AgentRuntime:
|
||||
tool_executor: Callable | None = None,
|
||||
config: AgentRuntimeConfig | None = None,
|
||||
runtime_log_store: Any = None,
|
||||
checkpoint_config: CheckpointConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent runtime.
|
||||
@@ -115,11 +117,13 @@ class AgentRuntime:
|
||||
tool_executor: Function to execute tools
|
||||
config: Optional runtime configuration
|
||||
runtime_log_store: Optional RuntimeLogStore for per-execution logging
|
||||
checkpoint_config: Optional checkpoint configuration for resumable sessions
|
||||
"""
|
||||
self.graph = graph
|
||||
self.goal = goal
|
||||
self._config = config or AgentRuntimeConfig()
|
||||
self._runtime_log_store = runtime_log_store
|
||||
self._checkpoint_config = checkpoint_config
|
||||
|
||||
# Initialize storage
|
||||
storage_path_obj = Path(storage_path) if isinstance(storage_path, str) else storage_path
|
||||
@@ -222,6 +226,7 @@ class AgentRuntime:
|
||||
result_retention_ttl_seconds=self._config.execution_result_ttl_seconds,
|
||||
runtime_log_store=self._runtime_log_store,
|
||||
session_store=self._session_store,
|
||||
checkpoint_config=self._checkpoint_config,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
@@ -460,6 +465,7 @@ def create_agent_runtime(
|
||||
config: AgentRuntimeConfig | None = None,
|
||||
runtime_log_store: Any = None,
|
||||
enable_logging: bool = True,
|
||||
checkpoint_config: CheckpointConfig | None = None,
|
||||
) -> AgentRuntime:
|
||||
"""
|
||||
Create and configure an AgentRuntime with entry points.
|
||||
@@ -480,6 +486,8 @@ def create_agent_runtime(
|
||||
If None and enable_logging=True, creates one automatically.
|
||||
enable_logging: Whether to enable runtime logging (default: True).
|
||||
Set to False to disable logging entirely.
|
||||
checkpoint_config: Optional checkpoint configuration for resumable sessions.
|
||||
If None, uses default checkpointing behavior.
|
||||
|
||||
Returns:
|
||||
Configured AgentRuntime (not yet started)
|
||||
@@ -500,6 +508,7 @@ def create_agent_runtime(
|
||||
tool_executor=tool_executor,
|
||||
config=config,
|
||||
runtime_log_store=runtime_log_store,
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
|
||||
for spec in entry_points:
|
||||
|
||||
@@ -17,6 +17,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.executor import ExecutionResult, GraphExecutor
|
||||
from framework.runtime.shared_state import IsolationLevel, SharedStateManager
|
||||
from framework.runtime.stream_runtime import StreamRuntime, StreamRuntimeAdapter
|
||||
@@ -115,6 +116,7 @@ class ExecutionStream:
|
||||
result_retention_ttl_seconds: float | None = None,
|
||||
runtime_log_store: Any = None,
|
||||
session_store: "SessionStore | None" = None,
|
||||
checkpoint_config: CheckpointConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -133,6 +135,7 @@ class ExecutionStream:
|
||||
tool_executor: Function to execute tools
|
||||
runtime_log_store: Optional RuntimeLogStore for per-execution logging
|
||||
session_store: Optional SessionStore for unified session storage
|
||||
checkpoint_config: Optional checkpoint configuration for resumable sessions
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.entry_spec = entry_spec
|
||||
@@ -148,6 +151,7 @@ class ExecutionStream:
|
||||
self._result_retention_max = result_retention_max
|
||||
self._result_retention_ttl_seconds = result_retention_ttl_seconds
|
||||
self._runtime_log_store = runtime_log_store
|
||||
self._checkpoint_config = checkpoint_config
|
||||
self._session_store = session_store
|
||||
|
||||
# Create stream-scoped runtime
|
||||
@@ -400,6 +404,7 @@ class ExecutionStream:
|
||||
goal=self.goal,
|
||||
input_data=ctx.input_data,
|
||||
session_state=ctx.session_state,
|
||||
checkpoint_config=self._checkpoint_config,
|
||||
)
|
||||
|
||||
# Clean up executor reference
|
||||
@@ -437,8 +442,42 @@ class ExecutionStream:
|
||||
logger.debug(f"Execution {execution_id} completed: success={result.success}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
ctx.status = "cancelled"
|
||||
raise
|
||||
# Execution was cancelled
|
||||
# The executor catches CancelledError and returns a paused result,
|
||||
# but if cancellation happened before executor started, we won't have a result
|
||||
logger.info(f"Execution {execution_id} cancelled")
|
||||
|
||||
# Check if we have a result (executor completed and returned)
|
||||
try:
|
||||
_ = result # Check if result variable exists
|
||||
has_result = True
|
||||
except NameError:
|
||||
has_result = False
|
||||
result = ExecutionResult(
|
||||
success=False,
|
||||
error="Execution cancelled",
|
||||
)
|
||||
|
||||
# Update context status based on result
|
||||
if has_result and result.paused_at:
|
||||
ctx.status = "paused"
|
||||
ctx.completed_at = datetime.now()
|
||||
else:
|
||||
ctx.status = "cancelled"
|
||||
|
||||
# Clean up executor reference
|
||||
self._active_executors.pop(execution_id, None)
|
||||
|
||||
# Store result with retention
|
||||
self._record_execution_result(execution_id, result)
|
||||
|
||||
# Write session state
|
||||
if has_result and result.paused_at:
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
else:
|
||||
await self._write_session_state(execution_id, ctx, error="Execution cancelled")
|
||||
|
||||
# Don't re-raise - we've handled it and saved state
|
||||
|
||||
except Exception as e:
|
||||
ctx.status = "failed"
|
||||
@@ -511,7 +550,11 @@ class ExecutionStream:
|
||||
else:
|
||||
status = SessionStatus.FAILED
|
||||
elif error:
|
||||
status = SessionStatus.FAILED
|
||||
# Check if this is a cancellation
|
||||
if ctx.status == "cancelled" or "cancelled" in error.lower():
|
||||
status = SessionStatus.CANCELLED
|
||||
else:
|
||||
status = SessionStatus.FAILED
|
||||
else:
|
||||
status = SessionStatus.ACTIVE
|
||||
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Checkpoint Schema - Execution state snapshots for resumability.
|
||||
|
||||
Checkpoints capture the execution state at strategic points (node boundaries,
|
||||
iterations) to enable crash recovery and resume-from-failure scenarios.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Checkpoint(BaseModel):
|
||||
"""
|
||||
Single checkpoint in execution timeline.
|
||||
|
||||
Captures complete execution state at a specific point to enable
|
||||
resuming from that exact point after failures or pauses.
|
||||
"""
|
||||
|
||||
# Identity
|
||||
checkpoint_id: str # Format: cp_{type}_{node_id}_{timestamp}
|
||||
checkpoint_type: str # "node_start" | "node_complete" | "loop_iteration"
|
||||
session_id: str
|
||||
|
||||
# Timestamps
|
||||
created_at: str # ISO 8601 format
|
||||
|
||||
# Execution state
|
||||
current_node: str | None = None
|
||||
next_node: str | None = None # For edge_transition checkpoints
|
||||
execution_path: list[str] = Field(default_factory=list) # Nodes executed so far
|
||||
|
||||
# State snapshots
|
||||
shared_memory: dict[str, Any] = Field(default_factory=dict) # Full SharedMemory._data
|
||||
accumulated_outputs: dict[str, Any] = Field(default_factory=dict) # Outputs accumulated so far
|
||||
|
||||
# Execution metrics (for resuming quality tracking)
|
||||
metrics_snapshot: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Metadata
|
||||
is_clean: bool = True # True if no failures/retries before this checkpoint
|
||||
description: str = "" # Human-readable checkpoint description
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
checkpoint_type: str,
|
||||
session_id: str,
|
||||
current_node: str,
|
||||
execution_path: list[str],
|
||||
shared_memory: dict[str, Any],
|
||||
next_node: str | None = None,
|
||||
accumulated_outputs: dict[str, Any] | None = None,
|
||||
metrics_snapshot: dict[str, Any] | None = None,
|
||||
is_clean: bool = True,
|
||||
description: str = "",
|
||||
) -> "Checkpoint":
|
||||
"""
|
||||
Create a new checkpoint with generated ID and timestamp.
|
||||
|
||||
Args:
|
||||
checkpoint_type: Type of checkpoint (node_start, node_complete, etc.)
|
||||
session_id: Session this checkpoint belongs to
|
||||
current_node: Node ID at checkpoint time
|
||||
execution_path: List of node IDs executed so far
|
||||
shared_memory: Full memory state snapshot
|
||||
next_node: Next node to execute (for node_complete checkpoints)
|
||||
accumulated_outputs: Outputs accumulated so far
|
||||
metrics_snapshot: Execution metrics at checkpoint time
|
||||
is_clean: Whether execution was clean up to this point
|
||||
description: Human-readable description
|
||||
|
||||
Returns:
|
||||
New Checkpoint instance
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_id = f"cp_{checkpoint_type}_{current_node}_{timestamp}"
|
||||
|
||||
if not description:
|
||||
description = f"{checkpoint_type.replace('_', ' ').title()}: {current_node}"
|
||||
|
||||
return cls(
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_type=checkpoint_type,
|
||||
session_id=session_id,
|
||||
created_at=datetime.now().isoformat(),
|
||||
current_node=current_node,
|
||||
next_node=next_node,
|
||||
execution_path=execution_path,
|
||||
shared_memory=shared_memory,
|
||||
accumulated_outputs=accumulated_outputs or {},
|
||||
metrics_snapshot=metrics_snapshot or {},
|
||||
is_clean=is_clean,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
class CheckpointSummary(BaseModel):
|
||||
"""
|
||||
Lightweight checkpoint metadata for index listings.
|
||||
|
||||
Used in checkpoint index to provide fast scanning without
|
||||
loading full checkpoint data.
|
||||
"""
|
||||
|
||||
checkpoint_id: str
|
||||
checkpoint_type: str
|
||||
created_at: str
|
||||
current_node: str | None = None
|
||||
next_node: str | None = None
|
||||
is_clean: bool = True
|
||||
description: str = ""
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, checkpoint: Checkpoint) -> "CheckpointSummary":
|
||||
"""Create summary from full checkpoint."""
|
||||
return cls(
|
||||
checkpoint_id=checkpoint.checkpoint_id,
|
||||
checkpoint_type=checkpoint.checkpoint_type,
|
||||
created_at=checkpoint.created_at,
|
||||
current_node=checkpoint.current_node,
|
||||
next_node=checkpoint.next_node,
|
||||
is_clean=checkpoint.is_clean,
|
||||
description=checkpoint.description,
|
||||
)
|
||||
|
||||
|
||||
class CheckpointIndex(BaseModel):
|
||||
"""
|
||||
Manifest of all checkpoints for a session.
|
||||
|
||||
Provides fast lookup and filtering without loading
|
||||
full checkpoint files.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
checkpoints: list[CheckpointSummary] = Field(default_factory=list)
|
||||
latest_checkpoint_id: str | None = None
|
||||
total_checkpoints: int = 0
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def add_checkpoint(self, checkpoint: Checkpoint) -> None:
|
||||
"""Add a checkpoint to the index."""
|
||||
summary = CheckpointSummary.from_checkpoint(checkpoint)
|
||||
self.checkpoints.append(summary)
|
||||
self.latest_checkpoint_id = checkpoint.checkpoint_id
|
||||
self.total_checkpoints = len(self.checkpoints)
|
||||
|
||||
def get_checkpoint_summary(self, checkpoint_id: str) -> CheckpointSummary | None:
|
||||
"""Get checkpoint summary by ID."""
|
||||
for summary in self.checkpoints:
|
||||
if summary.checkpoint_id == checkpoint_id:
|
||||
return summary
|
||||
return None
|
||||
|
||||
def filter_by_type(self, checkpoint_type: str) -> list[CheckpointSummary]:
|
||||
"""Filter checkpoints by type."""
|
||||
return [cp for cp in self.checkpoints if cp.checkpoint_type == checkpoint_type]
|
||||
|
||||
def filter_by_node(self, node_id: str) -> list[CheckpointSummary]:
|
||||
"""Filter checkpoints by current_node."""
|
||||
return [cp for cp in self.checkpoints if cp.current_node == node_id]
|
||||
|
||||
def get_clean_checkpoints(self) -> list[CheckpointSummary]:
|
||||
"""Get all clean checkpoints (no failures before them)."""
|
||||
return [cp for cp in self.checkpoints if cp.is_clean]
|
||||
|
||||
def get_latest_clean_checkpoint(self) -> CheckpointSummary | None:
|
||||
"""Get the most recent clean checkpoint."""
|
||||
clean = self.get_clean_checkpoints()
|
||||
return clean[-1] if clean else None
|
||||
@@ -91,10 +91,11 @@ class SessionState(BaseModel):
|
||||
|
||||
Version History:
|
||||
- v1.0: Initial schema (2026-02-06)
|
||||
- v1.1: Added checkpoint support (2026-02-08)
|
||||
"""
|
||||
|
||||
# Schema version for forward/backward compatibility
|
||||
schema_version: str = "1.0"
|
||||
schema_version: str = "1.1"
|
||||
|
||||
# Identity
|
||||
session_id: str # Format: session_YYYYMMDD_HHMMSS_{uuid_8char}
|
||||
@@ -136,6 +137,10 @@ class SessionState(BaseModel):
|
||||
# Isolation level (from ExecutionContext)
|
||||
isolation_level: str = "shared"
|
||||
|
||||
# Checkpointing (for crash recovery and resume-from-failure)
|
||||
checkpoint_enabled: bool = False
|
||||
latest_checkpoint_id: str | None = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@computed_field
|
||||
@@ -154,6 +159,14 @@ class SessionState(BaseModel):
|
||||
"""Can this session be resumed?"""
|
||||
return self.status == SessionStatus.PAUSED and self.progress.resume_from is not None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def is_resumable_from_checkpoint(self) -> bool:
|
||||
"""Can this session be resumed from a checkpoint?"""
|
||||
# ANY session with checkpoints can be resumed (not just failed ones)
|
||||
# This enables: pause/resume, iterative execution, continuation after completion
|
||||
return self.checkpoint_enabled and self.latest_checkpoint_id is not None
|
||||
|
||||
@classmethod
|
||||
def from_execution_result(
|
||||
cls,
|
||||
|
||||
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
Checkpoint Store - Manages checkpoint storage with atomic writes.
|
||||
|
||||
Handles saving, loading, listing, and pruning of execution checkpoints
|
||||
for session resumability.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from framework.schemas.checkpoint import Checkpoint, CheckpointIndex, CheckpointSummary
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointStore:
|
||||
"""
|
||||
Manages checkpoint storage with atomic writes.
|
||||
|
||||
Stores checkpoints in a session's checkpoints/ directory with
|
||||
an index for fast lookup and filtering.
|
||||
|
||||
Directory structure:
|
||||
checkpoints/
|
||||
index.json # Checkpoint manifest
|
||||
cp_{type}_{node}_{timestamp}.json # Individual checkpoints
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
"""
|
||||
Initialize checkpoint store.
|
||||
|
||||
Args:
|
||||
base_path: Session directory (e.g., ~/.hive/agents/agent_name/sessions/session_ID/)
|
||||
"""
|
||||
self.base_path = Path(base_path)
|
||||
self.checkpoints_dir = self.base_path / "checkpoints"
|
||||
self.index_path = self.checkpoints_dir / "index.json"
|
||||
self._index_lock = asyncio.Lock()
|
||||
|
||||
async def save_checkpoint(self, checkpoint: Checkpoint) -> None:
|
||||
"""
|
||||
Atomically save checkpoint and update index.
|
||||
|
||||
Uses temp file + rename for crash safety. Updates index
|
||||
after checkpoint is persisted.
|
||||
|
||||
Args:
|
||||
checkpoint: Checkpoint to save
|
||||
|
||||
Raises:
|
||||
OSError: If file write fails
|
||||
"""
|
||||
|
||||
def _write():
|
||||
# Ensure directory exists
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write checkpoint file atomically
|
||||
checkpoint_path = self.checkpoints_dir / f"{checkpoint.checkpoint_id}.json"
|
||||
with atomic_write(checkpoint_path) as f:
|
||||
f.write(checkpoint.model_dump_json(indent=2))
|
||||
|
||||
logger.debug(f"Saved checkpoint {checkpoint.checkpoint_id}")
|
||||
|
||||
# Write checkpoint file (blocking I/O in thread)
|
||||
await asyncio.to_thread(_write)
|
||||
|
||||
# Update index (with lock to prevent concurrent modifications)
|
||||
async with self._index_lock:
|
||||
await self._update_index_add(checkpoint)
|
||||
|
||||
async def load_checkpoint(
|
||||
self,
|
||||
checkpoint_id: str | None = None,
|
||||
) -> Checkpoint | None:
|
||||
"""
|
||||
Load checkpoint by ID or latest.
|
||||
|
||||
Args:
|
||||
checkpoint_id: Checkpoint ID to load, or None for latest
|
||||
|
||||
Returns:
|
||||
Checkpoint object, or None if not found
|
||||
"""
|
||||
|
||||
def _read(checkpoint_id: str) -> Checkpoint | None:
|
||||
checkpoint_path = self.checkpoints_dir / f"{checkpoint_id}.json"
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return Checkpoint.model_validate_json(checkpoint_path.read_text())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint {checkpoint_id}: {e}")
|
||||
return None
|
||||
|
||||
# Load index to get checkpoint ID if not provided
|
||||
if checkpoint_id is None:
|
||||
index = await self.load_index()
|
||||
if not index or not index.latest_checkpoint_id:
|
||||
logger.warning("No checkpoints found in index")
|
||||
return None
|
||||
checkpoint_id = index.latest_checkpoint_id
|
||||
|
||||
return await asyncio.to_thread(_read, checkpoint_id)
|
||||
|
||||
async def load_index(self) -> CheckpointIndex | None:
|
||||
"""
|
||||
Load checkpoint index.
|
||||
|
||||
Returns:
|
||||
CheckpointIndex or None if not found
|
||||
"""
|
||||
|
||||
def _read() -> CheckpointIndex | None:
|
||||
if not self.index_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
return CheckpointIndex.model_validate_json(self.index_path.read_text())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint index: {e}")
|
||||
return None
|
||||
|
||||
return await asyncio.to_thread(_read)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
checkpoint_type: str | None = None,
|
||||
is_clean: bool | None = None,
|
||||
) -> list[CheckpointSummary]:
|
||||
"""
|
||||
List checkpoints with optional filters.
|
||||
|
||||
Args:
|
||||
checkpoint_type: Filter by type (node_start, node_complete)
|
||||
is_clean: Filter by clean status
|
||||
|
||||
Returns:
|
||||
List of CheckpointSummary objects
|
||||
"""
|
||||
index = await self.load_index()
|
||||
if not index:
|
||||
return []
|
||||
|
||||
checkpoints = index.checkpoints
|
||||
|
||||
# Apply filters
|
||||
if checkpoint_type:
|
||||
checkpoints = [cp for cp in checkpoints if cp.checkpoint_type == checkpoint_type]
|
||||
|
||||
if is_clean is not None:
|
||||
checkpoints = [cp for cp in checkpoints if cp.is_clean == is_clean]
|
||||
|
||||
return checkpoints
|
||||
|
||||
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Delete a specific checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: Checkpoint ID to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
|
||||
def _delete(checkpoint_id: str) -> bool:
|
||||
checkpoint_path = self.checkpoints_dir / f"{checkpoint_id}.json"
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
checkpoint_path.unlink()
|
||||
logger.info(f"Deleted checkpoint {checkpoint_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete checkpoint {checkpoint_id}: {e}")
|
||||
return False
|
||||
|
||||
# Delete checkpoint file
|
||||
deleted = await asyncio.to_thread(_delete, checkpoint_id)
|
||||
|
||||
if deleted:
|
||||
# Update index (with lock)
|
||||
async with self._index_lock:
|
||||
await self._update_index_remove(checkpoint_id)
|
||||
|
||||
return deleted
|
||||
|
||||
async def prune_checkpoints(
|
||||
self,
|
||||
max_age_days: int = 7,
|
||||
) -> int:
|
||||
"""
|
||||
Prune checkpoints older than max_age_days.
|
||||
|
||||
Args:
|
||||
max_age_days: Maximum age in days (default 7)
|
||||
|
||||
Returns:
|
||||
Number of checkpoints deleted
|
||||
"""
|
||||
index = await self.load_index()
|
||||
if not index or not index.checkpoints:
|
||||
return 0
|
||||
|
||||
# Calculate cutoff datetime
|
||||
cutoff = datetime.now() - timedelta(days=max_age_days)
|
||||
|
||||
# Find old checkpoints
|
||||
old_checkpoints = []
|
||||
for cp in index.checkpoints:
|
||||
try:
|
||||
created = datetime.fromisoformat(cp.created_at)
|
||||
if created < cutoff:
|
||||
old_checkpoints.append(cp.checkpoint_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse timestamp for {cp.checkpoint_id}: {e}")
|
||||
|
||||
# Delete old checkpoints
|
||||
deleted_count = 0
|
||||
for checkpoint_id in old_checkpoints:
|
||||
if await self.delete_checkpoint(checkpoint_id):
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Pruned {deleted_count} checkpoints older than {max_age_days} days")
|
||||
|
||||
return deleted_count
|
||||
|
||||
async def checkpoint_exists(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Check if a checkpoint exists.
|
||||
|
||||
Args:
|
||||
checkpoint_id: Checkpoint ID
|
||||
|
||||
Returns:
|
||||
True if checkpoint exists
|
||||
"""
|
||||
|
||||
def _check(checkpoint_id: str) -> bool:
|
||||
checkpoint_path = self.checkpoints_dir / f"{checkpoint_id}.json"
|
||||
return checkpoint_path.exists()
|
||||
|
||||
return await asyncio.to_thread(_check, checkpoint_id)
|
||||
|
||||
async def _update_index_add(self, checkpoint: Checkpoint) -> None:
|
||||
"""
|
||||
Update index after adding a checkpoint.
|
||||
|
||||
Should be called with _index_lock held.
|
||||
|
||||
Args:
|
||||
checkpoint: Checkpoint that was added
|
||||
"""
|
||||
|
||||
def _write(index: CheckpointIndex):
|
||||
# Ensure directory exists
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write index atomically
|
||||
with atomic_write(self.index_path) as f:
|
||||
f.write(index.model_dump_json(indent=2))
|
||||
|
||||
# Load or create index
|
||||
index = await self.load_index()
|
||||
if not index:
|
||||
index = CheckpointIndex(
|
||||
session_id=checkpoint.session_id,
|
||||
checkpoints=[],
|
||||
)
|
||||
|
||||
# Add checkpoint to index
|
||||
index.add_checkpoint(checkpoint)
|
||||
|
||||
# Write updated index
|
||||
await asyncio.to_thread(_write, index)
|
||||
|
||||
logger.debug(f"Updated index with checkpoint {checkpoint.checkpoint_id}")
|
||||
|
||||
async def _update_index_remove(self, checkpoint_id: str) -> None:
|
||||
"""
|
||||
Update index after removing a checkpoint.
|
||||
|
||||
Should be called with _index_lock held.
|
||||
|
||||
Args:
|
||||
checkpoint_id: Checkpoint ID that was removed
|
||||
"""
|
||||
|
||||
def _write(index: CheckpointIndex):
|
||||
with atomic_write(self.index_path) as f:
|
||||
f.write(index.model_dump_json(indent=2))
|
||||
|
||||
# Load index
|
||||
index = await self.load_index()
|
||||
if not index:
|
||||
return
|
||||
|
||||
# Remove checkpoint from index
|
||||
index.checkpoints = [cp for cp in index.checkpoints if cp.checkpoint_id != checkpoint_id]
|
||||
|
||||
# Update totals
|
||||
index.total_checkpoints = len(index.checkpoints)
|
||||
|
||||
# Update latest_checkpoint_id if we removed the latest
|
||||
if index.latest_checkpoint_id == checkpoint_id:
|
||||
index.latest_checkpoint_id = (
|
||||
index.checkpoints[-1].checkpoint_id if index.checkpoints else None
|
||||
)
|
||||
|
||||
# Write updated index
|
||||
await asyncio.to_thread(_write, index)
|
||||
|
||||
logger.debug(f"Removed checkpoint {checkpoint_id} from index")
|
||||
+101
-4
@@ -6,7 +6,7 @@ import time
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal, Vertical
|
||||
from textual.widgets import Footer, Label
|
||||
from textual.widgets import Footer, Input, Label
|
||||
|
||||
from framework.runtime.agent_runtime import AgentRuntime
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
@@ -208,17 +208,24 @@ class AdenTUI(App):
|
||||
Binding("ctrl+c", "ctrl_c", "Interrupt", show=False, priority=True),
|
||||
Binding("super+c", "ctrl_c", "Copy", show=False, priority=True),
|
||||
Binding("ctrl+s", "screenshot", "Screenshot (SVG)", show=True, priority=True),
|
||||
Binding("ctrl+z", "pause_execution", "Pause", show=True, priority=True),
|
||||
Binding("ctrl+r", "show_sessions", "Sessions", show=True, priority=True),
|
||||
Binding("tab", "focus_next", "Next Panel", show=True),
|
||||
Binding("shift+tab", "focus_previous", "Previous Panel", show=False),
|
||||
]
|
||||
|
||||
def __init__(self, runtime: AgentRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
runtime: AgentRuntime,
|
||||
resume_session: str | None = None,
|
||||
resume_checkpoint: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.runtime = runtime
|
||||
self.log_pane = LogPane()
|
||||
self.graph_view = GraphOverview(runtime)
|
||||
self.chat_repl = ChatRepl(runtime)
|
||||
self.chat_repl = ChatRepl(runtime, resume_session, resume_checkpoint)
|
||||
self.status_bar = StatusBar(graph_id=runtime.graph.id)
|
||||
self.is_ready = False
|
||||
|
||||
@@ -528,9 +535,99 @@ class AdenTUI(App):
|
||||
except Exception as e:
|
||||
self.notify(f"Screenshot failed: {e}", severity="error", timeout=5)
|
||||
|
||||
def action_pause_execution(self) -> None:
|
||||
"""Immediately pause execution by cancelling task (bound to Ctrl+Z)."""
|
||||
try:
|
||||
chat_repl = self.query_one(ChatRepl)
|
||||
if not chat_repl._current_exec_id:
|
||||
self.notify(
|
||||
"No active execution to pause",
|
||||
severity="information",
|
||||
timeout=3,
|
||||
)
|
||||
return
|
||||
|
||||
# Find and cancel the execution task - executor will catch and save state
|
||||
task_cancelled = False
|
||||
for stream in self.runtime._streams.values():
|
||||
exec_id = chat_repl._current_exec_id
|
||||
task = stream._execution_tasks.get(exec_id)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
task_cancelled = True
|
||||
self.notify(
|
||||
"⏸ Execution paused - state saved",
|
||||
severity="information",
|
||||
timeout=3,
|
||||
)
|
||||
break
|
||||
|
||||
if not task_cancelled:
|
||||
self.notify(
|
||||
"Execution already completed",
|
||||
severity="information",
|
||||
timeout=2,
|
||||
)
|
||||
except Exception as e:
|
||||
self.notify(
|
||||
f"Error pausing execution: {e}",
|
||||
severity="error",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
def action_show_sessions(self) -> None:
|
||||
"""Show sessions list (bound to Ctrl+R)."""
|
||||
# Send /sessions command to chat input
|
||||
try:
|
||||
chat_repl = self.query_one(ChatRepl)
|
||||
chat_input = chat_repl.query_one("#chat-input", Input)
|
||||
chat_input.value = "/sessions"
|
||||
# Trigger submission
|
||||
self.notify(
|
||||
"💡 Type /sessions in the chat to see all sessions",
|
||||
severity="information",
|
||||
timeout=3,
|
||||
)
|
||||
except Exception:
|
||||
self.notify(
|
||||
"Use /sessions command to see all sessions",
|
||||
severity="information",
|
||||
timeout=3,
|
||||
)
|
||||
|
||||
async def on_unmount(self) -> None:
|
||||
"""Cleanup on app shutdown."""
|
||||
"""Cleanup on app shutdown - cancel execution which will save state."""
|
||||
self.is_ready = False
|
||||
|
||||
# Cancel any active execution - the executor will catch CancelledError
|
||||
# and save current state as paused (no waiting needed!)
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
chat_repl = self.query_one(ChatRepl)
|
||||
if chat_repl._current_exec_id:
|
||||
# Find the stream with this execution
|
||||
for stream in self.runtime._streams.values():
|
||||
exec_id = chat_repl._current_exec_id
|
||||
task = stream._execution_tasks.get(exec_id)
|
||||
if task and not task.done():
|
||||
# Cancel the task - executor will catch and save state
|
||||
task.cancel()
|
||||
try:
|
||||
# Wait for executor to save state (may take a few seconds)
|
||||
# Longer timeout for quit to ensure state is properly saved
|
||||
await asyncio.wait_for(task, timeout=5.0)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
# Expected - task was cancelled
|
||||
# If timeout, state may not be fully saved
|
||||
pass
|
||||
except Exception:
|
||||
# Ignore other exceptions during cleanup
|
||||
pass
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if hasattr(self, "_subscription_id"):
|
||||
self.runtime.unsubscribe_from_events(self._subscription_id)
|
||||
|
||||
@@ -17,6 +17,7 @@ Client-facing input:
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from textual.app import ComposeResult
|
||||
@@ -69,13 +70,20 @@ class ChatRepl(Vertical):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, runtime: AgentRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
runtime: AgentRuntime,
|
||||
resume_session: str | None = None,
|
||||
resume_checkpoint: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.runtime = runtime
|
||||
self._current_exec_id: str | None = None
|
||||
self._streaming_snapshot: str = ""
|
||||
self._waiting_for_input: bool = False
|
||||
self._input_node_id: str | None = None
|
||||
self._resume_session = resume_session
|
||||
self._resume_checkpoint = resume_checkpoint
|
||||
|
||||
# Dedicated event loop for agent execution.
|
||||
# Keeps blocking runtime code (LLM calls, MCP tools) off
|
||||
@@ -121,10 +129,589 @@ class ChatRepl(Vertical):
|
||||
if was_at_bottom:
|
||||
history.scroll_end(animate=False)
|
||||
|
||||
async def _handle_command(self, command: str) -> None:
|
||||
"""Handle slash commands for session and checkpoint operations."""
|
||||
parts = command.split(maxsplit=2)
|
||||
cmd = parts[0].lower()
|
||||
|
||||
if cmd == "/help":
|
||||
self._write_history("""[bold cyan]Available Commands:[/bold cyan]
|
||||
[bold]/sessions[/bold] - List all sessions for this agent
|
||||
[bold]/sessions[/bold] <session_id> - Show session details and checkpoints
|
||||
[bold]/resume[/bold] - Resume latest paused/failed session
|
||||
[bold]/resume[/bold] <session_id> - Resume session from where it stopped
|
||||
[bold]/recover[/bold] <session_id> <cp_id> - Recover from specific checkpoint
|
||||
[bold]/pause[/bold] - Pause current execution (Ctrl+Z)
|
||||
[bold]/help[/bold] - Show this help message
|
||||
|
||||
[dim]Examples:[/dim]
|
||||
/sessions [dim]# List all sessions[/dim]
|
||||
/sessions session_20260208_143022 [dim]# Show session details[/dim]
|
||||
/resume [dim]# Resume latest session (from state)[/dim]
|
||||
/resume session_20260208_143022 [dim]# Resume specific session (from state)[/dim]
|
||||
/recover session_20260208_143022 cp_xxx [dim]# Recover from specific checkpoint[/dim]
|
||||
/pause [dim]# Pause (or Ctrl+Z)[/dim]
|
||||
""")
|
||||
elif cmd == "/sessions":
|
||||
session_id = parts[1].strip() if len(parts) > 1 else None
|
||||
await self._cmd_sessions(session_id)
|
||||
elif cmd == "/resume":
|
||||
# Resume from session state (not checkpoint-based)
|
||||
if len(parts) < 2:
|
||||
session_id = await self._find_latest_resumable_session()
|
||||
if not session_id:
|
||||
self._write_history("[bold red]No resumable sessions found[/bold red]")
|
||||
self._write_history(" Tip: Use [bold]/sessions[/bold] to see all sessions")
|
||||
return
|
||||
else:
|
||||
session_id = parts[1].strip()
|
||||
await self._cmd_resume(session_id)
|
||||
elif cmd == "/recover":
|
||||
# Recover from specific checkpoint
|
||||
if len(parts) < 3:
|
||||
self._write_history(
|
||||
"[bold red]Error:[/bold red] /recover requires session_id and checkpoint_id"
|
||||
)
|
||||
self._write_history(" Usage: [bold]/recover <session_id> <checkpoint_id>[/bold]")
|
||||
self._write_history(
|
||||
" Tip: Use [bold]/sessions <session_id>[/bold] to see checkpoints"
|
||||
)
|
||||
return
|
||||
session_id = parts[1].strip()
|
||||
checkpoint_id = parts[2].strip()
|
||||
await self._cmd_recover(session_id, checkpoint_id)
|
||||
elif cmd == "/pause":
|
||||
await self._cmd_pause()
|
||||
else:
|
||||
self._write_history(
|
||||
f"[bold red]Unknown command:[/bold red] {cmd}\n"
|
||||
"Type [bold]/help[/bold] for available commands"
|
||||
)
|
||||
|
||||
async def _cmd_sessions(self, session_id: str | None) -> None:
|
||||
"""List sessions or show details of a specific session."""
|
||||
try:
|
||||
# Get storage path from runtime
|
||||
storage_path = self.runtime._storage.base_path
|
||||
|
||||
if session_id:
|
||||
# Show details of specific session including checkpoints
|
||||
await self._show_session_details(storage_path, session_id)
|
||||
else:
|
||||
# List all sessions
|
||||
await self._list_sessions(storage_path)
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error:[/bold red] {e}")
|
||||
self._write_history(" Could not access session data")
|
||||
|
||||
async def _find_latest_resumable_session(self) -> str | None:
|
||||
"""Find the most recent paused or failed session."""
|
||||
try:
|
||||
storage_path = self.runtime._storage.base_path
|
||||
sessions_dir = storage_path / "sessions"
|
||||
|
||||
if not sessions_dir.exists():
|
||||
return None
|
||||
|
||||
# Get all sessions, most recent first
|
||||
session_dirs = sorted(
|
||||
[d for d in sessions_dir.iterdir() if d.is_dir()],
|
||||
key=lambda d: d.name,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Find first paused, failed, or cancelled session
|
||||
import json
|
||||
|
||||
for session_dir in session_dirs:
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
continue
|
||||
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
|
||||
status = state.get("status", "").lower()
|
||||
|
||||
# Check if resumable (any non-completed status)
|
||||
if status in ["paused", "failed", "cancelled", "active"]:
|
||||
return session_dir.name
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _list_sessions(self, storage_path: Path) -> None:
|
||||
"""List all sessions for the agent."""
|
||||
self._write_history("[bold cyan]Available Sessions:[/bold cyan]")
|
||||
|
||||
# Find all session directories
|
||||
sessions_dir = storage_path / "sessions"
|
||||
if not sessions_dir.exists():
|
||||
self._write_history("[dim]No sessions found.[/dim]")
|
||||
self._write_history(" Sessions will appear here after running the agent")
|
||||
return
|
||||
|
||||
session_dirs = sorted(
|
||||
[d for d in sessions_dir.iterdir() if d.is_dir()],
|
||||
key=lambda d: d.name,
|
||||
reverse=True, # Most recent first
|
||||
)
|
||||
|
||||
if not session_dirs:
|
||||
self._write_history("[dim]No sessions found.[/dim]")
|
||||
return
|
||||
|
||||
self._write_history(f"[dim]Found {len(session_dirs)} session(s)[/dim]\n")
|
||||
|
||||
for session_dir in session_dirs[:10]: # Show last 10 sessions
|
||||
session_id = session_dir.name
|
||||
state_file = session_dir / "state.json"
|
||||
|
||||
if not state_file.exists():
|
||||
continue
|
||||
|
||||
# Read session state
|
||||
try:
|
||||
import json
|
||||
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
|
||||
status = state.get("status", "unknown").upper()
|
||||
|
||||
# Status with color
|
||||
if status == "COMPLETED":
|
||||
status_colored = f"[green]{status}[/green]"
|
||||
elif status == "FAILED":
|
||||
status_colored = f"[red]{status}[/red]"
|
||||
elif status == "PAUSED":
|
||||
status_colored = f"[yellow]{status}[/yellow]"
|
||||
elif status == "CANCELLED":
|
||||
status_colored = f"[dim yellow]{status}[/dim yellow]"
|
||||
else:
|
||||
status_colored = f"[dim]{status}[/dim]"
|
||||
|
||||
# Check for checkpoints
|
||||
checkpoint_dir = session_dir / "checkpoints"
|
||||
checkpoint_count = 0
|
||||
if checkpoint_dir.exists():
|
||||
checkpoint_files = list(checkpoint_dir.glob("cp_*.json"))
|
||||
checkpoint_count = len(checkpoint_files)
|
||||
|
||||
# Session line
|
||||
self._write_history(f"📋 [bold]{session_id}[/bold]")
|
||||
self._write_history(f" Status: {status_colored} Checkpoints: {checkpoint_count}")
|
||||
|
||||
if checkpoint_count > 0:
|
||||
self._write_history(f" [dim]Resume: /resume {session_id}[/dim]")
|
||||
|
||||
self._write_history("") # Blank line
|
||||
|
||||
except Exception as e:
|
||||
self._write_history(f" [dim red]Error reading: {e}[/dim red]")
|
||||
|
||||
async def _show_session_details(self, storage_path: Path, session_id: str) -> None:
|
||||
"""Show detailed information about a specific session."""
|
||||
self._write_history(f"[bold cyan]Session Details:[/bold cyan] {session_id}\n")
|
||||
|
||||
session_dir = storage_path / "sessions" / session_id
|
||||
if not session_dir.exists():
|
||||
self._write_history("[bold red]Error:[/bold red] Session not found")
|
||||
self._write_history(f" Path: {session_dir}")
|
||||
self._write_history(" Tip: Use [bold]/sessions[/bold] to see available sessions")
|
||||
return
|
||||
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
self._write_history("[bold red]Error:[/bold red] Session state not found")
|
||||
return
|
||||
|
||||
try:
|
||||
import json
|
||||
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
|
||||
# Basic info
|
||||
status = state.get("status", "unknown").upper()
|
||||
if status == "COMPLETED":
|
||||
status_colored = f"[green]{status}[/green]"
|
||||
elif status == "FAILED":
|
||||
status_colored = f"[red]{status}[/red]"
|
||||
elif status == "PAUSED":
|
||||
status_colored = f"[yellow]{status}[/yellow]"
|
||||
elif status == "CANCELLED":
|
||||
status_colored = f"[dim yellow]{status}[/dim yellow]"
|
||||
else:
|
||||
status_colored = status
|
||||
|
||||
self._write_history(f"Status: {status_colored}")
|
||||
|
||||
if "started_at" in state:
|
||||
self._write_history(f"Started: {state['started_at']}")
|
||||
if "completed_at" in state:
|
||||
self._write_history(f"Completed: {state['completed_at']}")
|
||||
|
||||
# Execution path
|
||||
if "execution_path" in state and state["execution_path"]:
|
||||
self._write_history("\n[bold]Execution Path:[/bold]")
|
||||
for node_id in state["execution_path"]:
|
||||
self._write_history(f" ✓ {node_id}")
|
||||
|
||||
# Checkpoints
|
||||
checkpoint_dir = session_dir / "checkpoints"
|
||||
if checkpoint_dir.exists():
|
||||
checkpoint_files = sorted(checkpoint_dir.glob("cp_*.json"))
|
||||
if checkpoint_files:
|
||||
self._write_history(
|
||||
f"\n[bold]Available Checkpoints:[/bold] ({len(checkpoint_files)})"
|
||||
)
|
||||
|
||||
# Load and show checkpoints
|
||||
for i, cp_file in enumerate(checkpoint_files[-5:], 1): # Last 5
|
||||
try:
|
||||
with open(cp_file) as f:
|
||||
cp_data = json.load(f)
|
||||
|
||||
cp_id = cp_data.get("checkpoint_id", cp_file.stem)
|
||||
cp_type = cp_data.get("checkpoint_type", "unknown")
|
||||
current_node = cp_data.get("current_node", "unknown")
|
||||
is_clean = cp_data.get("is_clean", False)
|
||||
|
||||
clean_marker = "✓" if is_clean else "⚠"
|
||||
self._write_history(f" {i}. {clean_marker} [cyan]{cp_id}[/cyan]")
|
||||
self._write_history(f" Type: {cp_type}, Node: {current_node}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Quick actions
|
||||
if checkpoint_dir.exists() and list(checkpoint_dir.glob("cp_*.json")):
|
||||
self._write_history("\n[bold]Quick Actions:[/bold]")
|
||||
self._write_history(
|
||||
f" [dim]/resume {session_id}[/dim] - Resume from latest checkpoint"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error:[/bold red] {e}")
|
||||
import traceback
|
||||
|
||||
self._write_history(f"[dim]{traceback.format_exc()}[/dim]")
|
||||
|
||||
async def _cmd_resume(self, session_id: str) -> None:
|
||||
"""Resume a session from its last state (session state, not checkpoint)."""
|
||||
try:
|
||||
storage_path = self.runtime._storage.base_path
|
||||
session_dir = storage_path / "sessions" / session_id
|
||||
|
||||
# Verify session exists
|
||||
if not session_dir.exists():
|
||||
self._write_history(f"[bold red]Error:[/bold red] Session not found: {session_id}")
|
||||
self._write_history(" Use [bold]/sessions[/bold] to see available sessions")
|
||||
return
|
||||
|
||||
# Load session state
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
self._write_history("[bold red]Error:[/bold red] Session state not found")
|
||||
return
|
||||
|
||||
import json
|
||||
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
|
||||
# Resume from session state (not checkpoint)
|
||||
progress = state.get("progress", {})
|
||||
paused_at = progress.get("paused_at") or progress.get("resume_from")
|
||||
|
||||
if paused_at:
|
||||
# Has paused_at - resume from there
|
||||
resume_session_state = {
|
||||
"paused_at": paused_at,
|
||||
"memory": state.get("memory", {}),
|
||||
"execution_path": progress.get("path", []),
|
||||
"node_visit_counts": progress.get("node_visit_counts", {}),
|
||||
}
|
||||
resume_info = f"From node: [cyan]{paused_at}[/cyan]"
|
||||
else:
|
||||
# No paused_at - just retry with same input
|
||||
resume_session_state = {}
|
||||
resume_info = "Retrying with same input"
|
||||
|
||||
# Display resume info
|
||||
self._write_history(f"[bold cyan]🔄 Resuming session[/bold cyan] {session_id}")
|
||||
self._write_history(f" {resume_info}")
|
||||
if paused_at:
|
||||
self._write_history(" [dim](Using session state, not checkpoint)[/dim]")
|
||||
|
||||
# Check if already executing
|
||||
if self._current_exec_id is not None:
|
||||
self._write_history(
|
||||
"[bold yellow]Warning:[/bold yellow] An execution is already running"
|
||||
)
|
||||
self._write_history(" Wait for it to complete or use /pause first")
|
||||
return
|
||||
|
||||
# Get original input data from session state
|
||||
input_data = state.get("input_data", {})
|
||||
|
||||
# Show indicator
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
indicator.update("Resuming from session state...")
|
||||
indicator.display = True
|
||||
|
||||
# Update placeholder
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent resuming...)"
|
||||
|
||||
# Trigger execution with resume state
|
||||
try:
|
||||
entry_points = self.runtime.get_entry_points()
|
||||
if not entry_points:
|
||||
self._write_history("[bold red]Error:[/bold red] No entry points available")
|
||||
return
|
||||
|
||||
# Submit execution with resume state and original input data
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.runtime.trigger(
|
||||
entry_points[0].id,
|
||||
input_data=input_data,
|
||||
session_state=resume_session_state,
|
||||
),
|
||||
self._agent_loop,
|
||||
)
|
||||
exec_id = await asyncio.wrap_future(future)
|
||||
self._current_exec_id = exec_id
|
||||
|
||||
self._write_history(
|
||||
f"[green]✓[/green] Resume started (execution: {exec_id[:12]}...)"
|
||||
)
|
||||
self._write_history(" Agent is continuing from where it stopped...")
|
||||
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error starting resume:[/bold red] {e}")
|
||||
indicator.display = False
|
||||
chat_input.placeholder = "Enter input for agent..."
|
||||
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error:[/bold red] {e}")
|
||||
import traceback
|
||||
|
||||
self._write_history(f"[dim]{traceback.format_exc()}[/dim]")
|
||||
|
||||
async def _cmd_recover(self, session_id: str, checkpoint_id: str) -> None:
|
||||
"""Recover a session from a specific checkpoint (time-travel debugging)."""
|
||||
try:
|
||||
storage_path = self.runtime._storage.base_path
|
||||
session_dir = storage_path / "sessions" / session_id
|
||||
|
||||
# Verify session exists
|
||||
if not session_dir.exists():
|
||||
self._write_history(f"[bold red]Error:[/bold red] Session not found: {session_id}")
|
||||
self._write_history(" Use [bold]/sessions[/bold] to see available sessions")
|
||||
return
|
||||
|
||||
# Verify checkpoint exists
|
||||
checkpoint_file = session_dir / "checkpoints" / f"{checkpoint_id}.json"
|
||||
if not checkpoint_file.exists():
|
||||
self._write_history(
|
||||
f"[bold red]Error:[/bold red] Checkpoint not found: {checkpoint_id}"
|
||||
)
|
||||
self._write_history(
|
||||
f" Use [bold]/sessions {session_id}[/bold] to see available checkpoints"
|
||||
)
|
||||
return
|
||||
|
||||
# Display recover info
|
||||
self._write_history(f"[bold cyan]⏪ Recovering session[/bold cyan] {session_id}")
|
||||
self._write_history(f" From checkpoint: [cyan]{checkpoint_id}[/cyan]")
|
||||
self._write_history(
|
||||
" [dim](Checkpoint-based recovery for time-travel debugging)[/dim]"
|
||||
)
|
||||
|
||||
# Check if already executing
|
||||
if self._current_exec_id is not None:
|
||||
self._write_history(
|
||||
"[bold yellow]Warning:[/bold yellow] An execution is already running"
|
||||
)
|
||||
self._write_history(" Wait for it to complete or use /pause first")
|
||||
return
|
||||
|
||||
# Create session_state for checkpoint recovery
|
||||
recover_session_state = {
|
||||
"resume_from_checkpoint": checkpoint_id,
|
||||
}
|
||||
|
||||
# Show indicator
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
indicator.update("Recovering from checkpoint...")
|
||||
indicator.display = True
|
||||
|
||||
# Update placeholder
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent recovering...)"
|
||||
|
||||
# Trigger execution with checkpoint recovery
|
||||
try:
|
||||
entry_points = self.runtime.get_entry_points()
|
||||
if not entry_points:
|
||||
self._write_history("[bold red]Error:[/bold red] No entry points available")
|
||||
return
|
||||
|
||||
# Submit execution with checkpoint recovery state
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.runtime.trigger(
|
||||
entry_points[0].id,
|
||||
input_data={},
|
||||
session_state=recover_session_state,
|
||||
),
|
||||
self._agent_loop,
|
||||
)
|
||||
exec_id = await asyncio.wrap_future(future)
|
||||
self._current_exec_id = exec_id
|
||||
|
||||
self._write_history(
|
||||
f"[green]✓[/green] Recovery started (execution: {exec_id[:12]}...)"
|
||||
)
|
||||
self._write_history(" Agent is continuing from checkpoint...")
|
||||
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error starting recovery:[/bold red] {e}")
|
||||
indicator.display = False
|
||||
chat_input.placeholder = "Enter input for agent..."
|
||||
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error:[/bold red] {e}")
|
||||
import traceback
|
||||
|
||||
self._write_history(f"[dim]{traceback.format_exc()}[/dim]")
|
||||
|
||||
async def _cmd_pause(self) -> None:
|
||||
"""Immediately pause execution by cancelling task (same as Ctrl+Z)."""
|
||||
# Check if there's a current execution
|
||||
if not self._current_exec_id:
|
||||
self._write_history("[bold yellow]No active execution to pause[/bold yellow]")
|
||||
self._write_history(" Start an execution first, then use /pause during execution")
|
||||
return
|
||||
|
||||
# Find and cancel the execution task - executor will catch and save state
|
||||
task_cancelled = False
|
||||
for stream in self.runtime._streams.values():
|
||||
exec_id = self._current_exec_id
|
||||
task = stream._execution_tasks.get(exec_id)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
task_cancelled = True
|
||||
self._write_history("[bold green]⏸ Execution paused - state saved[/bold green]")
|
||||
self._write_history(" Resume later with: [bold]/resume[/bold]")
|
||||
break
|
||||
|
||||
if not task_cancelled:
|
||||
self._write_history("[bold yellow]Execution already completed[/bold yellow]")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Add welcome message when widget mounts."""
|
||||
"""Add welcome message and check for resumable sessions."""
|
||||
history = self.query_one("#chat-history", RichLog)
|
||||
history.write("[bold cyan]Chat REPL Ready[/bold cyan] — Type your input below\n")
|
||||
history.write(
|
||||
"[bold cyan]Chat REPL Ready[/bold cyan] — "
|
||||
"Type your input or use [bold]/help[/bold] for commands\n"
|
||||
)
|
||||
|
||||
# Auto-trigger resume/recover if CLI args provided
|
||||
if self._resume_session:
|
||||
if self._resume_checkpoint:
|
||||
# Use /recover for checkpoint-based recovery
|
||||
history.write(
|
||||
"\n[bold cyan]🔄 Auto-recovering from checkpoint "
|
||||
"(--resume-session + --checkpoint)[/bold cyan]"
|
||||
)
|
||||
self.call_later(self._cmd_recover, self._resume_session, self._resume_checkpoint)
|
||||
else:
|
||||
# Use /resume for session state resume
|
||||
history.write(
|
||||
"\n[bold cyan]🔄 Auto-resuming session (--resume-session)[/bold cyan]"
|
||||
)
|
||||
self.call_later(self._cmd_resume, self._resume_session)
|
||||
return # Skip normal startup messages
|
||||
|
||||
# Check for resumable sessions
|
||||
self._check_and_show_resumable_sessions()
|
||||
|
||||
history.write(
|
||||
"[dim]Quick start: /sessions to see previous sessions, "
|
||||
"/pause to pause execution[/dim]\n"
|
||||
)
|
||||
|
||||
def _check_and_show_resumable_sessions(self) -> None:
|
||||
"""Check for non-terminated sessions and prompt user."""
|
||||
try:
|
||||
storage_path = self.runtime._storage.base_path
|
||||
sessions_dir = storage_path / "sessions"
|
||||
|
||||
if not sessions_dir.exists():
|
||||
return
|
||||
|
||||
# Find non-terminated sessions (paused, failed, cancelled, active)
|
||||
resumable = []
|
||||
session_dirs = sorted(
|
||||
[d for d in sessions_dir.iterdir() if d.is_dir()],
|
||||
key=lambda d: d.name,
|
||||
reverse=True, # Most recent first
|
||||
)
|
||||
|
||||
import json
|
||||
|
||||
for session_dir in session_dirs[:5]: # Check last 5 sessions
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
|
||||
status = state.get("status", "").lower()
|
||||
# Non-terminated statuses
|
||||
if status in ["paused", "failed", "cancelled", "active"]:
|
||||
resumable.append(
|
||||
{
|
||||
"session_id": session_dir.name,
|
||||
"status": status.upper(),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if resumable:
|
||||
self._write_history("\n[bold yellow]⚠ Non-terminated sessions found:[/bold yellow]")
|
||||
for i, session in enumerate(resumable[:3], 1): # Show top 3
|
||||
status = session["status"]
|
||||
session_id = session["session_id"]
|
||||
|
||||
# Color code status
|
||||
if status == "PAUSED":
|
||||
status_colored = f"[yellow]{status}[/yellow]"
|
||||
elif status == "FAILED":
|
||||
status_colored = f"[red]{status}[/red]"
|
||||
elif status == "CANCELLED":
|
||||
status_colored = f"[dim yellow]{status}[/dim yellow]"
|
||||
else:
|
||||
status_colored = f"[dim]{status}[/dim]"
|
||||
|
||||
self._write_history(f" {i}. {session_id[:32]}... [{status_colored}]")
|
||||
|
||||
self._write_history("\n[bold cyan]What would you like to do?[/bold cyan]")
|
||||
self._write_history(" • Type [bold]/resume[/bold] to continue the latest session")
|
||||
self._write_history(
|
||||
f" • Type [bold]/resume {resumable[0]['session_id']}[/bold] "
|
||||
"for specific session"
|
||||
)
|
||||
self._write_history(" • Or just type your input to start a new session\n")
|
||||
|
||||
except Exception:
|
||||
# Silently fail - don't block TUI startup
|
||||
pass
|
||||
|
||||
async def on_input_submitted(self, message: Input.Submitted) -> None:
|
||||
"""Handle input submission — either start new execution or inject input."""
|
||||
@@ -132,15 +719,21 @@ class ChatRepl(Vertical):
|
||||
if not user_input:
|
||||
return
|
||||
|
||||
# Handle commands (starting with /) - ALWAYS process commands first
|
||||
# Commands work during execution, during client-facing input, anytime
|
||||
if user_input.startswith("/"):
|
||||
await self._handle_command(user_input)
|
||||
message.input.value = ""
|
||||
return
|
||||
|
||||
# Client-facing input: route to the waiting node
|
||||
if self._waiting_for_input and self._input_node_id:
|
||||
self._write_history(f"[bold green]You:[/bold green] {user_input}")
|
||||
message.input.value = ""
|
||||
|
||||
# Disable input while agent processes the response
|
||||
# Keep input enabled for commands (but change placeholder)
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input.disabled = True
|
||||
chat_input.placeholder = "Enter input for agent..."
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent processing...)"
|
||||
self._waiting_for_input = False
|
||||
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
@@ -193,9 +786,9 @@ class ChatRepl(Vertical):
|
||||
indicator.update("Thinking...")
|
||||
indicator.display = True
|
||||
|
||||
# Disable input while the agent is working
|
||||
# Keep input enabled for commands during execution
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input.disabled = True
|
||||
chat_input.placeholder = "Commands available: /pause, /sessions, /help"
|
||||
|
||||
# Submit execution to the dedicated agent loop so blocking
|
||||
# runtime code (LLM, MCP tools) never touches Textual's loop.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "framework"
|
||||
version = "0.1.0"
|
||||
version = "0.4.2"
|
||||
description = "Goal-driven agent runtime with Builder-friendly observability"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
|
||||
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
Regression tests for conditional edge direct key access (Issue #3599).
|
||||
|
||||
Verifies that node outputs are written to memory before edge evaluation,
|
||||
enabling direct key access in conditional expressions (e.g., 'score > 80')
|
||||
instead of requiring output['score'] > 80 syntax.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
class SimpleRuntime(Runtime):
|
||||
"""Minimal runtime for testing."""
|
||||
|
||||
def start_run(self, **kwargs):
|
||||
return "test-run"
|
||||
|
||||
def end_run(self, **kwargs):
|
||||
pass
|
||||
|
||||
def report_problem(self, **kwargs):
|
||||
pass
|
||||
|
||||
def decide(self, **kwargs):
|
||||
return "test-decision"
|
||||
|
||||
def record_outcome(self, **kwargs):
|
||||
pass
|
||||
|
||||
def set_node(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ScoreNode(NodeProtocol):
|
||||
"""Node that outputs a score value."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"score": 85})
|
||||
|
||||
|
||||
class HighScoreNode(NodeProtocol):
|
||||
"""Consumer node for high scores."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"result": "high_score_path"})
|
||||
|
||||
|
||||
class MultiKeyNode(NodeProtocol):
|
||||
"""Node that outputs multiple keys."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"x": 100, "y": 50})
|
||||
|
||||
|
||||
class ConsumerNode(NodeProtocol):
|
||||
"""Generic consumer node."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"processed": True})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_key_access_in_conditional_edge():
|
||||
"""
|
||||
Verify direct key access works in conditional edges (e.g., 'score > 80').
|
||||
|
||||
This is the core regression test for issue #3599. Before the fix,
|
||||
node outputs were only written to memory during input mapping (after
|
||||
edge evaluation), causing NameError when edges tried to access keys directly.
|
||||
"""
|
||||
goal = Goal(
|
||||
id="test-direct-key",
|
||||
name="Test Direct Key Access",
|
||||
description="Test that direct key access works in conditional edges",
|
||||
)
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="score_node",
|
||||
name="ScoreNode",
|
||||
description="Outputs a score",
|
||||
node_type="function",
|
||||
output_keys=["score"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="high_score_node",
|
||||
name="HighScoreNode",
|
||||
description="Handles high scores",
|
||||
node_type="function",
|
||||
input_keys=["score"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
]
|
||||
|
||||
# Edge with DIRECT key access: 'score > 80' (not 'output["score"] > 80')
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="score_to_high",
|
||||
source="score_node",
|
||||
target="high_score_node",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="score > 80", # Direct key access
|
||||
)
|
||||
]
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id="test-direct-key",
|
||||
entry_node="score_node",
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
terminal_nodes=["high_score_node"],
|
||||
)
|
||||
|
||||
runtime = SimpleRuntime(storage_path="/tmp/test")
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("score_node", ScoreNode())
|
||||
executor.register_node("high_score_node", HighScoreNode())
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Verify the edge was followed (high_score_node executed)
|
||||
assert result.success, "Execution should succeed"
|
||||
assert "high_score_node" in result.path, (
|
||||
f"Expected high_score_node in path. "
|
||||
f"Condition 'score > 80' should evaluate to True (score=85). "
|
||||
f"Path: {result.path}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backward_compatibility_output_syntax():
|
||||
"""
|
||||
Verify backward compatibility: output['key'] syntax still works.
|
||||
|
||||
The fix should not break existing code that uses the explicit
|
||||
output dictionary syntax in conditional expressions.
|
||||
"""
|
||||
goal = Goal(
|
||||
id="test-backward-compat",
|
||||
name="Test Backward Compatibility",
|
||||
description="Test that output['key'] syntax still works",
|
||||
)
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="score_node",
|
||||
name="ScoreNode",
|
||||
description="Outputs a score",
|
||||
node_type="function",
|
||||
output_keys=["score"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="consumer_node",
|
||||
name="ConsumerNode",
|
||||
description="Consumer",
|
||||
node_type="function",
|
||||
input_keys=["score"],
|
||||
output_keys=["processed"],
|
||||
),
|
||||
]
|
||||
|
||||
# Edge with OLD syntax: output['score'] > 80
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="score_to_consumer",
|
||||
source="score_node",
|
||||
target="consumer_node",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output['score'] > 80", # Old explicit syntax
|
||||
)
|
||||
]
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test-graph-compat",
|
||||
goal_id="test-backward-compat",
|
||||
entry_node="score_node",
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
terminal_nodes=["consumer_node"],
|
||||
)
|
||||
|
||||
runtime = SimpleRuntime(storage_path="/tmp/test")
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("score_node", ScoreNode())
|
||||
executor.register_node("consumer_node", ConsumerNode())
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Verify backward compatibility maintained
|
||||
assert result.success, "Execution should succeed"
|
||||
assert "consumer_node" in result.path, (
|
||||
f"Expected consumer_node in path. "
|
||||
f"Old syntax output['score'] > 80 should still work. "
|
||||
f"Path: {result.path}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_keys_in_expression():
|
||||
"""
|
||||
Verify multiple direct keys work in complex expressions.
|
||||
|
||||
Tests that expressions like 'x > y and y < 100' work correctly
|
||||
when both x and y are written to memory before edge evaluation.
|
||||
"""
|
||||
goal = Goal(
|
||||
id="test-multi-key",
|
||||
name="Test Multiple Keys",
|
||||
description="Test multiple keys in conditional expression",
|
||||
)
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="multi_key_node",
|
||||
name="MultiKeyNode",
|
||||
description="Outputs multiple keys",
|
||||
node_type="function",
|
||||
output_keys=["x", "y"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="consumer_node",
|
||||
name="ConsumerNode",
|
||||
description="Consumer",
|
||||
node_type="function",
|
||||
input_keys=["x", "y"],
|
||||
output_keys=["processed"],
|
||||
),
|
||||
]
|
||||
|
||||
# Complex expression with multiple direct keys
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="multi_to_consumer",
|
||||
source="multi_key_node",
|
||||
target="consumer_node",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="x > y and y < 100", # Multiple keys
|
||||
)
|
||||
]
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test-graph-multi",
|
||||
goal_id="test-multi-key",
|
||||
entry_node="multi_key_node",
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
terminal_nodes=["consumer_node"],
|
||||
)
|
||||
|
||||
runtime = SimpleRuntime(storage_path="/tmp/test")
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("multi_key_node", MultiKeyNode())
|
||||
executor.register_node("consumer_node", ConsumerNode())
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Verify multiple keys work correctly
|
||||
assert result.success, "Execution should succeed"
|
||||
assert "consumer_node" in result.path, (
|
||||
f"Expected consumer_node in path. "
|
||||
f"Condition 'x > y and y < 100' should be True (x=100, y=50). "
|
||||
f"Path: {result.path}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_case_condition_false():
|
||||
"""
|
||||
Verify conditions correctly evaluate to False when not met.
|
||||
|
||||
Tests that when a condition fails, the edge is NOT followed
|
||||
and execution doesn't proceed to the target node.
|
||||
"""
|
||||
goal = Goal(
|
||||
id="test-negative",
|
||||
name="Test Negative Case",
|
||||
description="Test condition evaluates to False correctly",
|
||||
)
|
||||
|
||||
class LowScoreNode(NodeProtocol):
|
||||
"""Node that outputs a LOW score."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"score": 30})
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="low_score_node",
|
||||
name="LowScoreNode",
|
||||
description="Outputs low score",
|
||||
node_type="function",
|
||||
output_keys=["score"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="high_score_handler",
|
||||
name="HighScoreHandler",
|
||||
description="Should NOT execute",
|
||||
node_type="function",
|
||||
input_keys=["score"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
]
|
||||
|
||||
# Condition should be FALSE (30 is not > 80)
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="low_to_high",
|
||||
source="low_score_node",
|
||||
target="high_score_handler",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="score > 80", # Should be False
|
||||
)
|
||||
]
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test-graph-negative",
|
||||
goal_id="test-negative",
|
||||
entry_node="low_score_node",
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
terminal_nodes=["high_score_handler"],
|
||||
)
|
||||
|
||||
runtime = SimpleRuntime(storage_path="/tmp/test")
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("low_score_node", LowScoreNode())
|
||||
executor.register_node("high_score_handler", HighScoreNode())
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Verify condition correctly evaluated to False
|
||||
assert result.success, "Execution should succeed"
|
||||
assert "high_score_handler" not in result.path, (
|
||||
f"high_score_handler should NOT be in path. "
|
||||
f"Condition 'score > 80' should be False (score=30). "
|
||||
f"Path: {result.path}"
|
||||
)
|
||||
@@ -93,12 +93,12 @@ hive/
|
||||
│ └── pyproject.toml # Package metadata
|
||||
│
|
||||
├── tools/ # MCP Tools Package
|
||||
│ ├── mcp_server.py # MCP server entry point
|
||||
│ └── src/aden_tools/ # Tools for agent capabilities
|
||||
│ ├── tools/ # Individual tool implementations
|
||||
│ │ ├── web_search_tool/
|
||||
│ │ ├── web_scrape_tool/
|
||||
│ │ └── file_system_toolkits/
|
||||
│ └── mcp_server.py # HTTP MCP server
|
||||
│ └── tools/ # Individual tool implementations
|
||||
│ ├── web_search_tool/
|
||||
│ ├── web_scrape_tool/
|
||||
│ └── file_system_toolkits/
|
||||
│
|
||||
├── exports/ # Agent Packages (user-generated, not in repo)
|
||||
│ └── your_agent/ # Your agents created via /hive
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Deep Research Agent
|
||||
|
||||
A template agent designed to perform comprehensive research on a specific topic and generate a structured report.
|
||||
|
||||
## Usage
|
||||
|
||||
Run the agent using the following command:
|
||||
|
||||
### Linux / Mac
|
||||
```bash
|
||||
PYTHONPATH=core:examples/templates python -m deep_research_agent run --mock --topic "Artificial Intelligence"
|
||||
|
||||
### Windows
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;examples\templates"
|
||||
python -m deep_research_agent run --mock --topic "Artificial Intelligence"
|
||||
|
||||
## Options
|
||||
|
||||
- `-t, --topic`: The research topic (required).
|
||||
- `--mock`: Run without calling real LLM APIs (simulated execution).
|
||||
- `--help`: Show all available options.
|
||||
@@ -36,6 +36,8 @@ Credential categories:
|
||||
- llm.py: LLM provider credentials (anthropic, openai, etc.)
|
||||
- search.py: Search tool credentials (brave_search, google_search, etc.)
|
||||
- email.py: Email provider credentials (resend, google/gmail)
|
||||
- apollo.py: Apollo.io API credentials
|
||||
- airtable.py: Airtable bases and records credentials
|
||||
- github.py: GitHub API credentials
|
||||
- hubspot.py: HubSpot CRM credentials
|
||||
- slack.py: Slack workspace credentials
|
||||
@@ -49,6 +51,7 @@ To add a new credential:
|
||||
3. If new category, import and merge it in this __init__.py
|
||||
"""
|
||||
|
||||
from .apollo import APOLLO_CREDENTIALS
|
||||
from .airtable import AIRTABLE_CREDENTIALS
|
||||
from .base import CredentialError, CredentialSpec
|
||||
from .browser import get_aden_auth_url, get_aden_setup_url, open_browser
|
||||
@@ -72,6 +75,7 @@ CREDENTIAL_SPECS = {
|
||||
**LLM_CREDENTIALS,
|
||||
**SEARCH_CREDENTIALS,
|
||||
**EMAIL_CREDENTIALS,
|
||||
**APOLLO_CREDENTIALS,
|
||||
**GITHUB_CREDENTIALS,
|
||||
**HUBSPOT_CREDENTIALS,
|
||||
**SLACK_CREDENTIALS,
|
||||
@@ -106,5 +110,6 @@ __all__ = [
|
||||
"GITHUB_CREDENTIALS",
|
||||
"HUBSPOT_CREDENTIALS",
|
||||
"SLACK_CREDENTIALS",
|
||||
"APOLLO_CREDENTIALS",
|
||||
"AIRTABLE_CREDENTIALS",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Apollo.io tool credentials.
|
||||
|
||||
Contains credentials for Apollo.io API integration.
|
||||
"""
|
||||
|
||||
from .base import CredentialSpec
|
||||
|
||||
APOLLO_CREDENTIALS = {
|
||||
"apollo": CredentialSpec(
|
||||
env_var="APOLLO_API_KEY",
|
||||
tools=[
|
||||
"apollo_enrich_person",
|
||||
"apollo_enrich_company",
|
||||
"apollo_search_people",
|
||||
"apollo_search_companies",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://apolloio.github.io/apollo-api-docs/",
|
||||
description="Apollo.io API key for contact and company data enrichment",
|
||||
# Auth method support
|
||||
aden_supported=False,
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get an Apollo.io API key:
|
||||
1. Sign up or log in at https://app.apollo.io/
|
||||
2. Go to Settings > Integrations > API
|
||||
3. Click "Connect" to generate your API key
|
||||
4. Copy the API key
|
||||
|
||||
Note: Apollo uses export credits for enrichment:
|
||||
- Free plan: 10 credits/month
|
||||
- Basic ($49/user/mo): 1,000 credits/month
|
||||
- Professional ($79/user/mo): 2,000 credits/month
|
||||
- Overage: $0.20/credit""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://api.apollo.io/v1/auth/health",
|
||||
health_check_method="GET",
|
||||
# Credential store mapping
|
||||
credential_id="apollo",
|
||||
credential_key="api_key",
|
||||
),
|
||||
}
|
||||
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
# Import register_tools from each tool module
|
||||
from .apollo_tool import register_tools as register_apollo
|
||||
from .airtable_tool import register_tools as register_airtable
|
||||
from .csv_tool import register_tools as register_csv
|
||||
from .email_tool import register_tools as register_email
|
||||
@@ -77,6 +78,7 @@ def register_all_tools(
|
||||
# email supports multiple providers (Resend) with auto-detection
|
||||
register_email(mcp, credentials=credentials)
|
||||
register_hubspot(mcp, credentials=credentials)
|
||||
register_apollo(mcp, credentials=credentials)
|
||||
register_slack(mcp, credentials=credentials)
|
||||
register_airtable(mcp, credentials=credentials)
|
||||
|
||||
@@ -114,6 +116,10 @@ def register_all_tools(
|
||||
"csv_append",
|
||||
"csv_info",
|
||||
"csv_sql",
|
||||
"apollo_enrich_person",
|
||||
"apollo_enrich_company",
|
||||
"apollo_search_people",
|
||||
"apollo_search_companies",
|
||||
"github_list_repos",
|
||||
"github_get_repo",
|
||||
"github_search_repos",
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# Apollo.io Tool
|
||||
|
||||
B2B contact and company data enrichment via the Apollo.io API.
|
||||
|
||||
## Tools
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `apollo_enrich_person` | Enrich a contact by email, LinkedIn URL, or name+domain |
|
||||
| `apollo_enrich_company` | Enrich a company by domain |
|
||||
| `apollo_search_people` | Search contacts with filters (titles, seniorities, locations, etc.) |
|
||||
| `apollo_search_companies` | Search companies with filters (industries, employee counts, etc.) |
|
||||
|
||||
## Authentication
|
||||
|
||||
Requires an Apollo.io API key passed via `APOLLO_API_KEY` environment variable or the credential store.
|
||||
|
||||
**How to get an API key:**
|
||||
|
||||
1. Sign up or log in at https://app.apollo.io/
|
||||
2. Go to Settings > Integrations > API
|
||||
3. Click "Connect" to generate your API key
|
||||
4. Copy the API key
|
||||
|
||||
## Pricing
|
||||
|
||||
| Plan | Price | Export Credits/month |
|
||||
|------|-------|---------------------|
|
||||
| Free | $0 | 10 |
|
||||
| Basic | $49/user/mo | 1,000 |
|
||||
| Professional | $79/user/mo | 2,000 |
|
||||
| Overage | - | $0.20/credit |
|
||||
|
||||
## Error Handling
|
||||
|
||||
Returns error dicts for common failure modes:
|
||||
|
||||
- `401` - Invalid API key
|
||||
- `403` - Insufficient credits or permissions
|
||||
- `404` - Resource not found
|
||||
- `422` - Invalid parameters
|
||||
- `429` - Rate limit exceeded
|
||||
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Apollo.io Tool - Contact and company data enrichment via Apollo API.
|
||||
|
||||
Supports API key authentication for:
|
||||
- Person enrichment by email or LinkedIn
|
||||
- Company enrichment by domain
|
||||
- People search with filters
|
||||
- Company search with filters
|
||||
"""
|
||||
|
||||
from .apollo_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,581 @@
|
||||
"""
|
||||
Apollo.io Tool - Contact and company data enrichment via Apollo API.
|
||||
|
||||
Supports:
|
||||
- API key authentication (APOLLO_API_KEY)
|
||||
|
||||
Use Cases:
|
||||
- Enrich contacts by email or LinkedIn URL
|
||||
- Enrich companies by domain
|
||||
- Search for people by titles, seniorities, locations
|
||||
- Search for companies by industries, employee counts, technologies
|
||||
|
||||
API Reference: https://apolloio.github.io/apollo-api-docs/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
APOLLO_API_BASE = "https://api.apollo.io/api/v1"
|
||||
|
||||
|
||||
class _ApolloClient:
|
||||
"""Internal client wrapping Apollo.io API calls."""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self._api_key = api_key
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Api-Key": self._api_key,
|
||||
}
|
||||
|
||||
def _handle_response(self, response: httpx.Response) -> dict[str, Any]:
|
||||
"""Handle common HTTP error codes."""
|
||||
if response.status_code == 401:
|
||||
return {"error": "Invalid Apollo API key"}
|
||||
if response.status_code == 403:
|
||||
return {
|
||||
"error": "Insufficient credits or permissions. Check your Apollo plan.",
|
||||
"help": "Apollo uses export credits for enrichment. Visit https://app.apollo.io/#/settings/plans",
|
||||
}
|
||||
if response.status_code == 404:
|
||||
return {"error": "Resource not found"}
|
||||
if response.status_code == 422:
|
||||
try:
|
||||
detail = response.json().get("error", response.text)
|
||||
except Exception:
|
||||
detail = response.text
|
||||
return {"error": f"Invalid parameters: {detail}"}
|
||||
if response.status_code == 429:
|
||||
return {"error": "Apollo rate limit exceeded. Try again later."}
|
||||
if response.status_code >= 400:
|
||||
try:
|
||||
detail = response.json().get("error", response.text)
|
||||
except Exception:
|
||||
detail = response.text
|
||||
return {"error": f"Apollo API error (HTTP {response.status_code}): {detail}"}
|
||||
return response.json()
|
||||
|
||||
def enrich_person(
|
||||
self,
|
||||
email: str | None = None,
|
||||
linkedin_url: str | None = None,
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
name: str | None = None,
|
||||
domain: str | None = None,
|
||||
reveal_personal_emails: bool = False,
|
||||
reveal_phone_number: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Enrich a person by email, LinkedIn URL, or name and domain."""
|
||||
body: dict[str, Any] = {
|
||||
"reveal_personal_emails": reveal_personal_emails,
|
||||
"reveal_phone_number": reveal_phone_number,
|
||||
}
|
||||
|
||||
if email:
|
||||
body["email"] = email
|
||||
if linkedin_url:
|
||||
body["linkedin_url"] = linkedin_url
|
||||
if first_name:
|
||||
body["first_name"] = first_name
|
||||
if last_name:
|
||||
body["last_name"] = last_name
|
||||
if name:
|
||||
body["name"] = name
|
||||
if domain:
|
||||
body["domain"] = domain
|
||||
|
||||
response = httpx.post(
|
||||
f"{APOLLO_API_BASE}/people/match",
|
||||
headers=self._headers,
|
||||
params=body if not email and not linkedin_url else None,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
result = self._handle_response(response)
|
||||
|
||||
# Handle "not found" gracefully
|
||||
if "error" not in result and result.get("person") is None:
|
||||
return {"match_found": False, "message": "No matching person found"}
|
||||
|
||||
if "error" not in result:
|
||||
person = result.get("person", {})
|
||||
return {
|
||||
"match_found": True,
|
||||
"person": {
|
||||
"id": person.get("id"),
|
||||
"first_name": person.get("first_name"),
|
||||
"last_name": person.get("last_name"),
|
||||
"name": person.get("name"),
|
||||
"title": person.get("title"),
|
||||
"email": person.get("email"),
|
||||
"email_status": person.get("email_status"),
|
||||
"phone_numbers": person.get("phone_numbers", []),
|
||||
"linkedin_url": person.get("linkedin_url"),
|
||||
"twitter_url": person.get("twitter_url"),
|
||||
"city": person.get("city"),
|
||||
"state": person.get("state"),
|
||||
"country": person.get("country"),
|
||||
"organization": {
|
||||
"id": person.get("organization", {}).get("id"),
|
||||
"name": person.get("organization", {}).get("name"),
|
||||
"domain": person.get("organization", {}).get("primary_domain"),
|
||||
"industry": person.get("organization", {}).get("industry"),
|
||||
"employee_count": person.get("organization", {}).get(
|
||||
"estimated_num_employees"
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
return result
|
||||
|
||||
def enrich_company(self, domain: str) -> dict[str, Any]:
|
||||
"""Enrich a company by domain."""
|
||||
body: dict[str, Any] = {
|
||||
"domain": domain,
|
||||
}
|
||||
|
||||
response = httpx.post(
|
||||
f"{APOLLO_API_BASE}/organizations/enrich",
|
||||
headers=self._headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
result = self._handle_response(response)
|
||||
|
||||
# Handle "not found" gracefully
|
||||
if "error" not in result and result.get("organization") is None:
|
||||
return {"match_found": False, "message": "No matching company found"}
|
||||
|
||||
if "error" not in result:
|
||||
org = result.get("organization", {})
|
||||
return {
|
||||
"match_found": True,
|
||||
"organization": {
|
||||
"id": org.get("id"),
|
||||
"name": org.get("name"),
|
||||
"domain": org.get("primary_domain"),
|
||||
"website_url": org.get("website_url"),
|
||||
"linkedin_url": org.get("linkedin_url"),
|
||||
"twitter_url": org.get("twitter_url"),
|
||||
"facebook_url": org.get("facebook_url"),
|
||||
"industry": org.get("industry"),
|
||||
"keywords": org.get("keywords", []),
|
||||
"employee_count": org.get("estimated_num_employees"),
|
||||
"employee_count_range": org.get("employee_count_range"),
|
||||
"annual_revenue": org.get("annual_revenue"),
|
||||
"annual_revenue_printed": org.get("annual_revenue_printed"),
|
||||
"total_funding": org.get("total_funding"),
|
||||
"total_funding_printed": org.get("total_funding_printed"),
|
||||
"latest_funding_round_date": org.get("latest_funding_round_date"),
|
||||
"latest_funding_stage": org.get("latest_funding_stage"),
|
||||
"founded_year": org.get("founded_year"),
|
||||
"phone": org.get("phone"),
|
||||
"city": org.get("city"),
|
||||
"state": org.get("state"),
|
||||
"country": org.get("country"),
|
||||
"street_address": org.get("street_address"),
|
||||
"technologies": org.get("technologies", []),
|
||||
"short_description": org.get("short_description"),
|
||||
},
|
||||
}
|
||||
return result
|
||||
|
||||
def search_people(
|
||||
self,
|
||||
titles: list[str] | None = None,
|
||||
seniorities: list[str] | None = None,
|
||||
locations: list[str] | None = None,
|
||||
company_sizes: list[str] | None = None,
|
||||
industries: list[str] | None = None,
|
||||
technologies: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Search for people with filters."""
|
||||
body: dict[str, Any] = {
|
||||
"per_page": min(limit, 100),
|
||||
"page": 1,
|
||||
}
|
||||
|
||||
if titles:
|
||||
body["person_titles"] = titles
|
||||
if seniorities:
|
||||
body["person_seniorities"] = seniorities
|
||||
if locations:
|
||||
body["person_locations"] = locations
|
||||
if company_sizes:
|
||||
body["organization_num_employees_ranges"] = company_sizes
|
||||
if industries:
|
||||
body["organization_industry_tag_ids"] = industries
|
||||
if technologies:
|
||||
body["currently_using_any_of_technology_uids"] = technologies
|
||||
|
||||
response = httpx.post(
|
||||
f"{APOLLO_API_BASE}/mixed_people/search",
|
||||
headers=self._headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
result = self._handle_response(response)
|
||||
|
||||
if "error" not in result:
|
||||
people = result.get("people", [])
|
||||
return {
|
||||
"total": result.get("pagination", {}).get("total_entries", len(people)),
|
||||
"page": result.get("pagination", {}).get("page", 1),
|
||||
"per_page": result.get("pagination", {}).get("per_page", limit),
|
||||
"results": [
|
||||
{
|
||||
"id": p.get("id"),
|
||||
"first_name": p.get("first_name"),
|
||||
"last_name": p.get("last_name"),
|
||||
"name": p.get("name"),
|
||||
"title": p.get("title"),
|
||||
"email": p.get("email"),
|
||||
"email_status": p.get("email_status"),
|
||||
"linkedin_url": p.get("linkedin_url"),
|
||||
"city": p.get("city"),
|
||||
"state": p.get("state"),
|
||||
"country": p.get("country"),
|
||||
"seniority": p.get("seniority"),
|
||||
"organization": {
|
||||
"id": p.get("organization", {}).get("id")
|
||||
if p.get("organization")
|
||||
else None,
|
||||
"name": p.get("organization", {}).get("name")
|
||||
if p.get("organization")
|
||||
else None,
|
||||
"domain": p.get("organization", {}).get("primary_domain")
|
||||
if p.get("organization")
|
||||
else None,
|
||||
},
|
||||
}
|
||||
for p in people
|
||||
],
|
||||
}
|
||||
return result
|
||||
|
||||
def search_companies(
|
||||
self,
|
||||
industries: list[str] | None = None,
|
||||
employee_counts: list[str] | None = None,
|
||||
locations: list[str] | None = None,
|
||||
technologies: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Search for companies with filters."""
|
||||
body: dict[str, Any] = {
|
||||
"per_page": min(limit, 100),
|
||||
"page": 1,
|
||||
}
|
||||
|
||||
if industries:
|
||||
body["organization_industry_tag_ids"] = industries
|
||||
if employee_counts:
|
||||
body["organization_num_employees_ranges"] = employee_counts
|
||||
if locations:
|
||||
body["organization_locations"] = locations
|
||||
if technologies:
|
||||
body["currently_using_any_of_technology_uids"] = technologies
|
||||
|
||||
response = httpx.post(
|
||||
f"{APOLLO_API_BASE}/mixed_companies/search",
|
||||
headers=self._headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
result = self._handle_response(response)
|
||||
|
||||
if "error" not in result:
|
||||
orgs = result.get("organizations", [])
|
||||
return {
|
||||
"total": result.get("pagination", {}).get("total_entries", len(orgs)),
|
||||
"page": result.get("pagination", {}).get("page", 1),
|
||||
"per_page": result.get("pagination", {}).get("per_page", limit),
|
||||
"results": [
|
||||
{
|
||||
"id": o.get("id"),
|
||||
"name": o.get("name"),
|
||||
"domain": o.get("primary_domain"),
|
||||
"website_url": o.get("website_url"),
|
||||
"linkedin_url": o.get("linkedin_url"),
|
||||
"industry": o.get("industry"),
|
||||
"employee_count": o.get("estimated_num_employees"),
|
||||
"employee_count_range": o.get("employee_count_range"),
|
||||
"annual_revenue_printed": o.get("annual_revenue_printed"),
|
||||
"city": o.get("city"),
|
||||
"state": o.get("state"),
|
||||
"country": o.get("country"),
|
||||
"short_description": o.get("short_description"),
|
||||
}
|
||||
for o in orgs
|
||||
],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register Apollo.io data enrichment tools with the MCP server."""
|
||||
|
||||
def _get_api_key() -> str | None:
|
||||
"""Get Apollo API key from credential manager or environment."""
|
||||
if credentials is not None:
|
||||
api_key = credentials.get("apollo")
|
||||
# Defensive check: ensure we get a string, not a complex object
|
||||
if api_key is not None and not isinstance(api_key, str):
|
||||
raise TypeError(
|
||||
f"Expected string from credentials.get('apollo'), got {type(api_key).__name__}"
|
||||
)
|
||||
return api_key
|
||||
return os.getenv("APOLLO_API_KEY")
|
||||
|
||||
def _get_client() -> _ApolloClient | dict[str, str]:
|
||||
"""Get an Apollo client, or return an error dict if no credentials."""
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
return {
|
||||
"error": "Apollo credentials not configured",
|
||||
"help": (
|
||||
"Set APOLLO_API_KEY environment variable "
|
||||
"or configure via credential store. "
|
||||
"Get your API key at https://app.apollo.io/#/settings/integrations/api"
|
||||
),
|
||||
}
|
||||
return _ApolloClient(api_key)
|
||||
|
||||
# --- Person Enrichment ---
|
||||
|
||||
@mcp.tool()
|
||||
def apollo_enrich_person(
|
||||
email: str | None = None,
|
||||
linkedin_url: str | None = None,
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
name: str | None = None,
|
||||
domain: str | None = None,
|
||||
reveal_personal_emails: bool = False,
|
||||
reveal_phone_number: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Enrich a person's information by email, LinkedIn URL, or name and domain.
|
||||
|
||||
Args:
|
||||
email: Person's email address
|
||||
linkedin_url: Person's LinkedIn profile URL
|
||||
first_name: Person's first name (use with last_name and domain)
|
||||
last_name: Person's last name (use with first_name and domain)
|
||||
name: Person's full name (use with domain)
|
||||
domain: Person's company domain (e.g., "acme.com")
|
||||
reveal_personal_emails: Whether to reveal personal email addresses (default: False)
|
||||
reveal_phone_number: Whether to reveal phone numbers (default: False)
|
||||
|
||||
Returns:
|
||||
Dict with person details including:
|
||||
- Full name, title
|
||||
- Email and email status
|
||||
- Phone numbers (if revealed)
|
||||
- Location (city, state, country)
|
||||
- LinkedIn/Twitter URLs
|
||||
- Company info (name, industry, size)
|
||||
Or error dict if enrichment fails
|
||||
|
||||
Example:
|
||||
apollo_enrich_person(email="john@acme.com")
|
||||
apollo_enrich_person(name="John Doe", domain="acme.com")
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
|
||||
# Validate that we have enough info to match
|
||||
has_email_or_linkedin = bool(email or linkedin_url)
|
||||
has_name_and_domain = bool((first_name and last_name and domain) or (name and domain))
|
||||
|
||||
if not has_email_or_linkedin and not has_name_and_domain:
|
||||
return {
|
||||
"error": (
|
||||
"Invalid search criteria. Provide either (email), (linkedin_url), "
|
||||
"or (name/first_name+last_name AND domain)."
|
||||
)
|
||||
}
|
||||
try:
|
||||
return client.enrich_person(
|
||||
email=email,
|
||||
linkedin_url=linkedin_url,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
name=name,
|
||||
domain=domain,
|
||||
reveal_personal_emails=reveal_personal_emails,
|
||||
reveal_phone_number=reveal_phone_number,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
# --- Company Enrichment ---
|
||||
|
||||
@mcp.tool()
|
||||
def apollo_enrich_company(domain: str) -> dict:
|
||||
"""
|
||||
Enrich a company by domain.
|
||||
|
||||
Args:
|
||||
domain: Company domain (e.g., "acme.com")
|
||||
|
||||
Returns:
|
||||
Dict with company firmographics including:
|
||||
- name, domain, website URL
|
||||
- Industry, keywords
|
||||
- Employee count and range
|
||||
- Annual revenue, funding info
|
||||
- Founded year, location
|
||||
- Technologies used
|
||||
Or error dict if enrichment fails
|
||||
|
||||
Example:
|
||||
apollo_enrich_company(domain="openai.com")
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.enrich_company(domain)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
# --- People Search ---
|
||||
|
||||
@mcp.tool()
|
||||
def apollo_search_people(
|
||||
titles: list[str] | None = None,
|
||||
seniorities: list[str] | None = None,
|
||||
locations: list[str] | None = None,
|
||||
company_sizes: list[str] | None = None,
|
||||
industries: list[str] | None = None,
|
||||
technologies: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Search for contacts with filters.
|
||||
|
||||
Args:
|
||||
titles: Job titles to search for
|
||||
(e.g., ["VP Sales", "Director of Marketing"])
|
||||
seniorities: Seniority levels
|
||||
(e.g., ["vp", "director", "c_suite", "manager", "senior"])
|
||||
locations: Geographic locations
|
||||
(e.g., ["San Francisco, CA", "New York, NY"])
|
||||
company_sizes: Company employee count ranges
|
||||
(e.g., ["1-10", "11-50", "51-200", "201-500", "501-1000", "1001-5000"])
|
||||
industries: Industry tags
|
||||
(e.g., ["technology", "finance", "healthcare"])
|
||||
technologies: Technologies used by company
|
||||
(e.g., ["salesforce", "hubspot", "aws"])
|
||||
limit: Maximum results (1-100, default 10)
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- total: Total matching results
|
||||
- results: List of matching contacts with email and company info
|
||||
Or error dict if search fails
|
||||
|
||||
Example:
|
||||
apollo_search_people(
|
||||
titles=["VP Sales", "Head of Sales"],
|
||||
seniorities=["vp", "director"],
|
||||
company_sizes=["51-200", "201-500"],
|
||||
limit=25
|
||||
)
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.search_people(
|
||||
titles=titles,
|
||||
seniorities=seniorities,
|
||||
locations=locations,
|
||||
company_sizes=company_sizes,
|
||||
industries=industries,
|
||||
technologies=technologies,
|
||||
limit=limit,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
# --- Company Search ---
|
||||
|
||||
@mcp.tool()
|
||||
def apollo_search_companies(
|
||||
industries: list[str] | None = None,
|
||||
employee_counts: list[str] | None = None,
|
||||
locations: list[str] | None = None,
|
||||
technologies: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Search for companies with filters.
|
||||
|
||||
Args:
|
||||
industries: Industry tags
|
||||
(e.g., ["technology", "finance", "healthcare"])
|
||||
employee_counts: Employee count ranges
|
||||
(e.g., ["1-10", "11-50", "51-200", "201-500", "501-1000"])
|
||||
locations: Geographic locations
|
||||
(e.g., ["San Francisco, CA", "United States"])
|
||||
technologies: Technologies used
|
||||
(e.g., ["salesforce", "hubspot", "aws", "kubernetes"])
|
||||
limit: Maximum results (1-100, default 10)
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- total: Total matching results
|
||||
- results: List of matching companies with firmographic data
|
||||
Or error dict if search fails
|
||||
|
||||
Example:
|
||||
apollo_search_companies(
|
||||
industries=["technology"],
|
||||
employee_counts=["51-200", "201-500"],
|
||||
technologies=["kubernetes"],
|
||||
limit=20
|
||||
)
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.search_companies(
|
||||
industries=industries,
|
||||
employee_counts=employee_counts,
|
||||
locations=locations,
|
||||
technologies=technologies,
|
||||
limit=limit,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
@@ -0,0 +1,675 @@
|
||||
"""
|
||||
Tests for Apollo.io data enrichment tool.
|
||||
|
||||
Covers:
|
||||
- _ApolloClient methods (enrich_person, enrich_company, search_people, search_companies)
|
||||
- Error handling (401, 403, 404, 422, 429, 500, timeout)
|
||||
- Credential retrieval (CredentialStoreAdapter vs env var)
|
||||
- All 4 MCP tool functions
|
||||
- "Not found" graceful handling
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from aden_tools.tools.apollo_tool.apollo_tool import (
|
||||
APOLLO_API_BASE,
|
||||
_ApolloClient,
|
||||
register_tools,
|
||||
)
|
||||
|
||||
# --- _ApolloClient tests ---
|
||||
|
||||
|
||||
class TestApolloClient:
|
||||
def setup_method(self):
|
||||
self.client = _ApolloClient("test-api-key")
|
||||
|
||||
def test_headers(self):
|
||||
headers = self.client._headers
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["Accept"] == "application/json"
|
||||
# API key is passed in X-Api-Key header
|
||||
assert headers["X-Api-Key"] == "test-api-key"
|
||||
|
||||
def test_handle_response_success(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.json.return_value = {"person": {"id": "123"}}
|
||||
assert self.client._handle_response(response) == {"person": {"id": "123"}}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_substring",
|
||||
[
|
||||
(401, "Invalid Apollo API key"),
|
||||
(403, "Insufficient credits"),
|
||||
(404, "not found"),
|
||||
(422, "Invalid parameters"),
|
||||
(429, "rate limit"),
|
||||
],
|
||||
)
|
||||
def test_handle_response_errors(self, status_code, expected_substring):
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
response.json.return_value = {"error": "Test error"}
|
||||
response.text = "Test error"
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert expected_substring in result["error"]
|
||||
|
||||
def test_handle_response_generic_error(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 500
|
||||
response.json.return_value = {"error": "Internal Server Error"}
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert "500" in result["error"]
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_by_email(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"person": {
|
||||
"id": "p123",
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"name": "John Doe",
|
||||
"title": "VP Sales",
|
||||
"email": "john@acme.com",
|
||||
"email_status": "verified",
|
||||
"phone_numbers": [{"sanitized_number": "+1234567890"}],
|
||||
"linkedin_url": "https://linkedin.com/in/johndoe",
|
||||
"twitter_url": None,
|
||||
"city": "San Francisco",
|
||||
"state": "California",
|
||||
"country": "United States",
|
||||
"organization": {
|
||||
"id": "o456",
|
||||
"name": "Acme Inc",
|
||||
"primary_domain": "acme.com",
|
||||
"industry": "Technology",
|
||||
"estimated_num_employees": 250,
|
||||
},
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.enrich_person(email="john@acme.com")
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{APOLLO_API_BASE}/people/match",
|
||||
headers=self.client._headers,
|
||||
params=None,
|
||||
json={
|
||||
"email": "john@acme.com",
|
||||
"reveal_personal_emails": False,
|
||||
"reveal_phone_number": False,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
assert result["match_found"] is True
|
||||
assert result["person"]["first_name"] == "John"
|
||||
assert result["person"]["title"] == "VP Sales"
|
||||
assert result["person"]["organization"]["name"] == "Acme Inc"
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_by_linkedin(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"person": {
|
||||
"id": "p456",
|
||||
"first_name": "Jane",
|
||||
"last_name": "Smith",
|
||||
"name": "Jane Smith",
|
||||
"title": "CTO",
|
||||
"email": "jane@startup.io",
|
||||
"linkedin_url": "https://linkedin.com/in/janesmith",
|
||||
"organization": {},
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.enrich_person(linkedin_url="https://linkedin.com/in/janesmith")
|
||||
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["linkedin_url"] == "https://linkedin.com/in/janesmith"
|
||||
assert result["match_found"] is True
|
||||
assert result["person"]["title"] == "CTO"
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_by_name_and_domain(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"person": {"id": "p123"}}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.client.enrich_person(name="John Doe", domain="acme.com")
|
||||
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["name"] == "John Doe"
|
||||
assert call_json["domain"] == "acme.com"
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_with_reveal_flags(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"person": {"id": "p123"}}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.client.enrich_person(
|
||||
email="john@acme.com",
|
||||
reveal_personal_emails=True,
|
||||
reveal_phone_number=True,
|
||||
)
|
||||
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["reveal_personal_emails"] is True
|
||||
assert call_json["reveal_phone_number"] is True
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_with_optional_params(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"person": {"id": "p789"}}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.client.enrich_person(
|
||||
email="john@acme.com",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
domain="acme.com",
|
||||
)
|
||||
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["email"] == "john@acme.com"
|
||||
assert call_json["first_name"] == "John"
|
||||
assert call_json["last_name"] == "Doe"
|
||||
assert call_json["domain"] == "acme.com"
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_not_found(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"person": None}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.enrich_person(email="nobody@nowhere.xyz")
|
||||
|
||||
assert result["match_found"] is False
|
||||
assert "No matching person found" in result["message"]
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_company(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"organization": {
|
||||
"id": "o123",
|
||||
"name": "OpenAI",
|
||||
"primary_domain": "openai.com",
|
||||
"website_url": "https://openai.com",
|
||||
"linkedin_url": "https://linkedin.com/company/openai",
|
||||
"industry": "Artificial Intelligence",
|
||||
"keywords": ["ai", "machine learning", "gpt"],
|
||||
"estimated_num_employees": 1500,
|
||||
"employee_count_range": "1001-5000",
|
||||
"annual_revenue": 1000000000,
|
||||
"annual_revenue_printed": "$1B",
|
||||
"total_funding": 11000000000,
|
||||
"total_funding_printed": "$11B",
|
||||
"latest_funding_round_date": "2023-01-23",
|
||||
"latest_funding_stage": "Series D",
|
||||
"founded_year": 2015,
|
||||
"phone": "+1-415-123-4567",
|
||||
"city": "San Francisco",
|
||||
"state": "California",
|
||||
"country": "United States",
|
||||
"street_address": "123 Mission St",
|
||||
"technologies": ["python", "kubernetes", "aws"],
|
||||
"short_description": "AI research and deployment company",
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.enrich_company("openai.com")
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{APOLLO_API_BASE}/organizations/enrich",
|
||||
headers=self.client._headers,
|
||||
json={"domain": "openai.com"},
|
||||
timeout=30.0,
|
||||
)
|
||||
assert result["match_found"] is True
|
||||
assert result["organization"]["name"] == "OpenAI"
|
||||
assert result["organization"]["industry"] == "Artificial Intelligence"
|
||||
assert result["organization"]["employee_count"] == 1500
|
||||
assert "python" in result["organization"]["technologies"]
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_company_not_found(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"organization": None}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.enrich_company("notarealcompany12345.xyz")
|
||||
|
||||
assert result["match_found"] is False
|
||||
assert "No matching company found" in result["message"]
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_search_people(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"pagination": {"total_entries": 150, "page": 1, "per_page": 10},
|
||||
"people": [
|
||||
{
|
||||
"id": "p1",
|
||||
"first_name": "Alice",
|
||||
"last_name": "Johnson",
|
||||
"name": "Alice Johnson",
|
||||
"title": "VP Sales",
|
||||
"email": "alice@company.com",
|
||||
"email_status": "verified",
|
||||
"linkedin_url": "https://linkedin.com/in/alicejohnson",
|
||||
"city": "New York",
|
||||
"state": "New York",
|
||||
"country": "United States",
|
||||
"seniority": "vp",
|
||||
"organization": {
|
||||
"id": "o1",
|
||||
"name": "Company Inc",
|
||||
"primary_domain": "company.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "p2",
|
||||
"first_name": "Bob",
|
||||
"last_name": "Smith",
|
||||
"name": "Bob Smith",
|
||||
"title": "Director of Sales",
|
||||
"email": "bob@another.com",
|
||||
"email_status": "verified",
|
||||
"linkedin_url": "https://linkedin.com/in/bobsmith",
|
||||
"city": "Chicago",
|
||||
"state": "Illinois",
|
||||
"country": "United States",
|
||||
"seniority": "director",
|
||||
"organization": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.search_people(
|
||||
titles=["VP Sales", "Director of Sales"],
|
||||
seniorities=["vp", "director"],
|
||||
company_sizes=["51-200", "201-500"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["person_titles"] == ["VP Sales", "Director of Sales"]
|
||||
assert call_json["person_seniorities"] == ["vp", "director"]
|
||||
assert call_json["organization_num_employees_ranges"] == ["51-200", "201-500"]
|
||||
assert call_json["per_page"] == 10
|
||||
|
||||
assert result["total"] == 150
|
||||
assert len(result["results"]) == 2
|
||||
assert result["results"][0]["title"] == "VP Sales"
|
||||
assert result["results"][0]["organization"]["name"] == "Company Inc"
|
||||
# Bob has no organization
|
||||
assert result["results"][1]["organization"]["name"] is None
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_search_people_limit_capped(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"pagination": {}, "people": []}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.client.search_people(limit=200)
|
||||
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["per_page"] == 100
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_search_companies(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"pagination": {"total_entries": 50, "page": 1, "per_page": 10},
|
||||
"organizations": [
|
||||
{
|
||||
"id": "o1",
|
||||
"name": "Tech Startup",
|
||||
"primary_domain": "techstartup.io",
|
||||
"website_url": "https://techstartup.io",
|
||||
"linkedin_url": "https://linkedin.com/company/techstartup",
|
||||
"industry": "Technology",
|
||||
"estimated_num_employees": 75,
|
||||
"employee_count_range": "51-200",
|
||||
"annual_revenue_printed": "$10M",
|
||||
"city": "Austin",
|
||||
"state": "Texas",
|
||||
"country": "United States",
|
||||
"short_description": "A tech startup",
|
||||
},
|
||||
],
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.search_companies(
|
||||
industries=["technology"],
|
||||
employee_counts=["51-200"],
|
||||
technologies=["kubernetes"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["organization_industry_tag_ids"] == ["technology"]
|
||||
assert call_json["organization_num_employees_ranges"] == ["51-200"]
|
||||
assert call_json["currently_using_any_of_technology_uids"] == ["kubernetes"]
|
||||
|
||||
assert result["total"] == 50
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["name"] == "Tech Startup"
|
||||
assert result["results"][0]["industry"] == "Technology"
|
||||
|
||||
|
||||
# --- MCP tool registration and credential tests ---
|
||||
|
||||
|
||||
class TestToolRegistration:
|
||||
def test_register_tools_registers_all_tools(self):
|
||||
mcp = MagicMock()
|
||||
mcp.tool.return_value = lambda fn: fn
|
||||
register_tools(mcp)
|
||||
assert mcp.tool.call_count == 4
|
||||
|
||||
def test_no_credentials_returns_error(self):
|
||||
mcp = MagicMock()
|
||||
registered_fns = []
|
||||
mcp.tool.return_value = lambda fn: registered_fns.append(fn) or fn
|
||||
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
register_tools(mcp, credentials=None)
|
||||
|
||||
enrich_fn = next(fn for fn in registered_fns if fn.__name__ == "apollo_enrich_person")
|
||||
result = enrich_fn(email="test@test.com")
|
||||
assert "error" in result
|
||||
assert "not configured" in result["error"]
|
||||
|
||||
def test_credentials_from_credential_manager(self):
|
||||
mcp = MagicMock()
|
||||
registered_fns = []
|
||||
mcp.tool.return_value = lambda fn: registered_fns.append(fn) or fn
|
||||
|
||||
cred_manager = MagicMock()
|
||||
cred_manager.get.return_value = "test-api-key"
|
||||
|
||||
register_tools(mcp, credentials=cred_manager)
|
||||
|
||||
enrich_fn = next(fn for fn in registered_fns if fn.__name__ == "apollo_enrich_company")
|
||||
|
||||
with patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post") as mock_post:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"organization": {"id": "123", "name": "Test"}}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = enrich_fn(domain="test.com")
|
||||
|
||||
cred_manager.get.assert_called_with("apollo")
|
||||
assert result["match_found"] is True
|
||||
|
||||
def test_credentials_from_env_var(self):
|
||||
mcp = MagicMock()
|
||||
registered_fns = []
|
||||
mcp.tool.return_value = lambda fn: registered_fns.append(fn) or fn
|
||||
|
||||
register_tools(mcp, credentials=None)
|
||||
|
||||
enrich_fn = next(fn for fn in registered_fns if fn.__name__ == "apollo_enrich_company")
|
||||
|
||||
with (
|
||||
patch.dict("os.environ", {"APOLLO_API_KEY": "env-api-key"}),
|
||||
patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post") as mock_post,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"organization": {"id": "123", "name": "Test"}}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = enrich_fn(domain="test.com")
|
||||
|
||||
assert result["match_found"] is True
|
||||
# Verify API key was used in X-Api-Key header
|
||||
call_headers = mock_post.call_args.kwargs["headers"]
|
||||
assert call_headers["X-Api-Key"] == "env-api-key"
|
||||
|
||||
|
||||
# --- Individual tool function tests ---
|
||||
|
||||
|
||||
class TestEnrichPersonTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.return_value = "test-key"
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
def test_enrich_person_requires_email_or_linkedin(self):
|
||||
result = self._fn("apollo_enrich_person")()
|
||||
assert "error" in result
|
||||
assert "Invalid search criteria" in result["error"]
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_success(self, mock_post):
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"person": {
|
||||
"id": "p1",
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"title": "CEO",
|
||||
"organization": {},
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self._fn("apollo_enrich_person")(email="john@acme.com")
|
||||
assert result["match_found"] is True
|
||||
assert result["person"]["title"] == "CEO"
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_timeout(self, mock_post):
|
||||
mock_post.side_effect = httpx.TimeoutException("timed out")
|
||||
result = self._fn("apollo_enrich_person")(email="test@test.com")
|
||||
assert "error" in result
|
||||
assert "timed out" in result["error"]
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_person_network_error(self, mock_post):
|
||||
mock_post.side_effect = httpx.RequestError("connection failed")
|
||||
result = self._fn("apollo_enrich_person")(email="test@test.com")
|
||||
assert "error" in result
|
||||
assert "Network error" in result["error"]
|
||||
|
||||
|
||||
class TestEnrichCompanyTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.return_value = "test-key"
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_company_success(self, mock_post):
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"organization": {
|
||||
"id": "o1",
|
||||
"name": "Acme Inc",
|
||||
"industry": "Technology",
|
||||
"estimated_num_employees": 500,
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self._fn("apollo_enrich_company")(domain="acme.com")
|
||||
assert result["match_found"] is True
|
||||
assert result["organization"]["name"] == "Acme Inc"
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_enrich_company_not_found(self, mock_post):
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value={"organization": None})
|
||||
)
|
||||
result = self._fn("apollo_enrich_company")(domain="notreal.xyz")
|
||||
assert result["match_found"] is False
|
||||
|
||||
|
||||
class TestSearchPeopleTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.return_value = "test-key"
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_search_people_success(self, mock_post):
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"pagination": {"total_entries": 100},
|
||||
"people": [{"id": "p1", "name": "Alice", "title": "VP Sales"}],
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self._fn("apollo_search_people")(titles=["VP Sales"])
|
||||
assert result["total"] == 100
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_search_people_with_all_filters(self, mock_post):
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value={"pagination": {}, "people": []})
|
||||
)
|
||||
self._fn("apollo_search_people")(
|
||||
titles=["CEO"],
|
||||
seniorities=["c_suite"],
|
||||
locations=["San Francisco"],
|
||||
company_sizes=["51-200"],
|
||||
industries=["technology"],
|
||||
technologies=["salesforce"],
|
||||
limit=25,
|
||||
)
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["person_titles"] == ["CEO"]
|
||||
assert call_json["person_seniorities"] == ["c_suite"]
|
||||
assert call_json["person_locations"] == ["San Francisco"]
|
||||
assert call_json["organization_num_employees_ranges"] == ["51-200"]
|
||||
|
||||
|
||||
class TestSearchCompaniesTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.return_value = "test-key"
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_search_companies_success(self, mock_post):
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"pagination": {"total_entries": 50},
|
||||
"organizations": [{"id": "o1", "name": "Tech Corp", "industry": "Technology"}],
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self._fn("apollo_search_companies")(industries=["technology"])
|
||||
assert result["total"] == 50
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["industry"] == "Technology"
|
||||
|
||||
@patch("aden_tools.tools.apollo_tool.apollo_tool.httpx.post")
|
||||
def test_search_companies_with_all_filters(self, mock_post):
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value={"pagination": {}, "organizations": []})
|
||||
)
|
||||
self._fn("apollo_search_companies")(
|
||||
industries=["finance"],
|
||||
employee_counts=["201-500"],
|
||||
locations=["New York"],
|
||||
technologies=["aws"],
|
||||
limit=15,
|
||||
)
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["organization_industry_tag_ids"] == ["finance"]
|
||||
assert call_json["organization_num_employees_ranges"] == ["201-500"]
|
||||
assert call_json["organization_locations"] == ["New York"]
|
||||
assert call_json["currently_using_any_of_technology_uids"] == ["aws"]
|
||||
assert call_json["per_page"] == 15
|
||||
|
||||
|
||||
# --- Credential spec tests ---
|
||||
|
||||
|
||||
class TestCredentialSpec:
|
||||
def test_apollo_credential_spec_exists(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
assert "apollo" in CREDENTIAL_SPECS
|
||||
|
||||
def test_apollo_spec_env_var(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS["apollo"]
|
||||
assert spec.env_var == "APOLLO_API_KEY"
|
||||
|
||||
def test_apollo_spec_tools(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS["apollo"]
|
||||
assert "apollo_enrich_person" in spec.tools
|
||||
assert "apollo_enrich_company" in spec.tools
|
||||
assert "apollo_search_people" in spec.tools
|
||||
assert "apollo_search_companies" in spec.tools
|
||||
assert len(spec.tools) == 4
|
||||
Reference in New Issue
Block a user