Compare commits
211 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c7fa621aeb | |||
| 8c3ad3d70a | |||
| 9eb3fc6285 | |||
| e95f7e7339 | |||
| d949551399 | |||
| a7dbd85ed4 | |||
| 1f288dab1c | |||
| 021754d941 | |||
| 7412904fbf | |||
| 5f3e9379a3 | |||
| 44a8b453b5 | |||
| 26511fe962 | |||
| ce5893216a | |||
| 4e821e4dbf | |||
| d11e97de59 | |||
| 4b10d3e360 | |||
| e04479930f | |||
| 8a8c4cc3f5 | |||
| 1e06ff611e | |||
| 1edc7bb9c7 | |||
| 7b1e0af155 | |||
| 7b15616e29 | |||
| 99ed00fd02 | |||
| f7af5f9ee8 | |||
| e5bcc8005f | |||
| 352d285212 | |||
| 3ef60f9d14 | |||
| a103312127 | |||
| 3d0bba4167 | |||
| 3df718cc14 | |||
| c7497a180e | |||
| 3f39039a21 | |||
| 88fbd90fcc | |||
| e0bf09dd78 | |||
| 3e158b07af | |||
| 5319ed7ee1 | |||
| 978904d2a4 | |||
| 4d876ecc54 | |||
| ba327d0b9e | |||
| b69cf3523c | |||
| 4d8c8e9308 | |||
| b70885934c | |||
| 722b087fc0 | |||
| 0c7ea272db | |||
| 5e4f322fc0 | |||
| c02e45f1aa | |||
| a7217f138c | |||
| 3502f25048 | |||
| 93c026fe31 | |||
| e515977b96 | |||
| 045490a097 | |||
| b25903fb7f | |||
| acf4bd5152 | |||
| 1f5711e1a1 | |||
| ca2dd90313 | |||
| 21e07f3b65 | |||
| e8a06ddd34 | |||
| 34cc09904f | |||
| f6bba8b62f | |||
| d241ad60f8 | |||
| 5a3fcf9a8a | |||
| 1f8a47203f | |||
| 7240090274 | |||
| 2e6a47c2df | |||
| 7f5ecd7913 | |||
| 105b98b113 | |||
| 114e65ab41 | |||
| 0fc13a5cc3 | |||
| e651799e9e | |||
| fcd3e514de | |||
| 7ab41de3a2 | |||
| 58e023f277 | |||
| a98f2d5b86 | |||
| eca43231c0 | |||
| 6763077887 | |||
| f85ff8a2f8 | |||
| 1a5c3480e6 | |||
| 69a7fe7b92 | |||
| a5418d760f | |||
| 0deeb87c63 | |||
| d1d5f49c5a | |||
| 917e23ccc8 | |||
| 988922304f | |||
| ab2bd726c3 | |||
| 713fefb163 | |||
| 83140a1398 | |||
| cafa6dd930 | |||
| 82e1af1a7a | |||
| 30c3dc9205 | |||
| 9a3c6703e1 | |||
| e26468aa19 | |||
| fe14992696 | |||
| d0775b95c6 | |||
| 96121b5757 | |||
| 11c003c48d | |||
| fbe72c58ae | |||
| 816156e87f | |||
| 7bceab3cea | |||
| 83d7f56728 | |||
| 76deba2a6a | |||
| d9d048b9e3 | |||
| 930f417729 | |||
| 8e214d06c1 | |||
| 63e0348963 | |||
| b46a5f0247 | |||
| 79dfd90068 | |||
| f9d5c7c751 | |||
| 8958fb2d88 | |||
| 3c51f2ac36 | |||
| 170a0918f7 | |||
| e3da3b619c | |||
| 6e32513b79 | |||
| 520e1963ee | |||
| 843b9b55e2 | |||
| ccd305ff96 | |||
| 3bd0d1e48c | |||
| d9bfa8e675 | |||
| 27746147e2 | |||
| 3a0b642980 | |||
| 8c0241f087 | |||
| 958d016174 | |||
| 913d318ada | |||
| 8212920cb7 | |||
| 6414be7bd4 | |||
| ac62a82d08 | |||
| a670548a57 | |||
| c4a7463f9d | |||
| edf0ac5270 | |||
| 8ff6b76f37 | |||
| c9f9eb365c | |||
| 7a17c115d3 | |||
| 9a2a11055f | |||
| f21aecd91c | |||
| 4aef73c1d7 | |||
| 9df147b450 | |||
| b71b4b0fc2 | |||
| 1bd2510c52 | |||
| 28b81092f9 | |||
| 4b9a3abba6 | |||
| 0c76b6dcb1 | |||
| 090a85b41b | |||
| 992d573573 | |||
| 9e768e660b | |||
| 26b9ed362e | |||
| 976ae75fde | |||
| 9da91b5319 | |||
| 2493beaf5a | |||
| d63dd021ab | |||
| 697ba89314 | |||
| b6c65ab5d5 | |||
| 162f9a55ad | |||
| e484fdfa51 | |||
| 77d9ccf2e4 | |||
| 94e39ee09e | |||
| 373ad77008 | |||
| 661b0c0038 | |||
| 8ed38bf0e2 | |||
| 4d675dfff7 | |||
| b42a3293f1 | |||
| 87e9bf853d | |||
| c56f78422a | |||
| ac311e10ba | |||
| 0297520263 | |||
| 4803552a7a | |||
| b8d85ff723 | |||
| 7d571dfaec | |||
| ba02e53bdd | |||
| 153e6142ff | |||
| 228449c9d8 | |||
| c65eed8802 | |||
| 40d32f2e01 | |||
| c83aac5e12 | |||
| 48b9241247 | |||
| 7779bc5336 | |||
| beec549f74 | |||
| 310698ecc0 | |||
| 4f719c4778 | |||
| 4cc00f3bdc | |||
| 1f9c47fef1 | |||
| 80a4980640 | |||
| 8dbe424f5a | |||
| ec9bf033e6 | |||
| a2d21ec7bc | |||
| 06ccc853ee | |||
| 4847332161 | |||
| 8c1ee54725 | |||
| 5e537d9d55 | |||
| d6b95067a1 | |||
| 32cae75ef5 | |||
| 21e7554cdb | |||
| 374442e900 | |||
| a1a0ec5ddb | |||
| 1fd56b079c | |||
| 87b0037fcd | |||
| 767d32d420 | |||
| 929dc24e93 | |||
| 8cfb533fef | |||
| 6fd7efece6 | |||
| 776583b3ad | |||
| 9c28dae583 | |||
| 59a315b90b | |||
| 866518f188 | |||
| e5428bec5c | |||
| faf534511b | |||
| 9d11f834b8 | |||
| 131b72cd0c | |||
| 1da9bb0c0f | |||
| 760ed51ad3 | |||
| a08f3a8925 | |||
| 7fae57f311 | |||
| 1f653969a9 |
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "uv",
|
||||
"args": ["run", "--directory", "core", "-m", "framework.mcp.agent_builder_server"],
|
||||
"disabled": false
|
||||
}
|
||||
}
|
||||
}
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-concepts
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-create
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-credentials
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-patterns
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-test
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-concepts
|
||||
---
|
||||
|
||||
use hive-concepts skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-create
|
||||
---
|
||||
|
||||
use hive-create skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-credentials
|
||||
---
|
||||
|
||||
use hive-credentials skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-patterns
|
||||
---
|
||||
|
||||
use hive-patterns skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-test
|
||||
---
|
||||
|
||||
use hive-test skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive
|
||||
---
|
||||
|
||||
use hive skill
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-concepts
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-create
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-credentials
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-patterns
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-test
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"mcp__agent-builder__create_session",
|
||||
"mcp__agent-builder__set_goal",
|
||||
"mcp__agent-builder__add_node",
|
||||
"mcp__agent-builder__add_edge",
|
||||
"mcp__agent-builder__configure_loop",
|
||||
"mcp__agent-builder__add_mcp_server",
|
||||
"mcp__agent-builder__validate_graph",
|
||||
"mcp__agent-builder__export_graph",
|
||||
"mcp__agent-builder__load_session_by_id",
|
||||
"Bash(git status:*)",
|
||||
"Bash(gh run view:*)",
|
||||
"Bash(uv run:*)",
|
||||
"Bash(env:*)",
|
||||
"mcp__agent-builder__test_node",
|
||||
"mcp__agent-builder__list_mcp_tools",
|
||||
"Bash(python -m py_compile:*)",
|
||||
"Bash(python -m pytest:*)",
|
||||
"Bash(source:*)",
|
||||
"mcp__agent-builder__update_node",
|
||||
"mcp__agent-builder__check_missing_credentials",
|
||||
"mcp__agent-builder__list_stored_credentials",
|
||||
"Bash(find:*)",
|
||||
"mcp__agent-builder__run_tests",
|
||||
"Bash(PYTHONPATH=core:exports:tools/src uv run pytest:*)",
|
||||
"mcp__agent-builder__list_agent_sessions",
|
||||
"mcp__agent-builder__generate_constraint_tests",
|
||||
"mcp__agent-builder__generate_success_tests"
|
||||
]
|
||||
},
|
||||
"enabledMcpjsonServers": ["agent-builder", "tools"]
|
||||
}
|
||||
@@ -553,6 +553,26 @@ AskUserQuestion(questions=[{
|
||||
- condition_expr (Python expression, only if conditional)
|
||||
- priority (positive = forward, negative = feedback/loop-back)
|
||||
|
||||
**DETERMINE the graph lifecycle.** Not every agent needs a terminal node:
|
||||
|
||||
| Pattern | `terminal_nodes` | When to Use |
|
||||
|---------|-------------------|-------------|
|
||||
| **Linear (finish)** | `["last-node"]` | Agent completes a task and exits (batch processing, one-shot generation) |
|
||||
| **Forever-alive (loop)** | `[]` (empty) | Agent stays alive for continuous interaction (research assistant, personal assistant, monitoring) |
|
||||
|
||||
**Forever-alive pattern:** The deep_research_agent example uses `terminal_nodes=[]`. Every leaf node has edges that loop back to earlier nodes, creating a perpetual session. The agent only stops when the user explicitly exits. This is the preferred pattern for interactive, multi-turn agents.
|
||||
|
||||
**Key design rules for forever-alive graphs:**
|
||||
- Every node must have at least one outgoing edge (no dead ends)
|
||||
- Client-facing nodes block for user input — these are the natural "pause points"
|
||||
- The user controls when to stop, not the graph
|
||||
- Sessions accumulate memory across loops — plan for conversation compaction
|
||||
- Use `conversation_mode="continuous"` to preserve conversation history across node transitions
|
||||
- `max_iterations` should be set high (e.g., 100) since the agent is designed to run indefinitely
|
||||
- The agent will NOT enter a "completed" execution state — this is intentional, not a bug
|
||||
|
||||
**Ask the user** which lifecycle pattern fits their agent. Default to forever-alive for interactive agents, linear for batch/one-shot tasks.
|
||||
|
||||
**RENDER the complete graph as ASCII art.** Make it large and clear — the user needs to see and understand the full workflow at a glance.
|
||||
|
||||
**IMPORTANT: Make the ASCII art BIG and READABLE.** Use a box-and-arrow style with generous spacing. Do NOT make it tiny or compressed. Example format:
|
||||
@@ -669,6 +689,7 @@ AskUserQuestion(questions=[{
|
||||
|------|---------------|
|
||||
| `config.py` | `AgentMetadata.name` — the display name shown in TUI agent selection |
|
||||
| `config.py` | `AgentMetadata.description` — agent description |
|
||||
| `config.py` | `AgentMetadata.intro_message` — greeting shown to user when TUI loads |
|
||||
| `agent.py` | Module docstring (line 1) |
|
||||
| `agent.py` | `class OldNameAgent:` → `class NewNameAgent:` |
|
||||
| `agent.py` | `GraphSpec(id="old-name-graph")` → `GraphSpec(id="new-name-graph")` — shown in TUI status bar |
|
||||
@@ -735,7 +756,7 @@ mcp__agent-builder__export_graph()
|
||||
|
||||
**THEN write the Python package files** using the exported data. Create these files in `exports/AGENT_NAME/`:
|
||||
|
||||
1. `config.py` - Runtime configuration with model settings
|
||||
1. `config.py` - Runtime configuration with model settings and `AgentMetadata` (including `intro_message` — the greeting shown when TUI loads)
|
||||
2. `nodes/__init__.py` - All NodeSpec definitions
|
||||
3. `agent.py` - Goal, edges, graph config, and agent class
|
||||
4. `__init__.py` - Package exports
|
||||
@@ -911,6 +932,46 @@ result = await executor.execute(graph=graph, goal=goal, input_data=input_data)
|
||||
|
||||
---
|
||||
|
||||
## REFERENCE: Graph Lifecycle & Conversation Memory
|
||||
|
||||
### Terminal vs Forever-Alive Graphs
|
||||
|
||||
Agents have two lifecycle patterns:
|
||||
|
||||
**Linear (terminal) graphs** have `terminal_nodes=["last-node"]`. Execution ends when the terminal node completes. The session enters a "completed" state. Use for batch processing, one-shot generation, and fire-and-forget tasks.
|
||||
|
||||
**Forever-alive graphs** have `terminal_nodes=[]` (empty). Every node has at least one outgoing edge — the graph loops indefinitely. The session **never enters a "completed" state** — this is intentional. The agent stays alive until the user explicitly exits. Use for interactive assistants, research tools, and any agent where the user drives the conversation.
|
||||
|
||||
The deep_research_agent example demonstrates this: `report` loops back to either `research` (dig deeper) or `intake` (new topic). The agent is a persistent, interactive assistant.
|
||||
|
||||
### Continuous Conversation Mode
|
||||
|
||||
When `conversation_mode="continuous"` is set on the GraphSpec, the framework preserves a **single conversation thread** across all node transitions:
|
||||
|
||||
**What the framework does automatically:**
|
||||
- **Inherits conversation**: Same message history carries forward to the next node
|
||||
- **Composes layered system prompts**: Identity (agent-level) + Narrative (auto-generated state summary) + Focus (per-node instructions)
|
||||
- **Inserts transition markers**: At each node boundary, a "State of the World" message showing completed phases, current memory, and available data files
|
||||
- **Accumulates tools**: Once a tool becomes available, it stays available in subsequent nodes
|
||||
- **Compacts opportunistically**: At phase transitions, old tool results are pruned to stay within token budget
|
||||
|
||||
**What this means for agent builders:**
|
||||
- Nodes don't need to re-explain context — the conversation carries it forward
|
||||
- Output keys from earlier nodes are available in memory for edge conditions and later nodes
|
||||
- For forever-alive agents, conversation memory persists across the entire session lifetime
|
||||
- Plan for compaction: very long sessions will have older tool results summarized automatically
|
||||
|
||||
**When to use continuous mode:**
|
||||
- Interactive agents with client-facing nodes (always)
|
||||
- Multi-phase workflows where context matters across phases
|
||||
- Forever-alive agents that loop indefinitely
|
||||
|
||||
**When NOT to use continuous mode:**
|
||||
- Embarrassingly parallel fan-out nodes (each branch should be independent)
|
||||
- Stateless utility agents that process items independently
|
||||
|
||||
---
|
||||
|
||||
## REFERENCE: Framework Capabilities for Qualification
|
||||
|
||||
Use this reference during STEP 2 to give accurate, honest assessments.
|
||||
@@ -943,7 +1004,7 @@ Use this reference during STEP 2 to give accurate, honest assessments.
|
||||
|
||||
| Use Case | Why It's Problematic | Alternative |
|
||||
|----------|---------------------|-------------|
|
||||
| Long-running daemons | Framework is request-response, not persistent | External scheduler + agent |
|
||||
| Persistent background daemons (no user) | Forever-alive graphs need a user at client-facing nodes; no autonomous background polling without user | External scheduler triggering agent runs |
|
||||
| Sub-second responses | LLM latency is inherent | Traditional code, no LLM |
|
||||
| Processing millions of items | Context windows and rate limits | Batch processing + sampling |
|
||||
| Real-time streaming data | No built-in pub/sub or streaming input | Custom MCP server + agent |
|
||||
@@ -978,3 +1039,6 @@ Use this reference during STEP 2 to give accurate, honest assessments.
|
||||
11. **Adding framework gating for LLM behavior** - Fix prompts or use judges, not ad-hoc code
|
||||
12. **Writing code before user approves the graph** - Always get approval on goal, nodes, and graph BEFORE writing any agent code
|
||||
13. **Wrong mcp_servers.json format** - Use flat format (no `"mcpServers"` wrapper), `cwd` must be `"../../tools"`, and `command` must be `"uv"` with args `["run", "python", ...]`
|
||||
14. **Assuming all agents need terminal nodes** - Interactive agents often work best with `terminal_nodes=[]` (forever-alive pattern). The agent never enters "completed" state — this is intentional. Only batch/one-shot agents need terminal nodes
|
||||
15. **Creating dead-end nodes in forever-alive graphs** - Every node must have at least one outgoing edge. A node with no outgoing edges will cause execution to end unexpectedly, breaking the forever-alive loop
|
||||
16. **Not using continuous conversation mode for interactive agents** - Multi-phase interactive agents should use `conversation_mode="continuous"` to preserve context across node transitions. Without it, each node starts with a blank conversation and loses all prior context
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Agent graph construction for Deep Research Agent."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import EdgeSpec, EdgeCondition, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import ExecutionResult, GraphExecutor
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.llm import LiteLLMProvider
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
|
||||
from .config import default_config, metadata
|
||||
from .nodes import (
|
||||
@@ -120,13 +123,31 @@ edges = [
|
||||
condition_expr="needs_more_research == False",
|
||||
priority=2,
|
||||
),
|
||||
# report -> research (user wants deeper research on current topic)
|
||||
EdgeSpec(
|
||||
id="report-to-research",
|
||||
source="report",
|
||||
target="research",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="str(next_action).lower() == 'more_research'",
|
||||
priority=2,
|
||||
),
|
||||
# report -> intake (user wants a new topic — default when not more_research)
|
||||
EdgeSpec(
|
||||
id="report-to-intake",
|
||||
source="report",
|
||||
target="intake",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="str(next_action).lower() != 'more_research'",
|
||||
priority=1,
|
||||
),
|
||||
]
|
||||
|
||||
# Graph configuration
|
||||
entry_node = "intake"
|
||||
entry_points = {"start": "intake"}
|
||||
pause_nodes = []
|
||||
terminal_nodes = ["report"]
|
||||
terminal_nodes = []
|
||||
|
||||
|
||||
class DeepResearchAgent:
|
||||
@@ -136,6 +157,12 @@ class DeepResearchAgent:
|
||||
Flow: intake -> research -> review -> report
|
||||
^ |
|
||||
+-- feedback loop (if user wants more)
|
||||
|
||||
Uses AgentRuntime for proper session management:
|
||||
- Session-scoped storage (sessions/{session_id}/)
|
||||
- Checkpointing for resume capability
|
||||
- Runtime logging
|
||||
- Data folder for save_data/load_data
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
@@ -147,10 +174,10 @@ class DeepResearchAgent:
|
||||
self.entry_points = entry_points
|
||||
self.pause_nodes = pause_nodes
|
||||
self.terminal_nodes = terminal_nodes
|
||||
self._executor: GraphExecutor | None = None
|
||||
self._graph: GraphSpec | None = None
|
||||
self._event_bus: EventBus | None = None
|
||||
self._agent_runtime: AgentRuntime | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
self._storage_path: Path | None = None
|
||||
|
||||
def _build_graph(self) -> GraphSpec:
|
||||
"""Build the GraphSpec."""
|
||||
@@ -171,16 +198,20 @@ class DeepResearchAgent:
|
||||
"max_tool_calls_per_turn": 20,
|
||||
"max_history_tokens": 32000,
|
||||
},
|
||||
conversation_mode="continuous",
|
||||
identity_prompt=(
|
||||
"You are a rigorous research agent. You search for information "
|
||||
"from diverse, authoritative sources, analyze findings critically, "
|
||||
"and produce well-cited reports. You never fabricate information — "
|
||||
"every claim must trace back to a source you actually retrieved."
|
||||
),
|
||||
)
|
||||
|
||||
def _setup(self, mock_mode=False) -> GraphExecutor:
|
||||
"""Set up the executor with all components."""
|
||||
from pathlib import Path
|
||||
def _setup(self, mock_mode=False) -> None:
|
||||
"""Set up the agent runtime with sessions, checkpoints, and logging."""
|
||||
self._storage_path = Path.home() / ".hive" / "agents" / "deep_research_agent"
|
||||
self._storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
storage_path = Path.home() / ".hive" / "agents" / "deep_research_agent"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._event_bus = EventBus()
|
||||
self._tool_registry = ToolRegistry()
|
||||
|
||||
mcp_config_path = Path(__file__).parent / "mcp_servers.json"
|
||||
@@ -199,47 +230,63 @@ class DeepResearchAgent:
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
|
||||
self._graph = self._build_graph()
|
||||
runtime = Runtime(storage_path)
|
||||
|
||||
self._executor = GraphExecutor(
|
||||
runtime=runtime,
|
||||
checkpoint_config = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=False,
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
async_checkpoint=True,
|
||||
)
|
||||
|
||||
entry_point_specs = [
|
||||
EntryPointSpec(
|
||||
id="default",
|
||||
name="Default",
|
||||
entry_node=self.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
)
|
||||
]
|
||||
|
||||
self._agent_runtime = create_agent_runtime(
|
||||
graph=self._graph,
|
||||
goal=self.goal,
|
||||
storage_path=self._storage_path,
|
||||
entry_points=entry_point_specs,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
event_bus=self._event_bus,
|
||||
storage_path=storage_path,
|
||||
loop_config=self._graph.loop_config,
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
|
||||
return self._executor
|
||||
|
||||
async def start(self, mock_mode=False) -> None:
|
||||
"""Set up the agent (initialize executor and tools)."""
|
||||
if self._executor is None:
|
||||
"""Set up and start the agent runtime."""
|
||||
if self._agent_runtime is None:
|
||||
self._setup(mock_mode=mock_mode)
|
||||
if not self._agent_runtime.is_running:
|
||||
await self._agent_runtime.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self._executor = None
|
||||
self._event_bus = None
|
||||
"""Stop the agent runtime and clean up."""
|
||||
if self._agent_runtime and self._agent_runtime.is_running:
|
||||
await self._agent_runtime.stop()
|
||||
self._agent_runtime = None
|
||||
|
||||
async def trigger_and_wait(
|
||||
self,
|
||||
entry_point: str,
|
||||
input_data: dict,
|
||||
entry_point: str = "default",
|
||||
input_data: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
session_state: dict | None = None,
|
||||
) -> ExecutionResult | None:
|
||||
"""Execute the graph and wait for completion."""
|
||||
if self._executor is None:
|
||||
if self._agent_runtime is None:
|
||||
raise RuntimeError("Agent not started. Call start() first.")
|
||||
if self._graph is None:
|
||||
raise RuntimeError("Graph not built. Call start() first.")
|
||||
|
||||
return await self._executor.execute(
|
||||
graph=self._graph,
|
||||
goal=self.goal,
|
||||
input_data=input_data,
|
||||
return await self._agent_runtime.trigger_and_wait(
|
||||
entry_point_id=entry_point,
|
||||
input_data=input_data or {},
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
@@ -250,7 +297,7 @@ class DeepResearchAgent:
|
||||
await self.start(mock_mode=mock_mode)
|
||||
try:
|
||||
result = await self.trigger_and_wait(
|
||||
"start", context, session_state=session_state
|
||||
"default", context, session_state=session_state
|
||||
)
|
||||
return result or ExecutionResult(success=False, error="Execution timeout")
|
||||
finally:
|
||||
|
||||
@@ -16,6 +16,11 @@ class AgentMetadata:
|
||||
"multi-source search, quality evaluation, and synthesis - with TUI conversation "
|
||||
"at key checkpoints for user guidance and feedback."
|
||||
)
|
||||
intro_message: str = (
|
||||
"Hi! I'm your deep research assistant. Tell me a topic and I'll investigate it "
|
||||
"thoroughly — searching multiple sources, evaluating quality, and synthesizing "
|
||||
"a comprehensive report. What would you like me to research?"
|
||||
)
|
||||
|
||||
|
||||
metadata = AgentMetadata()
|
||||
|
||||
@@ -10,8 +10,13 @@ intake_node = NodeSpec(
|
||||
description="Discuss the research topic with the user, clarify scope, and confirm direction",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=0,
|
||||
input_keys=["topic"],
|
||||
output_keys=["research_brief"],
|
||||
success_criteria=(
|
||||
"The research brief is specific and actionable: it states the topic, "
|
||||
"the key questions to answer, the desired scope, and depth."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are a research intake specialist. The user wants to research a topic.
|
||||
Have a brief conversation to clarify what they need.
|
||||
@@ -38,10 +43,14 @@ research_node = NodeSpec(
|
||||
name="Research",
|
||||
description="Search the web, fetch source content, and compile findings",
|
||||
node_type="event_loop",
|
||||
max_node_visits=3,
|
||||
max_node_visits=0,
|
||||
input_keys=["research_brief", "feedback"],
|
||||
output_keys=["findings", "sources", "gaps"],
|
||||
nullable_output_keys=["feedback"],
|
||||
success_criteria=(
|
||||
"Findings reference at least 3 distinct sources with URLs. "
|
||||
"Key claims are substantiated by fetched content, not generated."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are a research agent. Given a research brief, find and analyze sources.
|
||||
|
||||
@@ -56,18 +65,26 @@ Work in phases:
|
||||
and any contradictions between sources.
|
||||
|
||||
Important:
|
||||
- Work in batches of 3-4 tool calls at a time to manage context
|
||||
- Work in batches of 3-4 tool calls at a time — never more than 10 per turn
|
||||
- After each batch, assess whether you have enough material
|
||||
- Prefer quality over quantity — 5 good sources beat 15 thin ones
|
||||
- Track which URL each finding comes from (you'll need citations later)
|
||||
- Call set_output for each key in a SEPARATE turn (not in the same turn as other tool calls)
|
||||
|
||||
When done, use set_output:
|
||||
When done, use set_output (one key at a time, separate turns):
|
||||
- set_output("findings", "Structured summary: key findings with source URLs for each claim. \
|
||||
Include themes, contradictions, and confidence levels.")
|
||||
- set_output("sources", [{"url": "...", "title": "...", "summary": "..."}])
|
||||
- set_output("gaps", "What aspects of the research brief are NOT well-covered yet, if any.")
|
||||
""",
|
||||
tools=["web_search", "web_scrape", "load_data", "save_data", "list_data_files"],
|
||||
tools=[
|
||||
"web_search",
|
||||
"web_scrape",
|
||||
"load_data",
|
||||
"save_data",
|
||||
"append_data",
|
||||
"list_data_files",
|
||||
],
|
||||
)
|
||||
|
||||
# Node 3: Review (client-facing)
|
||||
@@ -78,9 +95,13 @@ review_node = NodeSpec(
|
||||
description="Present findings to user and decide whether to research more or write the report",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=3,
|
||||
max_node_visits=0,
|
||||
input_keys=["findings", "sources", "gaps", "research_brief"],
|
||||
output_keys=["needs_more_research", "feedback"],
|
||||
success_criteria=(
|
||||
"The user has been presented with findings and has explicitly indicated "
|
||||
"whether they want more research or are ready for the report."
|
||||
),
|
||||
system_prompt="""\
|
||||
Present the research findings to the user clearly and concisely.
|
||||
|
||||
@@ -109,49 +130,70 @@ report_node = NodeSpec(
|
||||
description="Write a cited HTML report from the findings and present it to the user",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=0,
|
||||
input_keys=["findings", "sources", "research_brief"],
|
||||
output_keys=["delivery_status"],
|
||||
output_keys=["delivery_status", "next_action"],
|
||||
success_criteria=(
|
||||
"An HTML report has been saved, the file link has been presented to the user, "
|
||||
"and the user has indicated what they want to do next."
|
||||
),
|
||||
system_prompt="""\
|
||||
Write a comprehensive research report as an HTML file and present it to the user.
|
||||
Write a research report as an HTML file and present it to the user.
|
||||
|
||||
**STEP 1 — Write the HTML report (tool calls, NO text to user yet):**
|
||||
IMPORTANT: save_data requires TWO separate arguments: filename and data.
|
||||
Call it like: save_data(filename="report.html", data="<html>...</html>")
|
||||
Do NOT use _raw, do NOT nest arguments inside a JSON string.
|
||||
|
||||
1. Compose a complete, self-contained HTML document with embedded CSS styling.
|
||||
Use a clean, readable design: max-width container, pleasant typography,
|
||||
numbered citation links, a table of contents, and a references section.
|
||||
**STEP 1 — Write and save the HTML report (tool calls, NO text to user yet):**
|
||||
|
||||
Report structure inside the HTML:
|
||||
- Title & date
|
||||
- Executive Summary (2-3 paragraphs)
|
||||
- Table of Contents
|
||||
- Findings (organized by theme, with [n] citation links)
|
||||
- Analysis (synthesis, implications, areas of debate)
|
||||
- Conclusion (key takeaways, confidence assessment)
|
||||
- References (numbered list with clickable URLs)
|
||||
Build a clean HTML document. Keep the HTML concise — aim for clarity over length.
|
||||
Use minimal embedded CSS (a few lines of style, not a full framework).
|
||||
|
||||
Requirements:
|
||||
- Every factual claim must cite its source with [n] notation
|
||||
- Be objective — present multiple viewpoints where sources disagree
|
||||
- Distinguish well-supported conclusions from speculation
|
||||
- Answer the original research questions from the brief
|
||||
Report structure:
|
||||
- Title & date
|
||||
- Executive Summary (2-3 paragraphs)
|
||||
- Key Findings (organized by theme, with [n] citation links)
|
||||
- Analysis (synthesis, implications)
|
||||
- Conclusion (key takeaways)
|
||||
- References (numbered list with clickable URLs)
|
||||
|
||||
2. Save the HTML file:
|
||||
save_data(filename="report.html", data=<your_html>)
|
||||
Requirements:
|
||||
- Every factual claim must cite its source with [n] notation
|
||||
- Be objective — present multiple viewpoints where sources disagree
|
||||
- Answer the original research questions from the brief
|
||||
|
||||
3. Get the clickable link:
|
||||
serve_file_to_user(filename="report.html", label="Research Report")
|
||||
Save the HTML:
|
||||
save_data(filename="report.html", data="<html>...</html>")
|
||||
|
||||
Then get the clickable link:
|
||||
serve_file_to_user(filename="report.html", label="Research Report")
|
||||
|
||||
If save_data fails, simplify and shorten the HTML, then retry.
|
||||
|
||||
**STEP 2 — Present the link to the user (text only, NO tool calls):**
|
||||
|
||||
Tell the user the report is ready and include the file:// URI from
|
||||
serve_file_to_user so they can click it to open. Give a brief summary
|
||||
of what the report covers. Ask if they have questions.
|
||||
of what the report covers. Ask if they have questions or want to continue.
|
||||
|
||||
**STEP 3 — After the user responds:**
|
||||
- Answer follow-up questions from the research material
|
||||
- When the user is satisfied: set_output("delivery_status", "completed")
|
||||
- Answer any follow-up questions from the research material
|
||||
- When the user is ready to move on, ask what they'd like to do next:
|
||||
- Research a new topic?
|
||||
- Dig deeper into the current topic?
|
||||
- Then call set_output:
|
||||
- set_output("delivery_status", "completed")
|
||||
- set_output("next_action", "new_topic") — if they want a new topic
|
||||
- set_output("next_action", "more_research") — if they want deeper research
|
||||
""",
|
||||
tools=["save_data", "serve_file_to_user", "load_data", "list_data_files"],
|
||||
tools=[
|
||||
"save_data",
|
||||
"append_data",
|
||||
"edit_data",
|
||||
"serve_file_to_user",
|
||||
"load_data",
|
||||
"list_data_files",
|
||||
],
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -141,6 +141,12 @@ for f in ~/.zshrc ~/.bashrc ~/.profile; do [ -f "$f" ] && grep -q 'HIVE_CREDENTI
|
||||
- **In shell config but NOT in current session** — run `source ~/.zshrc` (or `~/.bashrc`) first, then proceed
|
||||
- **Not set anywhere** — `EncryptedFileStorage` will auto-generate one. After storing, tell the user to persist it: `export HIVE_CREDENTIAL_KEY="{generated_key}"` in their shell profile
|
||||
|
||||
> **⚠️ IMPORTANT: After adding `HIVE_CREDENTIAL_KEY` to the user's shell config, always display:**
|
||||
> ```
|
||||
> ⚠️ Environment variables were added to your shell config.
|
||||
> Open a NEW TERMINAL for them to take effect outside this session.
|
||||
> ```
|
||||
|
||||
#### Option 1: Aden Platform (OAuth)
|
||||
|
||||
This is the recommended flow for supported integrations (HubSpot, etc.).
|
||||
@@ -202,6 +208,12 @@ if success:
|
||||
print(f"Run: {source_cmd}")
|
||||
```
|
||||
|
||||
> **⚠️ IMPORTANT: After adding `ADEN_API_KEY` to the user's shell config, always display:**
|
||||
> ```
|
||||
> ⚠️ Environment variables were added to your shell config.
|
||||
> Open a NEW TERMINAL for them to take effect outside this session.
|
||||
> ```
|
||||
|
||||
Also save to `~/.hive/configuration.json` for the framework:
|
||||
|
||||
```python
|
||||
@@ -460,9 +472,14 @@ result: HealthCheckResult = check_credential_health("hubspot", token_value)
|
||||
The local encrypted store requires `HIVE_CREDENTIAL_KEY` to encrypt/decrypt credentials.
|
||||
|
||||
- If the user doesn't have one, `EncryptedFileStorage` will auto-generate one and log it
|
||||
- The user MUST persist this key (e.g., in `~/.bashrc` or a secrets manager)
|
||||
- The user MUST persist this key (e.g., in `~/.bashrc`/`~/.zshrc` or a secrets manager)
|
||||
- Without this key, stored credentials cannot be decrypted
|
||||
- This is the ONLY secret that should live in `~/.bashrc` or environment config
|
||||
|
||||
**Shell config rule:** Only TWO keys belong in shell config (`~/.zshrc`/`~/.bashrc`):
|
||||
- `HIVE_CREDENTIAL_KEY` — encryption key for the credential store
|
||||
- `ADEN_API_KEY` — Aden platform auth key (needed before the store can sync)
|
||||
|
||||
All other API keys (Brave, Google, HubSpot, etc.) must go in the encrypted store only. **Never offer to add them to shell config.**
|
||||
|
||||
If `HIVE_CREDENTIAL_KEY` is not set:
|
||||
|
||||
@@ -475,6 +492,7 @@ If `HIVE_CREDENTIAL_KEY` is not set:
|
||||
- **NEVER** log, print, or echo credential values in tool output
|
||||
- **NEVER** store credentials in plaintext files, git-tracked files, or agent configs
|
||||
- **NEVER** hardcode credentials in source code
|
||||
- **NEVER** offer to save API keys to shell config (`~/.zshrc`/`~/.bashrc`) — the **only** keys that belong in shell config are `HIVE_CREDENTIAL_KEY` and `ADEN_API_KEY`. All other credentials (Brave, Google, HubSpot, GitHub, Resend, etc.) go in the encrypted store only.
|
||||
- **ALWAYS** use `SecretStr` from Pydantic when handling credential values in Python
|
||||
- **ALWAYS** use the local encrypted store (`~/.hive/credentials`) for persistence
|
||||
- **ALWAYS** run health checks before storing credentials (when possible)
|
||||
@@ -490,7 +508,7 @@ All credential specs are defined in `tools/src/aden_tools/credentials/`:
|
||||
| `llm.py` | LLM Providers | `anthropic` | No |
|
||||
| `search.py` | Search Tools | `brave_search`, `google_search`, `google_cse` | No |
|
||||
| `email.py` | Email | `resend` | No |
|
||||
| `integrations.py` | Integrations | `github`, `hubspot` | No / Yes |
|
||||
| `integrations.py` | Integrations | `github`, `hubspot`, `google_calendar_oauth` | No / Yes |
|
||||
|
||||
**Note:** Additional LLM providers (Cerebras, Groq, OpenAI) are handled by LiteLLM via environment
|
||||
variables (`CEREBRAS_API_KEY`, `GROQ_API_KEY`, `OPENAI_API_KEY`) but are not yet in CREDENTIAL_SPECS.
|
||||
@@ -601,18 +619,22 @@ All credentials are now configured:
|
||||
│ ✅ CREDENTIALS CONFIGURED │
|
||||
├─────────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ OPEN A NEW TERMINAL before running commands below. │
|
||||
│ Environment variables were saved to your shell config but │
|
||||
│ only take effect in new terminal sessions. │
|
||||
│ │
|
||||
│ NEXT STEPS: │
|
||||
│ │
|
||||
│ 1. RUN YOUR AGENT: │
|
||||
│ │
|
||||
│ PYTHONPATH=core:exports python -m research-agent tui │
|
||||
│ hive tui │
|
||||
│ │
|
||||
│ 2. IF YOU ENCOUNTER ISSUES, USE THE DEBUGGER: │
|
||||
│ │
|
||||
│ /hive-debugger │
|
||||
│ │
|
||||
│ The debugger analyzes runtime logs, identifies retry loops, tool │
|
||||
│ failures, stalled execution, and provides actionable fix suggestions. │
|
||||
│ failures, stalled execution, and provides actionable fix suggestions. │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
@@ -26,6 +26,17 @@ Use `/hive-debugger` when:
|
||||
|
||||
This skill works alongside agents running in TUI mode and provides supervisor-level insights into execution behavior.
|
||||
|
||||
### Forever-Alive Agent Awareness
|
||||
|
||||
Some agents use `terminal_nodes=[]` (the "forever-alive" pattern), meaning they loop indefinitely and never enter a "completed" execution state. For these agents:
|
||||
- Sessions with status "in_progress" or "paused" are **normal**, not failures
|
||||
- High step counts, long durations, and many node visits are expected behavior
|
||||
- The agent stops only when the user explicitly exits — there is no graph-driven completion
|
||||
- Debug focus should be on **quality of individual node visits and iterations**, not whether the session reached a terminal state
|
||||
- Conversation memory accumulates across loops — watch for context overflow and stale data issues
|
||||
|
||||
**How to identify forever-alive agents:** Check `agent.py` or `agent.json` for `terminal_nodes=[]` (empty list). If empty, the agent is forever-alive.
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
@@ -47,7 +58,7 @@ Before using this skill, ensure:
|
||||
**What to do:**
|
||||
|
||||
1. **Ask the developer which agent needs debugging:**
|
||||
- Get agent name (e.g., "twitter_outreach", "deep_research_agent")
|
||||
- Get agent name (e.g., "deep_research_agent", "deep_research_agent")
|
||||
- Confirm the agent exists in `exports/{agent_name}/`
|
||||
|
||||
2. **Determine agent working directory:**
|
||||
@@ -66,7 +77,7 @@ Before using this skill, ensure:
|
||||
|
||||
4. **Store context for the debugging session:**
|
||||
- agent_name
|
||||
- agent_work_dir (e.g., `/home/user/.hive/twitter_outreach`)
|
||||
- agent_work_dir (e.g., `/home/user/.hive/deep_research_agent`)
|
||||
- goal_id
|
||||
- success_criteria
|
||||
- constraints
|
||||
@@ -74,19 +85,19 @@ Before using this skill, ensure:
|
||||
|
||||
**Example:**
|
||||
```
|
||||
Developer: "My twitter_outreach agent keeps failing"
|
||||
Developer: "My deep_research_agent agent keeps failing"
|
||||
|
||||
You: "I'll help debug the twitter_outreach agent. Let me gather context..."
|
||||
You: "I'll help debug the deep_research_agent agent. Let me gather context..."
|
||||
|
||||
[Read exports/twitter_outreach/agent.json]
|
||||
[Read exports/deep_research_agent/agent.json]
|
||||
|
||||
Context gathered:
|
||||
- Agent: twitter_outreach
|
||||
- Goal: twitter-outreach-multi-loop
|
||||
- Working Directory: /home/user/.hive/twitter_outreach
|
||||
- Success Criteria: ["Successfully send 5 personalized outreach messages"]
|
||||
- Constraints: ["Must verify handle exists", "Must personalize message"]
|
||||
- Nodes: ["intake-collector", "profile-analyzer", "message-composer", "outreach-sender"]
|
||||
- Agent: deep_research_agent
|
||||
- Goal: deep-research
|
||||
- Working Directory: /home/user/.hive/deep_research_agent
|
||||
- Success Criteria: ["Produce a comprehensive research report with cited sources"]
|
||||
- Constraints: ["Must cite all sources", "Must cover multiple perspectives"]
|
||||
- Nodes: ["intake", "research", "analysis", "report-writer"]
|
||||
```
|
||||
|
||||
---
|
||||
@@ -142,6 +153,7 @@ Store the selected mode for the session.
|
||||
- Check `attention_summary.categories` for issue types
|
||||
- Note the `run_id` of problematic sessions
|
||||
- Check `status` field: "degraded", "failure", "in_progress"
|
||||
- **For forever-alive agents:** Sessions with status "in_progress" or "paused" are normal — these agents never reach "completed". Only flag sessions with `needs_attention: true` or actual error indicators (tool failures, retry loops, missing outputs). High step counts alone do not indicate a problem.
|
||||
|
||||
3. **Attention flag triggers to understand:**
|
||||
From runtime_logger.py, runs are flagged when:
|
||||
@@ -199,13 +211,20 @@ Which run would you like to investigate?
|
||||
| **Tool Errors** | `tool_error_count > 0`, `attention_reasons` contains "tool_failures" | Tool calls failed (API errors, timeouts, auth issues) |
|
||||
| **Retry Loops** | `retry_count > 3`, `verdict_counts.RETRY > 5` | Judge repeatedly rejecting outputs |
|
||||
| **Guard Failures** | `guard_reject_count > 0` | Output validation failed (wrong types, missing keys) |
|
||||
| **Stalled Execution** | `total_steps > 20`, `verdict_counts.CONTINUE > 10` | EventLoopNode not making progress |
|
||||
| **Stalled Execution** | `total_steps > 20`, `verdict_counts.CONTINUE > 10` | EventLoopNode not making progress. **Caveat:** Forever-alive agents may legitimately have high step counts — check if agent is blocked at a client-facing node (normal) vs genuinely stuck in a loop |
|
||||
| **High Latency** | `latency_ms > 60000`, `avg_step_latency > 5000` | Slow tool calls or LLM responses |
|
||||
| **Client-Facing Issues** | `client_input_requested` but no `user_input_received` | Premature set_output before user input |
|
||||
| **Edge Routing Errors** | `exit_status == "no_valid_edge"`, `attention_reasons` contains "routing_issue" | No edges match current state |
|
||||
| **Memory/Context Issues** | `tokens_used > 100000`, `context_overflow_count > 0` | Conversation history too long |
|
||||
| **Constraint Violations** | Compare output against goal constraints | Agent violated goal-level rules |
|
||||
|
||||
**Forever-Alive Agent Caveat:** If the agent uses `terminal_nodes=[]`, sessions will never reach "completed" status. This is by design. When debugging these agents, focus on:
|
||||
- Whether individual node visits succeed (not whether the graph "finishes")
|
||||
- Quality of each loop iteration — are outputs improving or degrading across loops?
|
||||
- Whether client-facing nodes are correctly blocking for user input
|
||||
- Memory accumulation issues: stale data from previous loops, context overflow across many iterations
|
||||
- Conversation compaction behavior: is the conversation growing unbounded?
|
||||
|
||||
3. **Analyze each flagged node:**
|
||||
- Node ID and name
|
||||
- Exit status
|
||||
@@ -224,7 +243,7 @@ Which run would you like to investigate?
|
||||
```
|
||||
Diagnosis for session_20260206_115718_e22339c5:
|
||||
|
||||
Problem Node: intake-collector
|
||||
Problem Node: research
|
||||
├─ Exit Status: escalate
|
||||
├─ Retry Count: 5 (HIGH)
|
||||
├─ Verdict Counts: {RETRY: 5, ESCALATE: 1}
|
||||
@@ -232,7 +251,7 @@ Problem Node: intake-collector
|
||||
├─ Total Steps: 8
|
||||
└─ Categories: Missing Outputs + Retry Loops
|
||||
|
||||
Root Issue: The intake-collector node is stuck in a retry loop because it's not setting required outputs.
|
||||
Root Issue: The research node is stuck in a retry loop because it's not setting required outputs.
|
||||
```
|
||||
|
||||
---
|
||||
@@ -293,25 +312,25 @@ Root Issue: The intake-collector node is stuck in a retry loop because it's not
|
||||
|
||||
**Example Output:**
|
||||
```
|
||||
Root Cause Analysis for intake-collector:
|
||||
Root Cause Analysis for research:
|
||||
|
||||
Step-by-step breakdown:
|
||||
|
||||
Step 3:
|
||||
- Tool Call: web_search(query="@RomuloNevesOf")
|
||||
- Result: Found Twitter profile information
|
||||
- Tool Call: web_search(query="latest AI regulations 2026")
|
||||
- Result: Found relevant articles and sources
|
||||
- Verdict: RETRY
|
||||
- Feedback: "Missing required output 'twitter_handles'. You found the handle but didn't call set_output."
|
||||
- Feedback: "Missing required output 'research_findings'. You found sources but didn't call set_output."
|
||||
|
||||
Step 4:
|
||||
- Tool Call: web_search(query="@RomuloNevesOf twitter")
|
||||
- Result: Found additional Twitter information
|
||||
- Tool Call: web_search(query="AI regulation policy 2026")
|
||||
- Result: Found additional policy information
|
||||
- Verdict: RETRY
|
||||
- Feedback: "Still missing 'twitter_handles'. Use set_output to save your findings."
|
||||
- Feedback: "Still missing 'research_findings'. Use set_output to save your findings."
|
||||
|
||||
Steps 5-7: Similar pattern continues...
|
||||
|
||||
ROOT CAUSE: The node is successfully finding Twitter handles via web_search, but the LLM is not calling set_output to save the results. It keeps searching for more information instead of completing the task.
|
||||
ROOT CAUSE: The node is successfully finding research sources via web_search, but the LLM is not calling set_output to save the results. It keeps searching for more information instead of completing the task.
|
||||
```
|
||||
|
||||
---
|
||||
@@ -562,15 +581,33 @@ PYTHONPATH=core:exports python -m {agent_name} --tui
|
||||
|
||||
### Find Available Checkpoints:
|
||||
|
||||
```bash
|
||||
# In TUI:
|
||||
/sessions {session_id}
|
||||
Use MCP tools to programmatically find and inspect checkpoints:
|
||||
|
||||
# This shows all checkpoints with timestamps:
|
||||
Available Checkpoints: (3)
|
||||
1. cp_node_complete_intake_143030
|
||||
2. cp_node_complete_research_143115
|
||||
3. cp_pause_research_143130
|
||||
```
|
||||
# List all sessions to find the failed one
|
||||
list_agent_sessions(agent_work_dir="~/.hive/agents/{agent_name}", status="failed")
|
||||
|
||||
# Inspect session state
|
||||
get_agent_session_state(agent_work_dir="~/.hive/agents/{agent_name}", session_id="{session_id}")
|
||||
|
||||
# Find clean checkpoints to resume from
|
||||
list_agent_checkpoints(agent_work_dir="~/.hive/agents/{agent_name}", session_id="{session_id}", is_clean="true")
|
||||
|
||||
# Compare checkpoints to understand what changed
|
||||
compare_agent_checkpoints(
|
||||
agent_work_dir="~/.hive/agents/{agent_name}",
|
||||
session_id="{session_id}",
|
||||
checkpoint_id_before="cp_node_complete_intake_143030",
|
||||
checkpoint_id_after="cp_node_complete_research_143115"
|
||||
)
|
||||
|
||||
# Inspect memory at a specific checkpoint
|
||||
get_agent_checkpoint(agent_work_dir="~/.hive/agents/{agent_name}", session_id="{session_id}", checkpoint_id="cp_node_complete_intake_143030")
|
||||
```
|
||||
|
||||
Or in TUI:
|
||||
```bash
|
||||
/sessions {session_id}
|
||||
```
|
||||
|
||||
**Verification:**
|
||||
@@ -653,7 +690,7 @@ Available Checkpoints: (3)
|
||||
|
||||
**Example interaction:**
|
||||
```
|
||||
Developer: "I applied the fix to intake-collector. How do I verify it works?"
|
||||
Developer: "I applied the fix to research. How do I verify it works?"
|
||||
|
||||
You: "Great! Let's verify the fix with these steps:
|
||||
|
||||
@@ -665,11 +702,11 @@ You: "Great! Let's verify the fix with these steps:
|
||||
[Use query_runtime_logs to check for attention flags]
|
||||
|
||||
3. Verify the specific node:
|
||||
[Use query_runtime_log_details for intake-collector]
|
||||
[Use query_runtime_log_details for research]
|
||||
|
||||
Expected results:
|
||||
- No 'needs_attention' flags
|
||||
- intake-collector shows exit_status='success'
|
||||
- research shows exit_status='success'
|
||||
- retry_count should be 0
|
||||
|
||||
Let me know when you've run it and I'll help check the logs!"
|
||||
@@ -687,7 +724,7 @@ Let me know when you've run it and I'll help check the logs!"
|
||||
- **Example:**
|
||||
```
|
||||
query_runtime_logs(
|
||||
agent_work_dir="/home/user/.hive/twitter_outreach",
|
||||
agent_work_dir="/home/user/.hive/deep_research_agent",
|
||||
status="needs_attention",
|
||||
limit=20
|
||||
)
|
||||
@@ -699,7 +736,7 @@ Let me know when you've run it and I'll help check the logs!"
|
||||
- **Example:**
|
||||
```
|
||||
query_runtime_log_details(
|
||||
agent_work_dir="/home/user/.hive/twitter_outreach",
|
||||
agent_work_dir="/home/user/.hive/deep_research_agent",
|
||||
run_id="session_20260206_115718_e22339c5",
|
||||
needs_attention_only=True
|
||||
)
|
||||
@@ -711,9 +748,83 @@ Let me know when you've run it and I'll help check the logs!"
|
||||
- **Example:**
|
||||
```
|
||||
query_runtime_log_raw(
|
||||
agent_work_dir="/home/user/.hive/twitter_outreach",
|
||||
agent_work_dir="/home/user/.hive/deep_research_agent",
|
||||
run_id="session_20260206_115718_e22339c5",
|
||||
node_id="intake-collector"
|
||||
node_id="research"
|
||||
)
|
||||
```
|
||||
|
||||
### Session & Checkpoint Tools
|
||||
|
||||
**list_agent_sessions** - Browse sessions with filtering
|
||||
- **When to use:** Finding resumable sessions, identifying failed sessions, Stage 3 triage
|
||||
- **Returns:** Session list with status, timestamps, is_resumable, current_node, quality
|
||||
- **Example:**
|
||||
```
|
||||
list_agent_sessions(
|
||||
agent_work_dir="/home/user/.hive/agents/twitter_outreach",
|
||||
status="failed",
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
**get_agent_session_state** - Load full session state (excludes memory values)
|
||||
- **When to use:** Inspecting session progress, checking is_resumable, examining path
|
||||
- **Returns:** Full state with memory_keys/memory_size instead of memory values
|
||||
- **Example:**
|
||||
```
|
||||
get_agent_session_state(
|
||||
agent_work_dir="/home/user/.hive/agents/twitter_outreach",
|
||||
session_id="session_20260208_143022_abc12345"
|
||||
)
|
||||
```
|
||||
|
||||
**get_agent_session_memory** - Get memory contents from a session
|
||||
- **When to use:** Stage 5 root cause analysis, inspecting produced data
|
||||
- **Returns:** All memory keys+values, or a single key's value
|
||||
- **Example:**
|
||||
```
|
||||
get_agent_session_memory(
|
||||
agent_work_dir="/home/user/.hive/agents/twitter_outreach",
|
||||
session_id="session_20260208_143022_abc12345",
|
||||
key="twitter_handles"
|
||||
)
|
||||
```
|
||||
|
||||
**list_agent_checkpoints** - List checkpoints for a session
|
||||
- **When to use:** Stage 6 recovery, finding clean checkpoints to resume from
|
||||
- **Returns:** Checkpoint summaries with type, node, clean status
|
||||
- **Example:**
|
||||
```
|
||||
list_agent_checkpoints(
|
||||
agent_work_dir="/home/user/.hive/agents/twitter_outreach",
|
||||
session_id="session_20260208_143022_abc12345",
|
||||
is_clean="true"
|
||||
)
|
||||
```
|
||||
|
||||
**get_agent_checkpoint** - Load a specific checkpoint with full state
|
||||
- **When to use:** Inspecting exact state at a checkpoint, comparing to current state
|
||||
- **Returns:** Full checkpoint: memory snapshot, execution path, metrics
|
||||
- **Example:**
|
||||
```
|
||||
get_agent_checkpoint(
|
||||
agent_work_dir="/home/user/.hive/agents/twitter_outreach",
|
||||
session_id="session_20260208_143022_abc12345",
|
||||
checkpoint_id="cp_node_complete_intake_143030"
|
||||
)
|
||||
```
|
||||
|
||||
**compare_agent_checkpoints** - Diff memory between two checkpoints
|
||||
- **When to use:** Understanding data flow, finding where state diverged
|
||||
- **Returns:** Memory diff (added/removed/changed keys) + execution path diff
|
||||
- **Example:**
|
||||
```
|
||||
compare_agent_checkpoints(
|
||||
agent_work_dir="/home/user/.hive/agents/twitter_outreach",
|
||||
session_id="session_20260208_143022_abc12345",
|
||||
checkpoint_id_before="cp_node_complete_intake_143030",
|
||||
checkpoint_id_after="cp_node_complete_research_143115"
|
||||
)
|
||||
```
|
||||
|
||||
@@ -739,27 +850,37 @@ Loop every 10 seconds:
|
||||
2. If found: Alert and drill into L2
|
||||
```
|
||||
|
||||
**Pattern 4: Session State + Checkpoint Recovery**
|
||||
```
|
||||
1. list_agent_sessions: Find failed/paused sessions
|
||||
2. get_agent_session_state: Check is_resumable, see execution path
|
||||
3. get_agent_session_memory: Inspect what data was produced
|
||||
4. list_agent_checkpoints: Find clean checkpoints before failure
|
||||
5. compare_agent_checkpoints: Understand what changed between checkpoints
|
||||
6. Recommend resume command with specific checkpoint
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Example Walkthrough
|
||||
|
||||
### Scenario: twitter_outreach Agent Failing
|
||||
### Scenario: deep_research_agent Agent Failing
|
||||
|
||||
**Developer says:** "My twitter_outreach agent keeps failing. It's not finding Twitter handles correctly."
|
||||
**Developer says:** "My deep_research_agent keeps failing. It's not saving research findings correctly."
|
||||
|
||||
#### Stage 1: Setup
|
||||
```
|
||||
You: "I'll help debug the twitter_outreach agent. Let me gather context..."
|
||||
You: "I'll help debug the deep_research_agent agent. Let me gather context..."
|
||||
|
||||
[Read exports/twitter_outreach/agent.json]
|
||||
[Read exports/deep_research_agent/agent.json]
|
||||
|
||||
Context:
|
||||
- Agent: twitter_outreach
|
||||
- Goal: twitter-outreach-multi-loop
|
||||
- Working Dir: ~/.hive/agents/twitter_outreach
|
||||
- Agent: deep_research_agent
|
||||
- Goal: deep-research
|
||||
- Working Dir: ~/.hive/agents/deep_research_agent
|
||||
- Success Criteria: ["Successfully send 5 personalized outreach messages"]
|
||||
- Constraints: ["Must verify handle exists", "Must personalize message"]
|
||||
- Nodes: intake-collector, profile-analyzer, message-composer, outreach-sender
|
||||
- Nodes: research, profile-analyzer, message-composer, outreach-sender
|
||||
```
|
||||
|
||||
#### Stage 2: Mode Selection
|
||||
@@ -791,31 +912,31 @@ You: "Analyzing node-level details..."
|
||||
[Call query_runtime_log_details for the session]
|
||||
|
||||
Results:
|
||||
- Problem Node: intake-collector
|
||||
- Problem Node: research
|
||||
- Exit Status: escalate
|
||||
- Retry Count: 5
|
||||
- Verdict Counts: {RETRY: 5, ESCALATE: 1}
|
||||
- Attention Reasons: ["high_retry_count", "missing_outputs"]
|
||||
- Total Steps: 8
|
||||
|
||||
Diagnosis: intake-collector has retry loop + missing outputs issue
|
||||
Diagnosis: research has retry loop + missing outputs issue
|
||||
Category: Missing Outputs + Retry Loops
|
||||
```
|
||||
|
||||
#### Stage 5: Root Cause (L3)
|
||||
```
|
||||
You: "Drilling into detailed logs for intake-collector..."
|
||||
You: "Drilling into detailed logs for research..."
|
||||
|
||||
[Call query_runtime_log_raw for node intake-collector]
|
||||
[Call query_runtime_log_raw for node research]
|
||||
|
||||
Evidence:
|
||||
Step 3: web_search("@RomuloNevesOf") → Found profile
|
||||
Step 3: web_search("latest AI regulations 2026") → Found articles
|
||||
Verdict: RETRY
|
||||
Feedback: "Missing required output 'twitter_handles'. You found the handle but didn't call set_output."
|
||||
Feedback: "Missing required output 'research_findings'. You found sources but didn't call set_output."
|
||||
|
||||
Step 4: web_search("@RomuloNevesOf twitter") → Found more info
|
||||
Step 4: web_search("AI regulation policy 2026") → Found more info
|
||||
Verdict: RETRY
|
||||
Feedback: "Still missing 'twitter_handles'. Use set_output to save your findings."
|
||||
Feedback: "Still missing 'research_findings'. Use set_output to save your findings."
|
||||
|
||||
Steps 5-7: Same pattern repeats...
|
||||
|
||||
@@ -826,37 +947,37 @@ ROOT CAUSE: Node successfully finds handles but LLM doesn't call set_output. It
|
||||
```
|
||||
You: "Here's the fix for this issue:
|
||||
|
||||
## Issue: intake-collector Not Setting Outputs
|
||||
## Issue: research Node Not Setting Outputs
|
||||
|
||||
**Root Cause:** Node successfully finds Twitter handles but doesn't call set_output
|
||||
**Root Cause:** Node successfully finds research sources but doesn't call set_output
|
||||
|
||||
**Category:** Missing Outputs
|
||||
|
||||
**File to edit:** exports/twitter_outreach/nodes/intake_collector.py
|
||||
**File to edit:** exports/deep_research_agent/nodes/research.py
|
||||
|
||||
**Changes:**
|
||||
Update the system_prompt to emphasize the output requirement:
|
||||
|
||||
```python
|
||||
system_prompt = """
|
||||
You are the Intake Collector for the Twitter outreach agent.
|
||||
You are the Research node for the deep research agent.
|
||||
|
||||
Your task:
|
||||
1. Extract Twitter handle from user input (URL or @mention)
|
||||
2. Verify it's a valid Twitter handle format
|
||||
3. **CRITICAL: Call set_output(twitter_handles=[handle]) immediately after finding valid handle**
|
||||
1. Search for relevant sources on the research topic
|
||||
2. Extract key findings from each source
|
||||
3. **CRITICAL: Call set_output(research_findings=[...]) immediately after gathering sufficient findings**
|
||||
|
||||
IMPORTANT: You MUST call set_output even if you found the handle on the first try.
|
||||
Do not continue searching once you have a valid handle - set output and finish.
|
||||
IMPORTANT: You MUST call set_output even if you found findings on the first search.
|
||||
Do not continue searching indefinitely - set output and finish once you have enough data.
|
||||
|
||||
Output format:
|
||||
{
|
||||
"twitter_handles": ["@handle1", "@handle2", ...]
|
||||
"research_findings": [{"source": "...", "summary": "...", "url": "..."}, ...]
|
||||
}
|
||||
"""
|
||||
```
|
||||
|
||||
**Why this works:** The explicit instruction to call set_output immediately prevents the LLM from continuing to search after finding the handle.
|
||||
**Why this works:** The explicit instruction to call set_output immediately prevents the LLM from continuing to search after finding sufficient sources.
|
||||
```
|
||||
|
||||
#### Stage 7: Verification
|
||||
@@ -880,13 +1001,13 @@ You: "Checking now..."
|
||||
[Call query_runtime_log_details for the new run]
|
||||
|
||||
Results:
|
||||
- Node: intake-collector
|
||||
- Node: research
|
||||
- Exit Status: success ✓
|
||||
- Retry Count: 0 ✓
|
||||
- Total Steps: 2 ✓
|
||||
- No attention flags ✓
|
||||
|
||||
"Perfect! The fix worked. The intake-collector now:
|
||||
"Perfect! The fix worked. The research now:
|
||||
- Completes successfully on first try
|
||||
- No retries needed
|
||||
- Calls set_output properly
|
||||
@@ -913,6 +1034,9 @@ Your agent should now work correctly!"
|
||||
3. **Don't ignore edge conditions** - Missing edges cause routing failures
|
||||
4. **Don't overlook judge configuration** - Mismatched expectations cause retry loops
|
||||
5. **Don't forget nullable_output_keys** - Optional inputs need explicit marking
|
||||
6. **Don't diagnose "in_progress" as a failure for forever-alive agents** - Agents with `terminal_nodes=[]` are designed to never enter "completed" state. This is intentional. Focus on quality of individual node visits, not session completion status
|
||||
7. **Don't ignore conversation memory issues in long-running sessions** - In continuous conversation mode, history grows across node transitions and loop iterations. Watch for context overflow (tokens_used > 100K), stale data from previous loops affecting edge conditions, and compaction failures that cause the LLM to lose important context
|
||||
8. **Don't confuse "waiting for user" with "stalled"** - Client-facing nodes in forever-alive agents block for user input by design. A session paused at a client-facing node is working correctly, not stalled
|
||||
|
||||
---
|
||||
|
||||
|
||||
+719
-973
File diff suppressed because it is too large
Load Diff
@@ -1,351 +1,333 @@
|
||||
# Example: Testing a YouTube Research Agent
|
||||
# Example: Iterative Testing of a Research Agent
|
||||
|
||||
This example walks through testing a YouTube research agent that finds relevant videos based on a topic.
|
||||
This example walks through the full iterative test loop for a research agent that searches the web, reviews findings, and produces a cited report.
|
||||
|
||||
## Prerequisites
|
||||
## Agent Structure
|
||||
|
||||
- Agent built with hive-create skill at `exports/youtube-research/`
|
||||
- Goal defined with success criteria and constraints
|
||||
|
||||
## Step 1: Load the Goal
|
||||
|
||||
First, load the goal that was defined during the Goal stage:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "youtube-research",
|
||||
"name": "YouTube Research Agent",
|
||||
"description": "Find relevant YouTube videos on a given topic",
|
||||
"success_criteria": [
|
||||
{
|
||||
"id": "find_videos",
|
||||
"description": "Find 3-5 relevant videos",
|
||||
"metric": "video_count",
|
||||
"target": "3-5",
|
||||
"weight": 1.0
|
||||
},
|
||||
{
|
||||
"id": "relevance",
|
||||
"description": "Videos must be relevant to the topic",
|
||||
"metric": "relevance_score",
|
||||
"target": ">0.8",
|
||||
"weight": 0.8
|
||||
}
|
||||
],
|
||||
"constraints": [
|
||||
{
|
||||
"id": "api_limits",
|
||||
"description": "Must not exceed YouTube API rate limits",
|
||||
"constraint_type": "hard",
|
||||
"category": "technical"
|
||||
},
|
||||
{
|
||||
"id": "content_safety",
|
||||
"description": "Must filter out inappropriate content",
|
||||
"constraint_type": "hard",
|
||||
"category": "safety"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
exports/deep_research_agent/
|
||||
├── agent.py # Goal + graph: intake → research → review → report
|
||||
├── nodes/__init__.py # Node definitions (system_prompt, input/output keys)
|
||||
├── config.py # Model config
|
||||
├── mcp_servers.json # Tools: web_search, web_scrape
|
||||
└── tests/ # Test files (we'll create these)
|
||||
```
|
||||
|
||||
## Step 2: Get Constraint Test Guidelines
|
||||
**Goal:** "Rigorous Interactive Research" — find 5+ diverse sources, cite every claim, produce a complete report.
|
||||
|
||||
During the Goal stage (or early Eval), get test guidelines for constraints:
|
||||
---
|
||||
|
||||
## Phase 1: Generate Tests
|
||||
|
||||
### Read the goal
|
||||
|
||||
```python
|
||||
result = generate_constraint_tests(
|
||||
goal_id="youtube-research",
|
||||
goal_json='<goal JSON above>',
|
||||
agent_path="exports/youtube-research"
|
||||
)
|
||||
Read(file_path="exports/deep_research_agent/agent.py")
|
||||
# Extract: goal_id="rigorous-interactive-research"
|
||||
# success_criteria: source-diversity (>=5), citation-coverage (100%), report-completeness (90%)
|
||||
# constraints: no-hallucination, source-attribution
|
||||
```
|
||||
|
||||
**The result contains guidelines (not generated tests):**
|
||||
- `output_file`: Where to write tests
|
||||
- `file_header`: Imports and fixtures to use
|
||||
- `test_template`: Format for test functions
|
||||
- `constraints_formatted`: The constraints to test
|
||||
- `test_guidelines`: Rules for writing tests
|
||||
|
||||
## Step 3: Write Constraint Tests
|
||||
|
||||
Using the guidelines, write tests directly with the Write tool:
|
||||
|
||||
```python
|
||||
# Write constraint tests using the provided file_header and guidelines
|
||||
Write(
|
||||
file_path="exports/youtube-research/tests/test_constraints.py",
|
||||
content='''
|
||||
"""Constraint tests for youtube-research agent."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from exports.youtube_research import default_agent
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("MOCK_MODE"),
|
||||
reason="API key required for real testing."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constraint_api_limits_respected():
|
||||
"""Verify API rate limits are not exceeded."""
|
||||
import time
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
|
||||
for i in range(10):
|
||||
result = await default_agent.run({"topic": f"test_{i}"}, mock_mode=mock_mode)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Should complete without rate limit errors
|
||||
assert "rate limit" not in str(result).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constraint_content_safety_filter():
|
||||
"""Verify inappropriate content is filtered."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "general topic"}, mock_mode=mock_mode)
|
||||
|
||||
for video in result.videos:
|
||||
assert video.safe_for_work is True
|
||||
assert video.age_restricted is False
|
||||
'''
|
||||
)
|
||||
```
|
||||
|
||||
## Step 4: Get Success Criteria Test Guidelines
|
||||
|
||||
After the agent is built, get success criteria test guidelines:
|
||||
### Get test guidelines
|
||||
|
||||
```python
|
||||
result = generate_success_tests(
|
||||
goal_id="youtube-research",
|
||||
goal_json='<goal JSON>',
|
||||
node_names="search_node,filter_node,rank_node,format_node",
|
||||
tool_names="youtube_search,video_details,channel_info",
|
||||
agent_path="exports/youtube-research"
|
||||
goal_id="rigorous-interactive-research",
|
||||
goal_json='{"id": "rigorous-interactive-research", "success_criteria": [{"id": "source-diversity", "description": "Use multiple diverse sources", "target": ">=5"}, {"id": "citation-coverage", "description": "Every claim cites its source", "target": "100%"}, {"id": "report-completeness", "description": "Report answers the research questions", "target": "90%"}]}',
|
||||
node_names="intake,research,review,report",
|
||||
tool_names="web_search,web_scrape",
|
||||
agent_path="exports/deep_research_agent"
|
||||
)
|
||||
```
|
||||
|
||||
## Step 5: Write Success Criteria Tests
|
||||
|
||||
Using the guidelines, write success criteria tests:
|
||||
### Write tests
|
||||
|
||||
```python
|
||||
Write(
|
||||
file_path="exports/youtube-research/tests/test_success_criteria.py",
|
||||
content='''
|
||||
"""Success criteria tests for youtube-research agent."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from exports.youtube_research import default_agent
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("MOCK_MODE"),
|
||||
reason="API key required for real testing."
|
||||
)
|
||||
|
||||
file_path="exports/deep_research_agent/tests/test_success_criteria.py",
|
||||
content=result["file_header"] + '''
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_videos_happy_path():
|
||||
"""Test finding videos for a common topic."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "machine learning"}, mock_mode=mock_mode)
|
||||
|
||||
assert result.success
|
||||
assert 3 <= len(result.videos) <= 5
|
||||
assert all(v.title for v in result.videos)
|
||||
assert all(v.video_id for v in result.videos)
|
||||
|
||||
async def test_success_source_diversity(runner, auto_responder, mock_mode):
|
||||
"""At least 5 diverse sources are found."""
|
||||
await auto_responder.start()
|
||||
try:
|
||||
result = await runner.run({"query": "impact of remote work on productivity"})
|
||||
finally:
|
||||
await auto_responder.stop()
|
||||
assert result.success, f"Agent failed: {result.error}"
|
||||
output = result.output or {}
|
||||
sources = output.get("sources", [])
|
||||
if isinstance(sources, list):
|
||||
assert len(sources) >= 5, f"Expected >= 5 sources, got {len(sources)}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_videos_minimum_boundary():
|
||||
"""Test at minimum threshold (3 videos)."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "niche topic xyz"}, mock_mode=mock_mode)
|
||||
|
||||
assert len(result.videos) >= 3
|
||||
|
||||
async def test_success_citation_coverage(runner, auto_responder, mock_mode):
|
||||
"""Every factual claim in the report cites its source."""
|
||||
await auto_responder.start()
|
||||
try:
|
||||
result = await runner.run({"query": "climate change effects on agriculture"})
|
||||
finally:
|
||||
await auto_responder.stop()
|
||||
assert result.success, f"Agent failed: {result.error}"
|
||||
output = result.output or {}
|
||||
report = output.get("report", "")
|
||||
# Check that report contains numbered references
|
||||
assert "[1]" in str(report) or "[source" in str(report).lower(), "Report lacks citations"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relevance_score_threshold():
|
||||
"""Test relevance scoring meets threshold."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "python programming"}, mock_mode=mock_mode)
|
||||
|
||||
for video in result.videos:
|
||||
assert video.relevance_score > 0.8
|
||||
|
||||
async def test_success_report_completeness(runner, auto_responder, mock_mode):
|
||||
"""Report addresses the original research question."""
|
||||
query = "pros and cons of nuclear energy"
|
||||
await auto_responder.start()
|
||||
try:
|
||||
result = await runner.run({"query": query})
|
||||
finally:
|
||||
await auto_responder.stop()
|
||||
assert result.success, f"Agent failed: {result.error}"
|
||||
output = result.output or {}
|
||||
report = output.get("report", "")
|
||||
assert len(str(report)) > 200, f"Report too short: {len(str(report))} chars"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_videos_no_results_graceful():
|
||||
"""Test graceful handling of no results."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "xyznonexistent123"}, mock_mode=mock_mode)
|
||||
async def test_empty_query_handling(runner, auto_responder, mock_mode):
|
||||
"""Agent handles empty input gracefully."""
|
||||
await auto_responder.start()
|
||||
try:
|
||||
result = await runner.run({"query": ""})
|
||||
finally:
|
||||
await auto_responder.stop()
|
||||
output = result.output or {}
|
||||
assert not result.success or output.get("error"), "Should handle empty query"
|
||||
|
||||
# Should not crash, return empty or message
|
||||
assert result.videos == [] or result.message
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_loop_terminates(runner, auto_responder, mock_mode):
|
||||
"""Feedback loop between review and research terminates."""
|
||||
await auto_responder.start()
|
||||
try:
|
||||
result = await runner.run({"query": "quantum computing basics"})
|
||||
finally:
|
||||
await auto_responder.stop()
|
||||
visits = result.node_visit_counts or {}
|
||||
for node_id, count in visits.items():
|
||||
assert count <= 5, f"Node {node_id} visited {count} times"
|
||||
'''
|
||||
)
|
||||
```
|
||||
|
||||
## Step 6: Run All Tests
|
||||
---
|
||||
|
||||
Execute all tests:
|
||||
## Phase 2: First Execution
|
||||
|
||||
```python
|
||||
result = run_tests(
|
||||
goal_id="youtube-research",
|
||||
agent_path="exports/youtube-research",
|
||||
test_types='["all"]',
|
||||
parallel=4
|
||||
run_tests(
|
||||
goal_id="rigorous-interactive-research",
|
||||
agent_path="exports/deep_research_agent",
|
||||
fail_fast=True
|
||||
)
|
||||
```
|
||||
|
||||
**Results:**
|
||||
|
||||
**Result:**
|
||||
```json
|
||||
{
|
||||
"goal_id": "youtube-research",
|
||||
"overall_passed": false,
|
||||
"summary": {
|
||||
"total": 6,
|
||||
"passed": 5,
|
||||
"failed": 1,
|
||||
"pass_rate": "83.3%"
|
||||
},
|
||||
"duration_ms": 4521,
|
||||
"results": [
|
||||
{"test_id": "test_constraint_api_001", "passed": true, "duration_ms": 1234},
|
||||
{"test_id": "test_constraint_content_001", "passed": true, "duration_ms": 456},
|
||||
{"test_id": "test_success_001", "passed": true, "duration_ms": 789},
|
||||
{"test_id": "test_success_002", "passed": true, "duration_ms": 654},
|
||||
{"test_id": "test_success_003", "passed": true, "duration_ms": 543},
|
||||
{"test_id": "test_success_004", "passed": false, "duration_ms": 845,
|
||||
"error_category": "IMPLEMENTATION_ERROR",
|
||||
"error_message": "TypeError: 'NoneType' object has no attribute 'videos'"}
|
||||
]
|
||||
"overall_passed": false,
|
||||
"summary": {"total": 5, "passed": 3, "failed": 2, "pass_rate": "60.0%"},
|
||||
"failures": [
|
||||
{"test_name": "test_success_source_diversity", "details": "AssertionError: Expected >= 5 sources, got 2"},
|
||||
{"test_name": "test_success_citation_coverage", "details": "AssertionError: Report lacks citations"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Step 7: Debug the Failed Test
|
||||
---
|
||||
|
||||
## Phase 3: Analyze (Iteration 1)
|
||||
|
||||
### Debug the first failure
|
||||
|
||||
```python
|
||||
result = debug_test(
|
||||
goal_id="youtube-research",
|
||||
test_name="test_find_videos_no_results_graceful",
|
||||
agent_path="exports/youtube-research"
|
||||
debug_test(
|
||||
goal_id="rigorous-interactive-research",
|
||||
test_name="test_success_source_diversity",
|
||||
agent_path="exports/deep_research_agent"
|
||||
)
|
||||
# Category: ASSERTION_FAILURE — Expected >= 5 sources, got 2
|
||||
```
|
||||
|
||||
### Find the session and inspect memory
|
||||
|
||||
```python
|
||||
list_agent_sessions(
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
status="completed",
|
||||
limit=1
|
||||
)
|
||||
# → session_20260209_150000_abc12345
|
||||
|
||||
get_agent_session_memory(
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
session_id="session_20260209_150000_abc12345",
|
||||
key="research_results"
|
||||
)
|
||||
# → Only 2 sources found. LLM stopped searching after 2 queries.
|
||||
```
|
||||
|
||||
### Check LLM behavior in the research node
|
||||
|
||||
```python
|
||||
query_runtime_log_raw(
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
run_id="session_20260209_150000_abc12345",
|
||||
node_id="research"
|
||||
)
|
||||
# → LLM called web_search twice, got results, immediately called set_output.
|
||||
# → Prompt doesn't instruct it to find at least 5 sources.
|
||||
```
|
||||
|
||||
**Root cause:** The research node's system_prompt doesn't specify minimum source requirements.
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Fix (Iteration 1)
|
||||
|
||||
```python
|
||||
Read(file_path="exports/deep_research_agent/nodes/__init__.py")
|
||||
|
||||
# Fix the research node prompt
|
||||
Edit(
|
||||
file_path="exports/deep_research_agent/nodes/__init__.py",
|
||||
old_string='system_prompt="Search for information on the user\'s topic using web search."',
|
||||
new_string='system_prompt="Search for information on the user\'s topic using web search. You MUST find at least 5 diverse, authoritative sources. Use multiple different search queries with varied keywords. Do NOT call set_output until you have gathered at least 5 distinct sources from different domains."'
|
||||
)
|
||||
```
|
||||
|
||||
**Debug Output:**
|
||||
---
|
||||
|
||||
## Phase 5: Recover & Resume (Iteration 1)
|
||||
|
||||
The fix is to the `research` node. Since this was a `run_tests` execution (no checkpoints), we re-run from scratch:
|
||||
|
||||
```python
|
||||
run_tests(
|
||||
goal_id="rigorous-interactive-research",
|
||||
agent_path="exports/deep_research_agent",
|
||||
fail_fast=True
|
||||
)
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```json
|
||||
{
|
||||
"test_id": "test_success_004",
|
||||
"test_name": "test_find_videos_no_results_graceful",
|
||||
"input": {"topic": "xyznonexistent123"},
|
||||
"expected": "Empty list or message",
|
||||
"actual": {"error": "TypeError: 'NoneType' object has no attribute 'videos'"},
|
||||
"passed": false,
|
||||
"error_message": "TypeError: 'NoneType' object has no attribute 'videos'",
|
||||
"error_category": "IMPLEMENTATION_ERROR",
|
||||
"stack_trace": "Traceback (most recent call last):\n File \"filter_node.py\", line 42\n for video in result.videos:\nTypeError: 'NoneType' object has no attribute 'videos'",
|
||||
"logs": [
|
||||
{"timestamp": "2026-01-20T10:00:01", "node": "search_node", "level": "INFO", "msg": "Searching for: xyznonexistent123"},
|
||||
{"timestamp": "2026-01-20T10:00:02", "node": "search_node", "level": "WARNING", "msg": "No results found"},
|
||||
{"timestamp": "2026-01-20T10:00:02", "node": "filter_node", "level": "ERROR", "msg": "NoneType error"}
|
||||
],
|
||||
"runtime_data": {
|
||||
"execution_path": ["start", "search_node", "filter_node"],
|
||||
"node_outputs": {
|
||||
"search_node": null
|
||||
}
|
||||
},
|
||||
"suggested_fix": "Add null check in filter_node before accessing .videos attribute",
|
||||
"iteration_guidance": {
|
||||
"stage": "Agent",
|
||||
"action": "Fix the code in nodes/edges",
|
||||
"restart_required": false,
|
||||
"description": "The goal is correct, but filter_node doesn't handle null results from search_node."
|
||||
}
|
||||
"overall_passed": false,
|
||||
"summary": {"total": 5, "passed": 4, "failed": 1, "pass_rate": "80.0%"},
|
||||
"failures": [
|
||||
{"test_name": "test_success_citation_coverage", "details": "AssertionError: Report lacks citations"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Step 8: Iterate Based on Category
|
||||
Source diversity now passes. Citation coverage still fails.
|
||||
|
||||
Since this is an **IMPLEMENTATION_ERROR**, we:
|
||||
---
|
||||
|
||||
1. **Don't restart** the Goal → Agent → Eval flow
|
||||
2. **Fix the agent** using hive-create skill:
|
||||
- Modify `filter_node` to handle null results
|
||||
3. **Re-run Eval** (tests only)
|
||||
|
||||
### Fix in hive-create:
|
||||
## Phase 3: Analyze (Iteration 2)
|
||||
|
||||
```python
|
||||
# Update the filter_node to handle null
|
||||
add_node(
|
||||
node_id="filter_node",
|
||||
name="Filter Node",
|
||||
description="Filter and rank videos",
|
||||
node_type="function",
|
||||
input_keys=["search_results"],
|
||||
output_keys=["filtered_videos"],
|
||||
system_prompt="""
|
||||
Filter videos by relevance.
|
||||
IMPORTANT: Handle case where search_results is None or empty.
|
||||
Return empty list if no results.
|
||||
"""
|
||||
debug_test(
|
||||
goal_id="rigorous-interactive-research",
|
||||
test_name="test_success_citation_coverage",
|
||||
agent_path="exports/deep_research_agent"
|
||||
)
|
||||
# Category: ASSERTION_FAILURE — Report lacks citations
|
||||
|
||||
# Check what the report node produced
|
||||
list_agent_sessions(
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
status="completed",
|
||||
limit=1
|
||||
)
|
||||
# → session_20260209_151500_def67890
|
||||
|
||||
get_agent_session_memory(
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
session_id="session_20260209_151500_def67890",
|
||||
key="report"
|
||||
)
|
||||
# → Report text exists but uses no numbered references.
|
||||
# → Sources are in memory but report node doesn't cite them.
|
||||
```
|
||||
|
||||
**Root cause:** The report node's prompt doesn't instruct the LLM to include numbered citations.
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Fix (Iteration 2)
|
||||
|
||||
```python
|
||||
Edit(
|
||||
file_path="exports/deep_research_agent/nodes/__init__.py",
|
||||
old_string='system_prompt="Write a comprehensive report based on the research findings."',
|
||||
new_string='system_prompt="Write a comprehensive report based on the research findings. You MUST include numbered citations [1], [2], etc. for every factual claim. At the end, include a References section listing all sources with their URLs. Every claim must be traceable to a specific source."'
|
||||
)
|
||||
```
|
||||
|
||||
### Re-export and re-test:
|
||||
---
|
||||
|
||||
## Phase 5: Resume (Iteration 2)
|
||||
|
||||
The fix is to the `report` node (the last node). To demonstrate checkpoint recovery, run via CLI:
|
||||
|
||||
```bash
|
||||
# Run via CLI to get checkpoints
|
||||
uv run hive run exports/deep_research_agent --input '{"topic": "climate change effects"}'
|
||||
|
||||
# After it runs, find the clean checkpoint before report
|
||||
list_agent_checkpoints(
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
session_id="session_20260209_152000_ghi34567",
|
||||
is_clean="true"
|
||||
)
|
||||
# → cp_node_complete_review_152100 (after review, before report)
|
||||
|
||||
# Resume — skips intake, research, review entirely
|
||||
uv run hive run exports/deep_research_agent \
|
||||
--resume-session session_20260209_152000_ghi34567 \
|
||||
--checkpoint cp_node_complete_review_152100
|
||||
```
|
||||
|
||||
Only the `report` node re-runs with the fixed prompt, using research data from the checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Phase 6: Final Verification
|
||||
|
||||
```python
|
||||
# Re-export the fixed agent
|
||||
export_graph(path="exports/youtube-research")
|
||||
|
||||
# Re-run tests
|
||||
result = run_tests(
|
||||
goal_id="youtube-research",
|
||||
agent_path="exports/youtube-research",
|
||||
test_types='["all"]'
|
||||
run_tests(
|
||||
goal_id="rigorous-interactive-research",
|
||||
agent_path="exports/deep_research_agent"
|
||||
)
|
||||
```
|
||||
|
||||
**Updated Results:**
|
||||
|
||||
**Result:**
|
||||
```json
|
||||
{
|
||||
"goal_id": "youtube-research",
|
||||
"overall_passed": true,
|
||||
"summary": {
|
||||
"total": 6,
|
||||
"passed": 6,
|
||||
"failed": 0,
|
||||
"pass_rate": "100.0%"
|
||||
}
|
||||
"overall_passed": true,
|
||||
"summary": {"total": 5, "passed": 5, "failed": 0, "pass_rate": "100.0%"}
|
||||
}
|
||||
```
|
||||
|
||||
All tests pass.
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
1. **Got guidelines** for constraint tests during Goal stage
|
||||
2. **Wrote** constraint tests using Write tool
|
||||
3. **Got guidelines** for success criteria tests during Eval stage
|
||||
4. **Wrote** success criteria tests using Write tool
|
||||
5. **Ran** tests in parallel
|
||||
6. **Debugged** the one failure
|
||||
7. **Categorized** as IMPLEMENTATION_ERROR
|
||||
8. **Fixed** the agent (not the goal)
|
||||
9. **Re-ran** Eval only (didn't restart full flow)
|
||||
10. **Passed** all tests
|
||||
| Iteration | Failure | Root Cause | Fix | Recovery |
|
||||
|-----------|---------|------------|-----|----------|
|
||||
| 1 | Source diversity (2 < 5) | Research prompt too vague | Added "at least 5 sources" to prompt | Re-run (no checkpoints) |
|
||||
| 2 | No citations in report | Report prompt lacks citation instructions | Added citation requirements | Checkpoint resume (skipped 3 nodes) |
|
||||
|
||||
The agent is now validated and ready for production use.
|
||||
**Key takeaways:**
|
||||
- Phase 3 analysis (session memory + L3 logs) identified root causes without guessing
|
||||
- Checkpoint recovery in iteration 2 saved time by skipping 3 expensive nodes
|
||||
- Final `run_tests` confirms all scenarios pass end-to-end
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# Project-level Codex config for Hive.
|
||||
# Keep this file minimal: MCP connectivity + skill discovery.
|
||||
|
||||
[mcp_servers.agent-builder]
|
||||
command = "uv"
|
||||
args = ["run", "--directory", "core", "-m", "framework.mcp.agent_builder_server"]
|
||||
cwd = "."
|
||||
+3
-1
@@ -74,4 +74,6 @@ exports/*
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
screenshots/*
|
||||
|
||||
screenshots/*
|
||||
|
||||
|
||||
@@ -4,11 +4,6 @@
|
||||
"command": "uv",
|
||||
"args": ["run", "-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core"
|
||||
},
|
||||
"tools": {
|
||||
"command": "uv",
|
||||
"args": ["run", "mcp_server.py", "--stdio"],
|
||||
"cwd": "tools"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"python",
|
||||
"-m",
|
||||
"framework.mcp.agent_builder_server"
|
||||
],
|
||||
"cwd": "core",
|
||||
"env": {
|
||||
"PYTHONPATH": "../tools/src"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"python",
|
||||
"mcp_server.py",
|
||||
"--stdio"
|
||||
],
|
||||
"cwd": "tools",
|
||||
"env": {
|
||||
"PYTHONPATH": "src"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-concepts
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-create
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-credentials
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-debugger
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-patterns
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-test
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/triage-issue
|
||||
+6
-4
@@ -49,8 +49,8 @@ You may submit PRs without prior assignment for:
|
||||
make check # Lint and format checks (ruff check + ruff format --check on core/ and tools/)
|
||||
make test # Core tests (cd core && pytest tests/ -v)
|
||||
```
|
||||
6. Commit your changes following our commit conventions
|
||||
7. Push to your fork and submit a Pull Request
|
||||
8. Commit your changes following our commit conventions
|
||||
9. Push to your fork and submit a Pull Request
|
||||
|
||||
## Development Setup
|
||||
|
||||
@@ -99,8 +99,7 @@ docs(readme): update installation instructions
|
||||
2. Update documentation if needed
|
||||
3. Add tests for new functionality
|
||||
4. Ensure `make check` and `make test` pass
|
||||
5. Update the CHANGELOG.md if applicable
|
||||
6. Request review from maintainers
|
||||
5. Request review from maintainers
|
||||
|
||||
### PR Title Format
|
||||
|
||||
@@ -145,6 +144,9 @@ make test
|
||||
# Or run tests directly
|
||||
cd core && pytest tests/ -v
|
||||
|
||||
# Run tools package tests (when contributing to tools/)
|
||||
cd tools && uv run pytest tests/ -v
|
||||
|
||||
# Run tests for a specific agent
|
||||
PYTHONPATH=exports uv run python -m agent_name test
|
||||
```
|
||||
|
||||
@@ -22,7 +22,6 @@
|
||||
<img src="https://img.shields.io/badge/MCP-102_Tools-00ADD8?style=flat-square" alt="MCP" />
|
||||
</p>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/AI_Agents-Self--Improving-brightgreen?style=flat-square" alt="AI Agents" />
|
||||
<img src="https://img.shields.io/badge/Multi--Agent-Systems-blue?style=flat-square" alt="Multi-Agent" />
|
||||
@@ -82,12 +81,17 @@ Use Hive when you need:
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.11+ for agent development
|
||||
- Claude Code or Cursor for utilizing agent skills
|
||||
- Claude Code, Codex CLI, or Cursor for utilizing agent skills
|
||||
|
||||
> **Note for Windows Users:** It is strongly recommended to use **WSL (Windows Subsystem for Linux)** or **Git Bash** to run this framework. Some core automation scripts may not execute correctly in standard Command Prompt or PowerShell.
|
||||
|
||||
### Installation
|
||||
|
||||
>**Note**
|
||||
> Hive uses a `uv` workspace layout and is not installed with `pip install`.
|
||||
> Running `pip install -e .` from the repository root will create a placeholder package and Hive will not function correctly.
|
||||
> Please use the quickstart script below to set up the environment.
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/adenhq/hive.git
|
||||
@@ -120,9 +124,41 @@ hive tui
|
||||
# Or run directly
|
||||
hive run exports/your_agent_name --input '{"key": "value"}'
|
||||
```
|
||||
## Coding Agent Support
|
||||
### Codex CLI
|
||||
Hive includes native support for [OpenAI Codex CLI](https://github.com/openai/codex) (v0.101.0+).
|
||||
|
||||
1. **Config:** `.codex/config.toml` with `agent-builder` MCP server (tracked in git)
|
||||
2. **Skills:** `.agents/skills/` symlinks to Hive skills (tracked in git)
|
||||
3. **Launch:** Run `codex` in the repo root, then type `use hive`
|
||||
|
||||
Example:
|
||||
```
|
||||
codex> use hive
|
||||
```
|
||||
|
||||
### Opencode
|
||||
Hive includes native support for [Opencode](https://github.com/opencode-ai/opencode).
|
||||
|
||||
1. **Setup:** Run the quickstart script
|
||||
2. **Launch:** Open Opencode in the project root.
|
||||
3. **Activate:** Type `/hive` in the chat to switch to the Hive Agent.
|
||||
4. **Verify:** Ask the agent *"List your tools"* to confirm the connection.
|
||||
|
||||
The agent has access to all Hive skills and can scaffold agents, add tools, and debug workflows directly from the chat.
|
||||
|
||||
**[📖 Complete Setup Guide](docs/environment-setup.md)** - Detailed instructions for agent development
|
||||
|
||||
### Antigravity IDE Support
|
||||
|
||||
Skills and MCP servers are also available in [Antigravity IDE](https://antigravity.google/) (Google's AI-powered IDE). **Easiest:** open a terminal in the hive repo folder and run (use `./` — the script is inside the repo):
|
||||
|
||||
```bash
|
||||
./scripts/setup-antigravity-mcp.sh
|
||||
```
|
||||
|
||||
**Important:** Always restart/refresh Antigravity IDE after running the setup script—MCP servers only load on startup. After restart, **agent-builder** and **tools** MCP servers should connect. Skills are under `.agent/skills/` (symlinks to `.claude/skills/`). See [docs/antigravity-setup.md](docs/antigravity-setup.md) for manual setup and troubleshooting.
|
||||
|
||||
## Features
|
||||
|
||||
- **[Goal-Driven Development](docs/key_concepts/goals_outcome.md)** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
|
||||
@@ -302,6 +338,7 @@ subgraph Expansion
|
||||
j2["Cursor"]
|
||||
j3["Opencode"]
|
||||
j4["Antigravity"]
|
||||
j5["Codex CLI"]
|
||||
end
|
||||
subgraph plat["Platform"]
|
||||
k1["JavaScript/TypeScript SDK"]
|
||||
|
||||
@@ -400,9 +400,13 @@ class GraphBuilder:
|
||||
if not terminal_candidates and self.session.nodes:
|
||||
warnings.append("No terminal nodes found (all nodes have outgoing edges)")
|
||||
|
||||
# Check reachability
|
||||
# Check reachability from ALL entry candidates (not just the first one).
|
||||
# Agents with async entry points have multiple nodes with no incoming
|
||||
# edges (e.g., a primary entry node and an event-driven entry node).
|
||||
if entry_candidates and self.session.nodes:
|
||||
reachable = self._compute_reachable(entry_candidates[0])
|
||||
reachable = set()
|
||||
for candidate in entry_candidates:
|
||||
reachable |= self._compute_reachable(candidate)
|
||||
unreachable = [n.id for n in self.session.nodes if n.id not in reachable]
|
||||
if unreachable:
|
||||
errors.append(f"Unreachable nodes: {unreachable}")
|
||||
|
||||
@@ -6,6 +6,7 @@ helper functions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -24,7 +25,7 @@ def get_hive_config() -> dict[str, Any]:
|
||||
if not HIVE_CONFIG_FILE.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(HIVE_CONFIG_FILE) as f:
|
||||
with open(HIVE_CONFIG_FILE, encoding="utf-8-sig") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
@@ -48,6 +49,15 @@ def get_max_tokens() -> int:
|
||||
return get_hive_config().get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
|
||||
|
||||
|
||||
def get_api_key() -> str | None:
|
||||
"""Return the API key from the environment variable specified in configuration."""
|
||||
llm = get_hive_config().get("llm", {})
|
||||
api_key_env_var = llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
return os.environ.get(api_key_env_var)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RuntimeConfig – shared across agent templates
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -60,5 +70,5 @@ class RuntimeConfig:
|
||||
model: str = field(default_factory=get_preferred_model)
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = field(default_factory=get_max_tokens)
|
||||
api_key: str | None = None
|
||||
api_key: str | None = field(default_factory=get_api_key)
|
||||
api_base: str | None = None
|
||||
|
||||
@@ -68,6 +68,7 @@ from .storage import (
|
||||
)
|
||||
from .store import CredentialStore
|
||||
from .template import TemplateResolver
|
||||
from .validation import ensure_credential_key_env, validate_agent_credentials
|
||||
|
||||
# Aden sync components (lazy import to avoid httpx dependency when not needed)
|
||||
# Usage: from core.framework.credentials.aden import AdenSyncProvider
|
||||
@@ -111,6 +112,9 @@ __all__ = [
|
||||
"CredentialRefreshError",
|
||||
"CredentialValidationError",
|
||||
"CredentialDecryptionError",
|
||||
# Validation
|
||||
"ensure_credential_key_env",
|
||||
"validate_agent_credentials",
|
||||
# Aden sync (optional - requires httpx)
|
||||
"AdenSyncProvider",
|
||||
"AdenCredentialClient",
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
"""Credential validation utilities.
|
||||
|
||||
Provides reusable credential validation for agents, whether run through
|
||||
the AgentRunner or directly via GraphExecutor.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_credential_key_env() -> None:
|
||||
"""Load HIVE_CREDENTIAL_KEY from shell config if not already in environment.
|
||||
|
||||
The setup-credentials skill writes the encryption key to ~/.zshrc or ~/.bashrc.
|
||||
If the user hasn't sourced their config in the current shell, this reads it
|
||||
directly so the runner (and any MCP subprocesses it spawns) can unlock the
|
||||
encrypted credential store.
|
||||
|
||||
Only HIVE_CREDENTIAL_KEY is loaded this way — all other secrets (API keys, etc.)
|
||||
come from the credential store itself.
|
||||
"""
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
return
|
||||
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import check_env_var_in_shell_config
|
||||
|
||||
found, value = check_env_var_in_shell_config("HIVE_CREDENTIAL_KEY")
|
||||
if found and value:
|
||||
os.environ["HIVE_CREDENTIAL_KEY"] = value
|
||||
logger.debug("Loaded HIVE_CREDENTIAL_KEY from shell config")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def validate_agent_credentials(nodes: list) -> None:
|
||||
"""Check that required credentials are available before running an agent.
|
||||
|
||||
Scans node specs for required tools and node types, then checks whether
|
||||
the corresponding credentials exist in the credential store.
|
||||
|
||||
Raises CredentialError with actionable guidance if any are missing.
|
||||
|
||||
Args:
|
||||
nodes: List of NodeSpec objects from the agent graph.
|
||||
"""
|
||||
required_tools: set[str] = set()
|
||||
for node in nodes:
|
||||
if node.tools:
|
||||
required_tools.update(node.tools)
|
||||
node_types: set[str] = {node.node_type for node in nodes}
|
||||
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
from framework.credentials import CredentialStore
|
||||
from framework.credentials.storage import (
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
except ImportError:
|
||||
return # aden_tools not installed, skip check
|
||||
|
||||
# Build credential store
|
||||
env_mapping = {
|
||||
(spec.credential_id or name): spec.env_var for name, spec in CREDENTIAL_SPECS.items()
|
||||
}
|
||||
storages: list = [EnvVarStorage(env_mapping=env_mapping)]
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
storages.insert(0, EncryptedFileStorage())
|
||||
if len(storages) == 1:
|
||||
storage = storages[0]
|
||||
else:
|
||||
storage = CompositeStorage(primary=storages[0], fallbacks=storages[1:])
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Build reverse mappings
|
||||
tool_to_cred: dict[str, str] = {}
|
||||
node_type_to_cred: dict[str, str] = {}
|
||||
for cred_name, spec in CREDENTIAL_SPECS.items():
|
||||
for tool_name in spec.tools:
|
||||
tool_to_cred[tool_name] = cred_name
|
||||
for nt in spec.node_types:
|
||||
node_type_to_cred[nt] = cred_name
|
||||
|
||||
missing: list[str] = []
|
||||
checked: set[str] = set()
|
||||
|
||||
# Check tool credentials
|
||||
for tool_name in sorted(required_tools):
|
||||
cred_name = tool_to_cred.get(tool_name)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected = sorted(t for t in required_tools if t in spec.tools)
|
||||
entry = f" {spec.env_var} for {', '.join(affected)}"
|
||||
if spec.help_url:
|
||||
entry += f"\n Get it at: {spec.help_url}"
|
||||
missing.append(entry)
|
||||
|
||||
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
|
||||
for nt in sorted(node_types):
|
||||
cred_name = node_type_to_cred.get(nt)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected_types = sorted(t for t in node_types if t in spec.node_types)
|
||||
entry = f" {spec.env_var} for {', '.join(affected_types)} nodes"
|
||||
if spec.help_url:
|
||||
entry += f"\n Get it at: {spec.help_url}"
|
||||
missing.append(entry)
|
||||
|
||||
if missing:
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
lines = ["Missing required credentials:\n"]
|
||||
lines.extend(missing)
|
||||
lines.append(
|
||||
"\nTo fix: run /hive-credentials in Claude Code."
|
||||
"\nIf you've already set up credentials, restart your terminal to load them."
|
||||
)
|
||||
raise CredentialError("\n".join(lines))
|
||||
@@ -27,6 +27,9 @@ class Message:
|
||||
tool_use_id: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
is_error: bool = False
|
||||
# Phase-aware compaction metadata (continuous mode)
|
||||
phase_id: str | None = None
|
||||
is_transition_marker: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
@@ -60,6 +63,10 @@ class Message:
|
||||
d["tool_calls"] = self.tool_calls
|
||||
if self.is_error:
|
||||
d["is_error"] = self.is_error
|
||||
if self.phase_id is not None:
|
||||
d["phase_id"] = self.phase_id
|
||||
if self.is_transition_marker:
|
||||
d["is_transition_marker"] = self.is_transition_marker
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
@@ -72,6 +79,8 @@ class Message:
|
||||
tool_use_id=data.get("tool_use_id"),
|
||||
tool_calls=data.get("tool_calls"),
|
||||
is_error=data.get("is_error", False),
|
||||
phase_id=data.get("phase_id"),
|
||||
is_transition_marker=data.get("is_transition_marker", False),
|
||||
)
|
||||
|
||||
|
||||
@@ -188,6 +197,7 @@ class NodeConversation:
|
||||
self._next_seq: int = 0
|
||||
self._meta_persisted: bool = False
|
||||
self._last_api_input_tokens: int | None = None
|
||||
self._current_phase: str | None = None
|
||||
|
||||
# --- Properties --------------------------------------------------------
|
||||
|
||||
@@ -195,6 +205,33 @@ class NodeConversation:
|
||||
def system_prompt(self) -> str:
|
||||
return self._system_prompt
|
||||
|
||||
def update_system_prompt(self, new_prompt: str) -> None:
|
||||
"""Update the system prompt.
|
||||
|
||||
Used in continuous conversation mode at phase transitions to swap
|
||||
Layer 3 (focus) while preserving the conversation history.
|
||||
"""
|
||||
self._system_prompt = new_prompt
|
||||
|
||||
def set_current_phase(self, phase_id: str) -> None:
|
||||
"""Set the current phase ID. Subsequent messages will be stamped with it."""
|
||||
self._current_phase = phase_id
|
||||
|
||||
async def switch_store(self, new_store: ConversationStore) -> None:
|
||||
"""Switch to a new persistence store at a phase transition.
|
||||
|
||||
Subsequent messages are written to *new_store*. Meta (system
|
||||
prompt, config) is re-persisted on the next write so the new
|
||||
store's ``meta.json`` reflects the updated prompt.
|
||||
"""
|
||||
self._store = new_store
|
||||
self._meta_persisted = False
|
||||
await new_store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
@property
|
||||
def current_phase(self) -> str | None:
|
||||
return self._current_phase
|
||||
|
||||
@property
|
||||
def messages(self) -> list[Message]:
|
||||
"""Return a defensive copy of the message list."""
|
||||
@@ -216,8 +253,19 @@ class NodeConversation:
|
||||
|
||||
# --- Add messages ------------------------------------------------------
|
||||
|
||||
async def add_user_message(self, content: str) -> Message:
|
||||
msg = Message(seq=self._next_seq, role="user", content=content)
|
||||
async def add_user_message(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
is_transition_marker: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
role="user",
|
||||
content=content,
|
||||
phase_id=self._current_phase,
|
||||
is_transition_marker=is_transition_marker,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
await self._persist(msg)
|
||||
@@ -233,6 +281,7 @@ class NodeConversation:
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
phase_id=self._current_phase,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -251,6 +300,7 @@ class NodeConversation:
|
||||
content=content,
|
||||
tool_use_id=tool_use_id,
|
||||
is_error=is_error,
|
||||
phase_id=self._current_phase,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -380,6 +430,11 @@ class NodeConversation:
|
||||
spillover filename reference (if any). Message structure (role,
|
||||
seq, tool_use_id) stays valid for the LLM API.
|
||||
|
||||
Phase-aware behavior (continuous mode): when messages have ``phase_id``
|
||||
metadata, all messages in the current phase are protected regardless of
|
||||
token budget. Transition markers are never pruned. Older phases' tool
|
||||
results are pruned more aggressively.
|
||||
|
||||
Error tool results are never pruned — they prevent re-calling
|
||||
failing tools.
|
||||
|
||||
@@ -388,13 +443,18 @@ class NodeConversation:
|
||||
if not self._messages:
|
||||
return 0
|
||||
|
||||
# Phase 1: Walk backward, classify tool results as protected vs pruneable
|
||||
# Walk backward, classify tool results as protected vs pruneable
|
||||
protected_tokens = 0
|
||||
pruneable: list[int] = [] # indices into self._messages
|
||||
pruneable_tokens = 0
|
||||
|
||||
for i in range(len(self._messages) - 1, -1, -1):
|
||||
msg = self._messages[i]
|
||||
|
||||
# Transition markers are never pruned (any role)
|
||||
if msg.is_transition_marker:
|
||||
continue
|
||||
|
||||
if msg.role != "tool":
|
||||
continue
|
||||
if msg.is_error:
|
||||
@@ -402,6 +462,10 @@ class NodeConversation:
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
|
||||
# Phase-aware: protect current phase messages
|
||||
if self._current_phase and msg.phase_id == self._current_phase:
|
||||
continue
|
||||
|
||||
est = len(msg.content) // 4
|
||||
if protected_tokens < protect_tokens:
|
||||
protected_tokens += est
|
||||
@@ -409,11 +473,11 @@ class NodeConversation:
|
||||
pruneable.append(i)
|
||||
pruneable_tokens += est
|
||||
|
||||
# Phase 2: Only prune if enough to be worthwhile
|
||||
# Only prune if enough to be worthwhile
|
||||
if pruneable_tokens < min_prune_tokens:
|
||||
return 0
|
||||
|
||||
# Phase 3: Replace content with compact placeholder
|
||||
# Replace content with compact placeholder
|
||||
count = 0
|
||||
for i in pruneable:
|
||||
msg = self._messages[i]
|
||||
@@ -436,6 +500,8 @@ class NodeConversation:
|
||||
tool_use_id=msg.tool_use_id,
|
||||
tool_calls=msg.tool_calls,
|
||||
is_error=msg.is_error,
|
||||
phase_id=msg.phase_id,
|
||||
is_transition_marker=msg.is_transition_marker,
|
||||
)
|
||||
count += 1
|
||||
|
||||
@@ -446,22 +512,38 @@ class NodeConversation:
|
||||
self._last_api_input_tokens = None
|
||||
return count
|
||||
|
||||
async def compact(self, summary: str, keep_recent: int = 2) -> None:
|
||||
async def compact(
|
||||
self,
|
||||
summary: str,
|
||||
keep_recent: int = 2,
|
||||
phase_graduated: bool = False,
|
||||
) -> None:
|
||||
"""Replace old messages with a summary, optionally keeping recent ones.
|
||||
|
||||
Args:
|
||||
summary: Caller-provided summary text.
|
||||
keep_recent: Number of recent messages to preserve (default 2).
|
||||
Clamped to [0, len(messages) - 1].
|
||||
phase_graduated: When True and messages have phase_id metadata,
|
||||
split at phase boundaries instead of using keep_recent.
|
||||
Keeps current + previous phase intact; compacts older phases.
|
||||
"""
|
||||
if not self._messages:
|
||||
return
|
||||
|
||||
# Clamp: must discard at least 1 message
|
||||
keep_recent = max(0, min(keep_recent, len(self._messages) - 1))
|
||||
|
||||
total = len(self._messages)
|
||||
split = total - keep_recent if keep_recent > 0 else total
|
||||
|
||||
# Phase-graduated: find the split point based on phase boundaries.
|
||||
# Keeps current phase + previous phase intact, compacts older phases.
|
||||
if phase_graduated and self._current_phase:
|
||||
split = self._find_phase_graduated_split()
|
||||
else:
|
||||
split = None
|
||||
|
||||
if split is None:
|
||||
# Fallback: use keep_recent (non-phase or single-phase conversation)
|
||||
keep_recent = max(0, min(keep_recent, total - 1))
|
||||
split = total - keep_recent if keep_recent > 0 else total
|
||||
|
||||
# Advance split past orphaned tool results at the boundary.
|
||||
# Tool-role messages reference a tool_use from the preceding
|
||||
@@ -470,6 +552,10 @@ class NodeConversation:
|
||||
while split < total and self._messages[split].role == "tool":
|
||||
split += 1
|
||||
|
||||
# Nothing to compact
|
||||
if split == 0:
|
||||
return
|
||||
|
||||
old_messages = list(self._messages[:split])
|
||||
recent_messages = list(self._messages[split:])
|
||||
|
||||
@@ -504,6 +590,33 @@ class NodeConversation:
|
||||
self._messages = [summary_msg] + recent_messages
|
||||
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
|
||||
|
||||
def _find_phase_graduated_split(self) -> int | None:
|
||||
"""Find split point that preserves current + previous phase.
|
||||
|
||||
Returns the index of the first message in the protected set,
|
||||
or None if phase graduation doesn't apply (< 3 phases).
|
||||
"""
|
||||
# Collect distinct phases in order of first appearance
|
||||
phases_seen: list[str] = []
|
||||
for msg in self._messages:
|
||||
if msg.phase_id and msg.phase_id not in phases_seen:
|
||||
phases_seen.append(msg.phase_id)
|
||||
|
||||
# Need at least 3 phases for graduation to be meaningful
|
||||
# (current + previous are protected, older get compacted)
|
||||
if len(phases_seen) < 3:
|
||||
return None
|
||||
|
||||
# Protect: current phase + previous phase
|
||||
protected_phases = {phases_seen[-1], phases_seen[-2]}
|
||||
|
||||
# Find split: first message belonging to a protected phase
|
||||
for i, msg in enumerate(self._messages):
|
||||
if msg.phase_id in protected_phases:
|
||||
return i
|
||||
|
||||
return None
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
|
||||
if self._store:
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
"""Level 2 Conversation-Aware Judge.
|
||||
|
||||
When a node has `success_criteria` set, the implicit judge upgrades:
|
||||
after Level 0 passes (all output keys set), a fast LLM call evaluates
|
||||
whether the conversation actually meets the criteria.
|
||||
|
||||
This prevents nodes from "checking boxes" (setting output keys) without
|
||||
doing quality work. The LLM reads the recent conversation and assesses
|
||||
whether the phase's goal was genuinely accomplished.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.provider import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseVerdict:
|
||||
"""Result of Level 2 conversation-aware evaluation."""
|
||||
|
||||
action: str # "ACCEPT" or "RETRY"
|
||||
confidence: float = 0.8
|
||||
feedback: str = ""
|
||||
|
||||
|
||||
async def evaluate_phase_completion(
|
||||
llm: LLMProvider,
|
||||
conversation: NodeConversation,
|
||||
phase_name: str,
|
||||
phase_description: str,
|
||||
success_criteria: str,
|
||||
accumulator_state: dict[str, Any],
|
||||
max_history_tokens: int = 8_196,
|
||||
) -> PhaseVerdict:
|
||||
"""Level 2 judge: read the conversation and evaluate quality.
|
||||
|
||||
Only called after Level 0 passes (all output keys set).
|
||||
|
||||
Args:
|
||||
llm: LLM provider for evaluation
|
||||
conversation: The current conversation to evaluate
|
||||
phase_name: Name of the current phase/node
|
||||
phase_description: Description of the phase
|
||||
success_criteria: Natural-language criteria for phase completion
|
||||
accumulator_state: Current output key values
|
||||
max_history_tokens: Main conversation token budget (judge gets 20%)
|
||||
|
||||
Returns:
|
||||
PhaseVerdict with action and optional feedback
|
||||
"""
|
||||
# Build a compact view of the recent conversation
|
||||
recent_messages = _extract_recent_context(conversation, max_messages=10)
|
||||
outputs_summary = _format_outputs(accumulator_state)
|
||||
|
||||
system_prompt = (
|
||||
"You are a quality judge evaluating whether a phase of work is complete. "
|
||||
"Be concise. Evaluate based on the success criteria, not on style."
|
||||
)
|
||||
|
||||
user_prompt = f"""Evaluate this phase:
|
||||
|
||||
PHASE: {phase_name}
|
||||
DESCRIPTION: {phase_description}
|
||||
|
||||
SUCCESS CRITERIA:
|
||||
{success_criteria}
|
||||
|
||||
OUTPUTS SET:
|
||||
{outputs_summary}
|
||||
|
||||
RECENT CONVERSATION:
|
||||
{recent_messages}
|
||||
|
||||
Has this phase accomplished its goal based on the success criteria?
|
||||
|
||||
Respond in exactly this format:
|
||||
ACTION: ACCEPT or RETRY
|
||||
CONFIDENCE: 0.X
|
||||
FEEDBACK: (reason if RETRY, empty if ACCEPT)"""
|
||||
|
||||
try:
|
||||
response = await llm.acomplete(
|
||||
messages=[{"role": "user", "content": user_prompt}],
|
||||
system=system_prompt,
|
||||
max_tokens=max(1024, max_history_tokens // 5),
|
||||
max_retries=1,
|
||||
)
|
||||
if not response.content or not response.content.strip():
|
||||
logger.debug("Level 2 judge: empty response, accepting by default")
|
||||
return PhaseVerdict(action="ACCEPT", confidence=0.5, feedback="")
|
||||
return _parse_verdict(response.content)
|
||||
except Exception as e:
|
||||
logger.warning(f"Level 2 judge failed, accepting by default: {e}")
|
||||
# On failure, don't block — Level 0 already passed
|
||||
return PhaseVerdict(action="ACCEPT", confidence=0.5, feedback="")
|
||||
|
||||
|
||||
def _extract_recent_context(conversation: NodeConversation, max_messages: int = 10) -> str:
|
||||
"""Extract recent conversation messages for evaluation."""
|
||||
messages = conversation.messages
|
||||
recent = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
|
||||
parts = []
|
||||
for msg in recent:
|
||||
role = msg.role.upper()
|
||||
content = msg.content or ""
|
||||
# Truncate long tool results
|
||||
if msg.role == "tool" and len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
if content.strip():
|
||||
parts.append(f"[{role}]: {content.strip()}")
|
||||
|
||||
return "\n".join(parts) if parts else "(no messages)"
|
||||
|
||||
|
||||
def _format_outputs(accumulator_state: dict[str, Any]) -> str:
|
||||
"""Format output key values for evaluation.
|
||||
|
||||
Lists and dicts get structural formatting so the judge can assess
|
||||
quantity and structure, not just a truncated stringification.
|
||||
"""
|
||||
if not accumulator_state:
|
||||
return "(none)"
|
||||
parts = []
|
||||
for key, value in accumulator_state.items():
|
||||
if isinstance(value, list):
|
||||
# Show count + brief per-item preview so the judge can
|
||||
# verify quantity without the full serialization.
|
||||
items_preview = []
|
||||
for i, item in enumerate(value[:8]):
|
||||
item_str = str(item)
|
||||
if len(item_str) > 150:
|
||||
item_str = item_str[:150] + "..."
|
||||
items_preview.append(f" [{i}]: {item_str}")
|
||||
val_str = f"list ({len(value)} items):\n" + "\n".join(items_preview)
|
||||
if len(value) > 8:
|
||||
val_str += f"\n ... and {len(value) - 8} more"
|
||||
elif isinstance(value, dict):
|
||||
val_str = str(value)
|
||||
if len(val_str) > 400:
|
||||
val_str = val_str[:400] + "..."
|
||||
else:
|
||||
val_str = str(value)
|
||||
if len(val_str) > 300:
|
||||
val_str = val_str[:300] + "..."
|
||||
parts.append(f" {key}: {val_str}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _parse_verdict(response: str) -> PhaseVerdict:
|
||||
"""Parse LLM response into PhaseVerdict."""
|
||||
action = "ACCEPT"
|
||||
confidence = 0.8
|
||||
feedback = ""
|
||||
|
||||
for line in response.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("ACTION:"):
|
||||
action_str = line.split(":", 1)[1].strip().upper()
|
||||
if action_str in ("ACCEPT", "RETRY"):
|
||||
action = action_str
|
||||
elif line.startswith("CONFIDENCE:"):
|
||||
try:
|
||||
confidence = float(line.split(":", 1)[1].strip())
|
||||
except ValueError:
|
||||
pass
|
||||
elif line.startswith("FEEDBACK:"):
|
||||
feedback = line.split(":", 1)[1].strip()
|
||||
|
||||
return PhaseVerdict(action=action, confidence=confidence, feedback=feedback)
|
||||
@@ -21,6 +21,9 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
|
||||
given the current goal, context, and execution state.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
@@ -28,6 +31,8 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_TOKENS = 8192
|
||||
|
||||
|
||||
@@ -99,7 +104,7 @@ class EdgeSpec(BaseModel):
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def should_traverse(
|
||||
async def should_traverse(
|
||||
self,
|
||||
source_success: bool,
|
||||
source_output: dict[str, Any],
|
||||
@@ -140,7 +145,7 @@ class EdgeSpec(BaseModel):
|
||||
if llm is None or goal is None:
|
||||
# Fallback to ON_SUCCESS if LLM not available
|
||||
return source_success
|
||||
return self._llm_decide(
|
||||
return await self._llm_decide(
|
||||
llm=llm,
|
||||
goal=goal,
|
||||
source_success=source_success,
|
||||
@@ -158,9 +163,6 @@ class EdgeSpec(BaseModel):
|
||||
memory: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Evaluate a conditional expression."""
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not self.condition_expr:
|
||||
return True
|
||||
@@ -201,7 +203,7 @@ class EdgeSpec(BaseModel):
|
||||
logger.warning(f" Available context keys: {list(context.keys())}")
|
||||
return False
|
||||
|
||||
def _llm_decide(
|
||||
async def _llm_decide(
|
||||
self,
|
||||
llm: Any,
|
||||
goal: Any,
|
||||
@@ -217,8 +219,6 @@ class EdgeSpec(BaseModel):
|
||||
The LLM evaluates whether proceeding to the target node
|
||||
is the best next step toward achieving the goal.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Build context for LLM
|
||||
prompt = f"""You are evaluating whether to proceed along an edge in an agent workflow.
|
||||
|
||||
@@ -247,15 +247,13 @@ Respond with ONLY a JSON object:
|
||||
{{"proceed": true/false, "reasoning": "brief explanation"}}"""
|
||||
|
||||
try:
|
||||
response = llm.complete(
|
||||
response = await llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a routing agent. Respond with JSON only.",
|
||||
max_tokens=150,
|
||||
)
|
||||
|
||||
# Parse response
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
@@ -263,9 +261,6 @@ Respond with ONLY a JSON object:
|
||||
reasoning = data.get("reasoning", "")
|
||||
|
||||
# Log the decision (using basic print for now)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f" 🤔 LLM routing decision: {'PROCEED' if proceed else 'SKIP'}")
|
||||
logger.info(f" Reason: {reasoning}")
|
||||
|
||||
@@ -273,9 +268,6 @@ Respond with ONLY a JSON object:
|
||||
|
||||
except Exception as e:
|
||||
# Fallback: proceed on success
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f" ⚠ LLM routing failed, defaulting to on_success: {e}")
|
||||
return source_success
|
||||
|
||||
@@ -443,6 +435,25 @@ class GraphSpec(BaseModel):
|
||||
description="EventLoopNode configuration (max_iterations, max_tool_calls_per_turn, etc.)",
|
||||
)
|
||||
|
||||
# Conversation mode
|
||||
conversation_mode: str = Field(
|
||||
default="continuous",
|
||||
description=(
|
||||
"How conversations flow between event_loop nodes. "
|
||||
"'continuous' (default): one conversation threads through all "
|
||||
"event_loop nodes with cumulative tools and layered prompt composition. "
|
||||
"'isolated': each node gets a fresh conversation."
|
||||
),
|
||||
)
|
||||
identity_prompt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Agent-level identity prompt (Layer 1 of the onion model). "
|
||||
"In continuous mode, this is the static identity that persists "
|
||||
"unchanged across all node transitions. In isolated mode, ignored."
|
||||
),
|
||||
)
|
||||
|
||||
# Metadata
|
||||
description: str = ""
|
||||
created_by: str = "" # "human" or "builder_agent"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -186,6 +186,55 @@ class GraphExecutor:
|
||||
# Pause/resume control
|
||||
self._pause_requested = asyncio.Event()
|
||||
|
||||
def _write_progress(
|
||||
self,
|
||||
current_node: str,
|
||||
path: list[str],
|
||||
memory: Any,
|
||||
node_visit_counts: dict[str, int],
|
||||
) -> None:
|
||||
"""Update state.json with live progress at node transitions.
|
||||
|
||||
Reads the existing state.json (written by ExecutionStream at session
|
||||
start) and patches the progress fields in-place. This keeps
|
||||
state.json as the single source of truth — readers always see
|
||||
current progress, not stale initial values.
|
||||
|
||||
The write is synchronous and best-effort: never blocks execution.
|
||||
"""
|
||||
if not self._storage_path:
|
||||
return
|
||||
try:
|
||||
import json as _json
|
||||
from datetime import datetime
|
||||
|
||||
state_path = self._storage_path / "state.json"
|
||||
if state_path.exists():
|
||||
state_data = _json.loads(state_path.read_text(encoding="utf-8"))
|
||||
else:
|
||||
state_data = {}
|
||||
|
||||
# Patch progress fields
|
||||
progress = state_data.setdefault("progress", {})
|
||||
progress["current_node"] = current_node
|
||||
progress["path"] = list(path)
|
||||
progress["node_visit_counts"] = dict(node_visit_counts)
|
||||
progress["steps_executed"] = len(path)
|
||||
|
||||
# Update timestamp
|
||||
timestamps = state_data.setdefault("timestamps", {})
|
||||
timestamps["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
# Persist full memory so state.json is sufficient for resume
|
||||
# even if the process dies before the final write.
|
||||
memory_snapshot = memory.read_all()
|
||||
state_data["memory"] = memory_snapshot
|
||||
state_data["memory_keys"] = list(memory_snapshot.keys())
|
||||
|
||||
state_path.write_text(_json.dumps(state_data, indent=2), encoding="utf-8")
|
||||
except Exception:
|
||||
pass # Best-effort — never block execution
|
||||
|
||||
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
||||
"""
|
||||
Validate that all tools declared by nodes are available.
|
||||
@@ -257,6 +306,13 @@ class GraphExecutor:
|
||||
# Initialize execution state
|
||||
memory = SharedMemory()
|
||||
|
||||
# Continuous conversation mode state
|
||||
is_continuous = getattr(graph, "conversation_mode", "isolated") == "continuous"
|
||||
continuous_conversation = None # NodeConversation threaded across nodes
|
||||
cumulative_tools: list = [] # Tools accumulate, never removed
|
||||
cumulative_tool_names: set[str] = set()
|
||||
cumulative_output_keys: list[str] = [] # Output keys from all visited nodes
|
||||
|
||||
# Initialize checkpoint store if checkpointing is enabled
|
||||
checkpoint_store: CheckpointStore | None = None
|
||||
if checkpoint_config and checkpoint_config.enabled and self._storage_path:
|
||||
@@ -273,13 +329,20 @@ class GraphExecutor:
|
||||
f"{type(memory_data).__name__}, expected dict"
|
||||
)
|
||||
else:
|
||||
# Restore memory from previous session
|
||||
# Restore memory from previous session.
|
||||
# Skip validation — this data was already validated when
|
||||
# originally written, and research text triggers false
|
||||
# positives on the code-indicator heuristic.
|
||||
for key, value in memory_data.items():
|
||||
memory.write(key, value)
|
||||
memory.write(key, value, validate=False)
|
||||
self.logger.info(f"📥 Restored session state with {len(memory_data)} memory keys")
|
||||
|
||||
# Write new input data to memory (each key individually)
|
||||
if input_data:
|
||||
# Write new input data to memory (each key individually).
|
||||
# Skip when resuming from a paused session — restored memory already
|
||||
# contains all state including the original input, and re-writing
|
||||
# input_data would overwrite intermediate results with stale values.
|
||||
_is_resuming = bool(session_state and session_state.get("paused_at"))
|
||||
if input_data and not _is_resuming:
|
||||
for key, value in input_data.items():
|
||||
memory.write(key, value)
|
||||
|
||||
@@ -368,7 +431,7 @@ class GraphExecutor:
|
||||
# Check if resuming from paused_at (session state resume)
|
||||
paused_at = session_state.get("paused_at") if session_state else None
|
||||
node_ids = [n.id for n in graph.nodes]
|
||||
self.logger.info(f"🔍 Debug: paused_at={paused_at}, available node IDs={node_ids}")
|
||||
self.logger.debug(f"paused_at={paused_at}, available node IDs={node_ids}")
|
||||
|
||||
if paused_at and graph.get_node(paused_at) is not None:
|
||||
# Resume from paused_at node directly (works for any node, not just pause_nodes)
|
||||
@@ -396,9 +459,76 @@ class GraphExecutor:
|
||||
|
||||
steps = 0
|
||||
|
||||
# Fresh shared-session execution: clear stale cursor so the entry
|
||||
# node doesn't restore a filled OutputAccumulator from the previous
|
||||
# webhook run (which would cause the judge to accept immediately).
|
||||
# The conversation history is preserved (continuous memory).
|
||||
_is_fresh_shared = bool(
|
||||
session_state
|
||||
and session_state.get("resume_session_id")
|
||||
and not session_state.get("paused_at")
|
||||
and not session_state.get("resume_from_checkpoint")
|
||||
)
|
||||
if _is_fresh_shared and is_continuous and self._storage_path:
|
||||
try:
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
entry_conv_path = self._storage_path / "conversations" / current_node_id
|
||||
if entry_conv_path.exists():
|
||||
_store = FileConversationStore(base_path=entry_conv_path)
|
||||
|
||||
# Read cursor to find next seq for the transition marker.
|
||||
_cursor = await _store.read_cursor() or {}
|
||||
_next_seq = _cursor.get("next_seq", 0)
|
||||
if _next_seq == 0:
|
||||
# Fallback: scan part files for max seq
|
||||
_parts = await _store.read_parts()
|
||||
if _parts:
|
||||
_next_seq = max(p.get("seq", 0) for p in _parts) + 1
|
||||
|
||||
# Reset cursor — clears stale accumulator outputs and
|
||||
# iteration counter so the node starts fresh work while
|
||||
# the conversation thread carries forward.
|
||||
await _store.write_cursor({})
|
||||
|
||||
# Append a transition marker so the LLM knows a new
|
||||
# event arrived and previous results are outdated.
|
||||
await _store.write_part(
|
||||
_next_seq,
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"--- NEW EVENT TRIGGER ---\n"
|
||||
"A new event has been received. "
|
||||
"Process this as a fresh request — "
|
||||
"previous outputs are no longer valid."
|
||||
),
|
||||
"seq": _next_seq,
|
||||
"is_transition_marker": True,
|
||||
},
|
||||
)
|
||||
self.logger.info(
|
||||
"🔄 Cleared stale cursor and added transition marker "
|
||||
"for shared-session entry node '%s'",
|
||||
current_node_id,
|
||||
)
|
||||
except Exception:
|
||||
self.logger.debug(
|
||||
"Could not prepare conversation store for shared-session entry node '%s'",
|
||||
current_node_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if session_state and current_node_id != graph.entry_node:
|
||||
self.logger.info(f"🔄 Resuming from: {current_node_id}")
|
||||
|
||||
# Emit resume event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_resumed(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
)
|
||||
|
||||
# Start run
|
||||
_run_id = self.runtime.start_run(
|
||||
goal_id=goal.id,
|
||||
@@ -435,6 +565,14 @@ class GraphExecutor:
|
||||
if self._pause_requested.is_set():
|
||||
self.logger.info("⏸ Pause detected - stopping at node boundary")
|
||||
|
||||
# Emit pause event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_paused(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
reason="User requested pause (Ctrl+Z)",
|
||||
)
|
||||
|
||||
# Create session state for pause
|
||||
saved_memory = memory.read_all()
|
||||
pause_session_state: dict[str, Any] = {
|
||||
@@ -489,7 +627,7 @@ class GraphExecutor:
|
||||
)
|
||||
# Skip execution — follow outgoing edges using current memory
|
||||
skip_result = NodeResult(success=True, output=memory.read_all())
|
||||
next_node = self._follow_edges(
|
||||
next_node = await self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -505,6 +643,21 @@ class GraphExecutor:
|
||||
|
||||
path.append(current_node_id)
|
||||
|
||||
# Clear stale nullable outputs from previous visits.
|
||||
# When a node is re-visited (e.g. review → process-batch → review),
|
||||
# nullable outputs from the PREVIOUS visit linger in shared memory.
|
||||
# This causes stale edge conditions to fire (e.g. "feedback is not None"
|
||||
# from visit 1 triggers even when visit 2 sets "final_summary" instead).
|
||||
# Clearing them ensures only the CURRENT visit's outputs affect routing.
|
||||
if node_visit_counts.get(current_node_id, 0) > 1:
|
||||
nullable_keys = getattr(node_spec, "nullable_output_keys", None) or []
|
||||
for key in nullable_keys:
|
||||
if memory.read(key) is not None:
|
||||
memory.write(key, None, validate=False)
|
||||
self.logger.info(
|
||||
f" 🧹 Cleared stale nullable output '{key}' from previous visit"
|
||||
)
|
||||
|
||||
# Check if pause (HITL) before execution
|
||||
if current_node_id in graph.pause_nodes:
|
||||
self.logger.info(f"⏸ Paused at HITL node: {node_spec.name}")
|
||||
@@ -515,6 +668,17 @@ class GraphExecutor:
|
||||
self.logger.info(f" Inputs: {node_spec.input_keys}")
|
||||
self.logger.info(f" Outputs: {node_spec.output_keys}")
|
||||
|
||||
# Continuous mode: accumulate tools and output keys from this node
|
||||
if is_continuous and node_spec.tools:
|
||||
for t in self.tools:
|
||||
if t.name in node_spec.tools and t.name not in cumulative_tool_names:
|
||||
cumulative_tools.append(t)
|
||||
cumulative_tool_names.add(t.name)
|
||||
if is_continuous and node_spec.output_keys:
|
||||
for k in node_spec.output_keys:
|
||||
if k not in cumulative_output_keys:
|
||||
cumulative_output_keys.append(k)
|
||||
|
||||
# Build context for node
|
||||
ctx = self._build_context(
|
||||
node_spec=node_spec,
|
||||
@@ -522,6 +686,10 @@ class GraphExecutor:
|
||||
goal=goal,
|
||||
input_data=input_data or {},
|
||||
max_tokens=graph.max_tokens,
|
||||
continuous_mode=is_continuous,
|
||||
inherited_conversation=continuous_conversation if is_continuous else None,
|
||||
override_tools=cumulative_tools if is_continuous else None,
|
||||
cumulative_output_keys=cumulative_output_keys if is_continuous else None,
|
||||
)
|
||||
|
||||
# Log actual input data being read
|
||||
@@ -689,6 +857,17 @@ class GraphExecutor:
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
|
||||
# Emit retry event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_retry(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
retry_count=retry_count,
|
||||
max_retries=max_retries,
|
||||
error=result.error or "",
|
||||
)
|
||||
|
||||
_is_retry = True
|
||||
continue
|
||||
else:
|
||||
@@ -698,7 +877,7 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Check if there's an ON_FAILURE edge to follow
|
||||
next_node = self._follow_edges(
|
||||
next_node = await self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -748,6 +927,7 @@ class GraphExecutor:
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
"resume_from": current_node_id,
|
||||
}
|
||||
|
||||
return ExecutionResult(
|
||||
@@ -774,11 +954,22 @@ class GraphExecutor:
|
||||
# This must happen BEFORE determining next node, since pause nodes may have no edges
|
||||
if node_spec.id in graph.pause_nodes:
|
||||
self.logger.info("💾 Saving session state after pause node")
|
||||
|
||||
# Emit pause event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_paused(
|
||||
stream_id=self._stream_id,
|
||||
node_id=node_spec.id,
|
||||
reason="HITL pause node",
|
||||
)
|
||||
|
||||
saved_memory = memory.read_all()
|
||||
session_state_out = {
|
||||
"paused_at": node_spec.id,
|
||||
"resume_from": f"{node_spec.id}_resume", # Resume key
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
"next_node": None, # Will resume from entry point
|
||||
}
|
||||
|
||||
@@ -827,10 +1018,21 @@ class GraphExecutor:
|
||||
if result.next_node:
|
||||
# Router explicitly set next node
|
||||
self.logger.info(f" → Router directing to: {result.next_node}")
|
||||
|
||||
# Emit edge traversed event for router-directed edge
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=result.next_node,
|
||||
edge_condition="router",
|
||||
)
|
||||
|
||||
current_node_id = result.next_node
|
||||
self._write_progress(current_node_id, path, memory, node_visit_counts)
|
||||
else:
|
||||
# Get all traversable edges for fan-out detection
|
||||
traversable_edges = self._get_all_traversable_edges(
|
||||
traversable_edges = await self._get_all_traversable_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -849,6 +1051,18 @@ class GraphExecutor:
|
||||
targets = [e.target for e in traversable_edges]
|
||||
fan_in_node = self._find_convergence_node(graph, targets)
|
||||
|
||||
# Emit edge traversed events for fan-out branches
|
||||
if self._event_bus:
|
||||
for edge in traversable_edges:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=edge.target,
|
||||
edge_condition=edge.condition.value
|
||||
if hasattr(edge.condition, "value")
|
||||
else str(edge.condition),
|
||||
)
|
||||
|
||||
# Execute branches in parallel
|
||||
(
|
||||
_branch_results,
|
||||
@@ -871,13 +1085,14 @@ class GraphExecutor:
|
||||
if fan_in_node:
|
||||
self.logger.info(f" ⑃ Fan-in: converging at {fan_in_node}")
|
||||
current_node_id = fan_in_node
|
||||
self._write_progress(current_node_id, path, memory, node_visit_counts)
|
||||
else:
|
||||
# No convergence point - branches are terminal
|
||||
self.logger.info(" → Parallel branches completed (no convergence)")
|
||||
break
|
||||
else:
|
||||
# Sequential: follow single edge (existing logic via _follow_edges)
|
||||
next_node = self._follow_edges(
|
||||
next_node = await self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -891,6 +1106,14 @@ class GraphExecutor:
|
||||
next_spec = graph.get_node(next_node)
|
||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||
|
||||
# Emit edge traversed event for sequential edge
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=next_node,
|
||||
)
|
||||
|
||||
# CHECKPOINT: node_complete (after determining next node)
|
||||
if (
|
||||
checkpoint_store
|
||||
@@ -925,6 +1148,84 @@ class GraphExecutor:
|
||||
|
||||
current_node_id = next_node
|
||||
|
||||
# Write progress snapshot at node transition
|
||||
self._write_progress(current_node_id, path, memory, node_visit_counts)
|
||||
|
||||
# Continuous mode: thread conversation forward with transition marker
|
||||
if is_continuous and result.conversation is not None:
|
||||
continuous_conversation = result.conversation
|
||||
|
||||
# Look up the next node spec for the transition marker
|
||||
next_spec = graph.get_node(current_node_id)
|
||||
if next_spec and next_spec.node_type == "event_loop":
|
||||
from framework.graph.prompt_composer import (
|
||||
build_narrative,
|
||||
build_transition_marker,
|
||||
compose_system_prompt,
|
||||
)
|
||||
|
||||
# Build Layer 2 (narrative) from current state
|
||||
narrative = build_narrative(memory, path, graph)
|
||||
|
||||
# Compose new system prompt (Layer 1 + 2 + 3)
|
||||
new_system = compose_system_prompt(
|
||||
identity_prompt=getattr(graph, "identity_prompt", None),
|
||||
focus_prompt=next_spec.system_prompt,
|
||||
narrative=narrative,
|
||||
)
|
||||
continuous_conversation.update_system_prompt(new_system)
|
||||
|
||||
# Switch conversation store to the next node's directory
|
||||
# so the transition marker and all subsequent messages are
|
||||
# persisted there instead of the first node's directory.
|
||||
if self._storage_path:
|
||||
from framework.storage.conversation_store import (
|
||||
FileConversationStore,
|
||||
)
|
||||
|
||||
next_store_path = self._storage_path / "conversations" / next_spec.id
|
||||
next_store = FileConversationStore(base_path=next_store_path)
|
||||
await continuous_conversation.switch_store(next_store)
|
||||
|
||||
# Insert transition marker into conversation
|
||||
data_dir = str(self._storage_path / "data") if self._storage_path else None
|
||||
marker = build_transition_marker(
|
||||
previous_node=node_spec,
|
||||
next_node=next_spec,
|
||||
memory=memory,
|
||||
cumulative_tool_names=sorted(cumulative_tool_names),
|
||||
data_dir=data_dir,
|
||||
)
|
||||
await continuous_conversation.add_user_message(
|
||||
marker,
|
||||
is_transition_marker=True,
|
||||
)
|
||||
|
||||
# Set current phase for phase-aware compaction
|
||||
continuous_conversation.set_current_phase(next_spec.id)
|
||||
|
||||
# Opportunistic compaction at transition:
|
||||
# 1. Prune old tool results (free, no LLM call)
|
||||
# 2. If still over 80%, do a phase-graduated compact
|
||||
if continuous_conversation.usage_ratio() > 0.5:
|
||||
await continuous_conversation.prune_old_tool_results(
|
||||
protect_tokens=2000,
|
||||
)
|
||||
if continuous_conversation.needs_compaction():
|
||||
self.logger.info(
|
||||
" Phase-boundary compaction (%.0f%% usage)",
|
||||
continuous_conversation.usage_ratio() * 100,
|
||||
)
|
||||
summary = (
|
||||
f"Summary of earlier phases (before {next_spec.name}). "
|
||||
"See transition markers for phase details."
|
||||
)
|
||||
await continuous_conversation.compact(
|
||||
summary,
|
||||
keep_recent=4,
|
||||
phase_graduated=True,
|
||||
)
|
||||
|
||||
# Update input_data for next node
|
||||
input_data = result.output
|
||||
|
||||
@@ -978,6 +1279,11 @@ class GraphExecutor:
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
session_state={
|
||||
"memory": output, # output IS memory.read_all()
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
},
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -1067,6 +1373,7 @@ class GraphExecutor:
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
"resume_from": current_node_id,
|
||||
}
|
||||
|
||||
# Mark latest checkpoint for resume on failure
|
||||
@@ -1119,12 +1426,20 @@ class GraphExecutor:
|
||||
goal: Goal,
|
||||
input_data: dict[str, Any],
|
||||
max_tokens: int = 4096,
|
||||
continuous_mode: bool = False,
|
||||
inherited_conversation: Any = None,
|
||||
override_tools: list | None = None,
|
||||
cumulative_output_keys: list[str] | None = None,
|
||||
) -> NodeContext:
|
||||
"""Build execution context for a node."""
|
||||
# Filter tools to those available to this node
|
||||
available_tools = []
|
||||
if node_spec.tools:
|
||||
available_tools = [t for t in self.tools if t.name in node_spec.tools]
|
||||
if override_tools is not None:
|
||||
# Continuous mode: use cumulative tool set
|
||||
available_tools = list(override_tools)
|
||||
else:
|
||||
available_tools = []
|
||||
if node_spec.tools:
|
||||
available_tools = [t for t in self.tools if t.name in node_spec.tools]
|
||||
|
||||
# Create scoped memory view
|
||||
scoped_memory = memory.with_permissions(
|
||||
@@ -1145,6 +1460,9 @@ class GraphExecutor:
|
||||
max_tokens=max_tokens,
|
||||
runtime_logger=self.runtime_logger,
|
||||
pause_event=self._pause_requested, # Pass pause event for granular control
|
||||
continuous_mode=continuous_mode,
|
||||
inherited_conversation=inherited_conversation,
|
||||
cumulative_output_keys=cumulative_output_keys or [],
|
||||
)
|
||||
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
@@ -1269,7 +1587,7 @@ class GraphExecutor:
|
||||
# Should never reach here due to validation above
|
||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||
|
||||
def _follow_edges(
|
||||
async def _follow_edges(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
@@ -1284,7 +1602,7 @@ class GraphExecutor:
|
||||
for edge in edges:
|
||||
target_node_spec = graph.get_node(edge.target)
|
||||
|
||||
if edge.should_traverse(
|
||||
if await edge.should_traverse(
|
||||
source_success=result.success,
|
||||
source_output=result.output,
|
||||
memory=memory.read_all(),
|
||||
@@ -1310,7 +1628,7 @@ class GraphExecutor:
|
||||
self.logger.warning(f"⚠ Output validation failed: {validation.errors}")
|
||||
|
||||
# Clean the output
|
||||
cleaned_output = self.output_cleaner.clean_output(
|
||||
cleaned_output = await self.output_cleaner.clean_output(
|
||||
output=output_to_validate,
|
||||
source_node_id=current_node_id,
|
||||
target_node_spec=target_node_spec,
|
||||
@@ -1348,7 +1666,7 @@ class GraphExecutor:
|
||||
|
||||
return None
|
||||
|
||||
def _get_all_traversable_edges(
|
||||
async def _get_all_traversable_edges(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
@@ -1368,7 +1686,7 @@ class GraphExecutor:
|
||||
|
||||
for edge in edges:
|
||||
target_node_spec = graph.get_node(edge.target)
|
||||
if edge.should_traverse(
|
||||
if await edge.should_traverse(
|
||||
source_success=result.success,
|
||||
source_output=result.output,
|
||||
memory=memory.read_all(),
|
||||
@@ -1510,7 +1828,7 @@ class GraphExecutor:
|
||||
f"⚠ Output validation failed for branch "
|
||||
f"{branch.node_id}: {validation.errors}"
|
||||
)
|
||||
cleaned_output = self.output_cleaner.clean_output(
|
||||
cleaned_output = await self.output_cleaner.clean_output(
|
||||
output=mem_snapshot,
|
||||
source_node_id=source_node_spec.id if source_node_spec else "unknown",
|
||||
target_node_spec=node_spec,
|
||||
|
||||
@@ -203,7 +203,7 @@ class HybridJudge:
|
||||
user_prompt = self._build_llm_user_prompt(step, result, context, rule_result)
|
||||
|
||||
try:
|
||||
response = self.llm.complete(
|
||||
response = await self.llm.acomplete(
|
||||
messages=[{"role": "user", "content": user_prompt}],
|
||||
system=system_prompt,
|
||||
)
|
||||
|
||||
@@ -238,6 +238,16 @@ class NodeSpec(BaseModel):
|
||||
description="If True, this node streams output to the end user and can request input.",
|
||||
)
|
||||
|
||||
# Phase completion criteria for conversation-aware judge (Level 2)
|
||||
success_criteria: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Natural-language criteria for phase completion. When set, the "
|
||||
"implicit judge upgrades to Level 2: after output keys are satisfied, "
|
||||
"a fast LLM evaluates whether the conversation meets these criteria."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
@@ -483,6 +493,11 @@ class NodeContext:
|
||||
# Pause control (optional) - asyncio.Event for pause requests
|
||||
pause_event: Any = None # asyncio.Event | None
|
||||
|
||||
# Continuous conversation mode
|
||||
continuous_mode: bool = False # True when graph has conversation_mode="continuous"
|
||||
inherited_conversation: Any = None # NodeConversation | None (from prior node)
|
||||
cumulative_output_keys: list[str] = field(default_factory=list) # All output keys from path
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeResult:
|
||||
@@ -511,6 +526,9 @@ class NodeResult:
|
||||
# Pydantic validation errors (if any)
|
||||
validation_errors: list[str] = field(default_factory=list)
|
||||
|
||||
# Continuous conversation mode: return conversation for threading to next node
|
||||
conversation: Any = None # NodeConversation | None
|
||||
|
||||
def to_summary(self, node_spec: Any = None) -> str:
|
||||
"""
|
||||
Generate a human-readable summary of this node's execution and output.
|
||||
@@ -913,7 +931,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
)
|
||||
return result
|
||||
|
||||
response = ctx.llm.complete_with_tools(
|
||||
response = await ctx.llm.acomplete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=ctx.available_tools,
|
||||
@@ -933,7 +951,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
f" 📋 Expecting JSON output with keys: {ctx.node_spec.output_keys}"
|
||||
)
|
||||
|
||||
response = ctx.llm.complete(
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
json_mode=use_json_mode,
|
||||
@@ -967,7 +985,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
|
||||
# Retry the call with compaction instruction
|
||||
if ctx.available_tools and self.tool_executor:
|
||||
response = ctx.llm.complete_with_tools(
|
||||
response = await ctx.llm.acomplete_with_tools(
|
||||
messages=compaction_messages,
|
||||
system=system,
|
||||
tools=ctx.available_tools,
|
||||
@@ -975,7 +993,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
max_tokens=ctx.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = ctx.llm.complete(
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=compaction_messages,
|
||||
system=system,
|
||||
json_mode=use_json_mode,
|
||||
@@ -1056,7 +1074,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
|
||||
# Re-call LLM with feedback
|
||||
if ctx.available_tools and self.tool_executor:
|
||||
response = ctx.llm.complete_with_tools(
|
||||
response = await ctx.llm.acomplete_with_tools(
|
||||
messages=current_messages,
|
||||
system=system,
|
||||
tools=ctx.available_tools,
|
||||
@@ -1064,7 +1082,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
max_tokens=ctx.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = ctx.llm.complete(
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=current_messages,
|
||||
system=system,
|
||||
json_mode=use_json_mode,
|
||||
@@ -1134,7 +1152,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
decision_id=decision_id,
|
||||
success=True,
|
||||
result=response.content,
|
||||
tokens_used=response.input_tokens + response.output_tokens,
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
@@ -1233,7 +1251,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
success=False,
|
||||
error=_extraction_error,
|
||||
output={},
|
||||
tokens_used=response.input_tokens + response.output_tokens,
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
# JSON extraction failed completely - still strip code blocks
|
||||
@@ -1275,7 +1293,7 @@ Keep the same JSON structure but with shorter content values.
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output=output,
|
||||
tokens_used=response.input_tokens + response.output_tokens,
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
@@ -1804,7 +1822,7 @@ Respond with ONLY a JSON object:
|
||||
logger.info(" 🤔 Router using LLM to choose path...")
|
||||
|
||||
try:
|
||||
response = ctx.llm.complete(
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system=ctx.node_spec.system_prompt
|
||||
or "You are a routing agent. Respond with JSON only.",
|
||||
|
||||
@@ -206,7 +206,7 @@ class OutputCleaner:
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
def clean_output(
|
||||
async def clean_output(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
source_node_id: str,
|
||||
@@ -288,7 +288,7 @@ Return ONLY valid JSON matching the expected schema. No explanations, no markdow
|
||||
f"🧹 Cleaning output from '{source_node_id}' using {self.config.fast_model}"
|
||||
)
|
||||
|
||||
response = self.llm.complete(
|
||||
response = await self.llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system=(
|
||||
"You clean malformed agent outputs. Return only valid JSON matching the schema."
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Prompt composition for continuous agent mode.
|
||||
|
||||
Composes the three-layer system prompt (onion model) and generates
|
||||
transition markers inserted into the conversation at phase boundaries.
|
||||
|
||||
Layer 1 — Identity (static, defined at agent level, never changes):
|
||||
"You are a thorough research agent. You prefer clarity over jargon..."
|
||||
|
||||
Layer 2 — Narrative (auto-generated from conversation/memory state):
|
||||
"We've finished scoping the project. The user wants to focus on..."
|
||||
|
||||
Layer 3 — Focus (per-node system_prompt, reframed as focus directive):
|
||||
"Your current attention: synthesize findings into a report..."
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.node import NodeSpec, SharedMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compose_system_prompt(
|
||||
identity_prompt: str | None,
|
||||
focus_prompt: str | None,
|
||||
narrative: str | None = None,
|
||||
) -> str:
|
||||
"""Compose the three-layer system prompt.
|
||||
|
||||
Args:
|
||||
identity_prompt: Layer 1 — static agent identity (from GraphSpec).
|
||||
focus_prompt: Layer 3 — per-node focus directive (from NodeSpec.system_prompt).
|
||||
narrative: Layer 2 — auto-generated from conversation state.
|
||||
|
||||
Returns:
|
||||
Composed system prompt with all layers present.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
# Layer 1: Identity (always first, anchors the personality)
|
||||
if identity_prompt:
|
||||
parts.append(identity_prompt)
|
||||
|
||||
# Layer 2: Narrative (what's happened so far)
|
||||
if narrative:
|
||||
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
|
||||
|
||||
# Layer 3: Focus (current phase directive)
|
||||
if focus_prompt:
|
||||
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
|
||||
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def build_narrative(
|
||||
memory: SharedMemory,
|
||||
execution_path: list[str],
|
||||
graph: GraphSpec,
|
||||
) -> str:
|
||||
"""Build Layer 2 (narrative) from structured state.
|
||||
|
||||
Deterministic — no LLM call. Reads SharedMemory and execution path
|
||||
to describe what has happened so far. Cheap and fast.
|
||||
|
||||
Args:
|
||||
memory: Current shared memory state.
|
||||
execution_path: List of node IDs visited so far.
|
||||
graph: Graph spec (for node names/descriptions).
|
||||
|
||||
Returns:
|
||||
Narrative string describing the session state.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
# Describe execution path
|
||||
if execution_path:
|
||||
phase_descriptions: list[str] = []
|
||||
for node_id in execution_path:
|
||||
node_spec = graph.get_node(node_id)
|
||||
if node_spec:
|
||||
phase_descriptions.append(f"- {node_spec.name}: {node_spec.description}")
|
||||
else:
|
||||
phase_descriptions.append(f"- {node_id}")
|
||||
parts.append("Phases completed:\n" + "\n".join(phase_descriptions))
|
||||
|
||||
# Describe key memory values (skip very long values)
|
||||
all_memory = memory.read_all()
|
||||
if all_memory:
|
||||
memory_lines: list[str] = []
|
||||
for key, value in all_memory.items():
|
||||
if value is None:
|
||||
continue
|
||||
val_str = str(value)
|
||||
if len(val_str) > 200:
|
||||
val_str = val_str[:200] + "..."
|
||||
memory_lines.append(f"- {key}: {val_str}")
|
||||
if memory_lines:
|
||||
parts.append("Current state:\n" + "\n".join(memory_lines))
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def build_transition_marker(
|
||||
previous_node: NodeSpec,
|
||||
next_node: NodeSpec,
|
||||
memory: SharedMemory,
|
||||
cumulative_tool_names: list[str],
|
||||
data_dir: Path | str | None = None,
|
||||
) -> str:
|
||||
"""Build a 'State of the World' transition marker.
|
||||
|
||||
Inserted into the conversation as a user message at phase boundaries.
|
||||
Gives the LLM full situational awareness: what happened, what's stored,
|
||||
what tools are available, and what to focus on next.
|
||||
|
||||
Args:
|
||||
previous_node: NodeSpec of the phase just completed.
|
||||
next_node: NodeSpec of the phase about to start.
|
||||
memory: Current shared memory state.
|
||||
cumulative_tool_names: All tools available (cumulative set).
|
||||
data_dir: Path to spillover data directory.
|
||||
|
||||
Returns:
|
||||
Transition marker message text.
|
||||
"""
|
||||
sections: list[str] = []
|
||||
|
||||
# Header
|
||||
sections.append(f"--- PHASE TRANSITION: {previous_node.name} → {next_node.name} ---")
|
||||
|
||||
# What just completed
|
||||
sections.append(f"\nCompleted: {previous_node.name}")
|
||||
sections.append(f" {previous_node.description}")
|
||||
|
||||
# Outputs in memory
|
||||
all_memory = memory.read_all()
|
||||
if all_memory:
|
||||
memory_lines: list[str] = []
|
||||
for key, value in all_memory.items():
|
||||
if value is None:
|
||||
continue
|
||||
val_str = str(value)
|
||||
if len(val_str) > 300:
|
||||
val_str = val_str[:300] + "..."
|
||||
memory_lines.append(f" {key}: {val_str}")
|
||||
if memory_lines:
|
||||
sections.append("\nOutputs available:\n" + "\n".join(memory_lines))
|
||||
|
||||
# Files in data directory
|
||||
if data_dir:
|
||||
data_path = Path(data_dir)
|
||||
if data_path.exists():
|
||||
files = sorted(data_path.iterdir())
|
||||
if files:
|
||||
file_lines = [
|
||||
f" {f.name} ({f.stat().st_size:,} bytes)" for f in files if f.is_file()
|
||||
]
|
||||
if file_lines:
|
||||
sections.append(
|
||||
"\nData files (use load_data to access):\n" + "\n".join(file_lines)
|
||||
)
|
||||
|
||||
# Available tools
|
||||
if cumulative_tool_names:
|
||||
sections.append("\nAvailable tools: " + ", ".join(sorted(cumulative_tool_names)))
|
||||
|
||||
# Next phase
|
||||
sections.append(f"\nNow entering: {next_node.name}")
|
||||
sections.append(f" {next_node.description}")
|
||||
|
||||
# Reflection prompt (engineered metacognition)
|
||||
sections.append(
|
||||
"\nBefore proceeding, briefly reflect: what went well in the "
|
||||
"previous phase? Are there any gaps or surprises worth noting?"
|
||||
)
|
||||
|
||||
sections.append("\n--- END TRANSITION ---")
|
||||
|
||||
return "\n".join(sections)
|
||||
@@ -145,7 +145,7 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
# value.attr
|
||||
# STIRCT CHECK: No access to private attributes (starting with _)
|
||||
# STRICT CHECK: No access to private attributes (starting with _)
|
||||
if node.attr.startswith("_"):
|
||||
raise ValueError(f"Access to private attribute '{node.attr}' is not allowed")
|
||||
|
||||
|
||||
@@ -314,7 +314,7 @@ class WorkerNode:
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = self.llm.complete(
|
||||
response = await self.llm.acomplete(
|
||||
messages=messages,
|
||||
system=action.system_prompt,
|
||||
)
|
||||
|
||||
@@ -70,6 +70,7 @@ class AnthropicProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion from Claude (via LiteLLM)."""
|
||||
return self._provider.complete(
|
||||
@@ -79,6 +80,7 @@ class AnthropicProvider(LLMProvider):
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
@@ -97,3 +99,41 @@ class AnthropicProvider(LLMProvider):
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async completion via LiteLLM."""
|
||||
return await self._provider.acomplete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Async tool-use loop via LiteLLM."""
|
||||
return await self._provider.acomplete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
+383
-16
@@ -30,6 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
|
||||
# Directory for dumping failed requests
|
||||
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
@@ -84,6 +85,91 @@ def _dump_failed_request(
|
||||
return str(filepath)
|
||||
|
||||
|
||||
def _compute_retry_delay(
|
||||
attempt: int,
|
||||
exception: BaseException | None = None,
|
||||
backoff_base: int = RATE_LIMIT_BACKOFF_BASE,
|
||||
max_delay: int = RATE_LIMIT_MAX_DELAY,
|
||||
) -> float:
|
||||
"""Compute retry delay, preferring server-provided Retry-After headers.
|
||||
|
||||
Priority:
|
||||
1. retry-after-ms header (milliseconds, float)
|
||||
2. retry-after header as seconds (float)
|
||||
3. retry-after header as HTTP-date (RFC 7231)
|
||||
4. Exponential backoff: backoff_base * 2^attempt
|
||||
|
||||
All values are capped at max_delay seconds.
|
||||
"""
|
||||
if exception is not None:
|
||||
response = getattr(exception, "response", None)
|
||||
if response is not None:
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is not None:
|
||||
# Priority 1: retry-after-ms (milliseconds)
|
||||
retry_after_ms = headers.get("retry-after-ms")
|
||||
if retry_after_ms is not None:
|
||||
try:
|
||||
delay = float(retry_after_ms) / 1000.0
|
||||
return min(max(delay, 0), max_delay)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Priority 2: retry-after (seconds or HTTP-date)
|
||||
retry_after = headers.get("retry-after")
|
||||
if retry_after is not None:
|
||||
# Try as seconds (float)
|
||||
try:
|
||||
delay = float(retry_after)
|
||||
return min(max(delay, 0), max_delay)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Try as HTTP-date (e.g., "Fri, 31 Dec 2025 23:59:59 GMT")
|
||||
try:
|
||||
from email.utils import parsedate_to_datetime
|
||||
|
||||
retry_date = parsedate_to_datetime(retry_after)
|
||||
now = datetime.now(retry_date.tzinfo)
|
||||
delay = (retry_date - now).total_seconds()
|
||||
return min(max(delay, 0), max_delay)
|
||||
except (ValueError, TypeError, OverflowError):
|
||||
pass
|
||||
|
||||
# Fallback: exponential backoff
|
||||
delay = backoff_base * (2**attempt)
|
||||
return min(delay, max_delay)
|
||||
|
||||
|
||||
def _is_stream_transient_error(exc: BaseException) -> bool:
|
||||
"""Classify whether a streaming exception is transient (recoverable).
|
||||
|
||||
Transient errors (recoverable=True): network issues, server errors, timeouts.
|
||||
Permanent errors (recoverable=False): auth, bad request, context window, etc.
|
||||
"""
|
||||
try:
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
|
||||
transient_types: tuple[type[BaseException], ...] = (
|
||||
APIConnectionError,
|
||||
InternalServerError,
|
||||
BadGatewayError,
|
||||
ServiceUnavailableError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
OSError,
|
||||
)
|
||||
except ImportError:
|
||||
transient_types = (TimeoutError, ConnectionError, OSError)
|
||||
|
||||
return isinstance(exc, transient_types)
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""
|
||||
LiteLLM-based LLM provider for multi-provider support.
|
||||
@@ -150,10 +236,13 @@ class LiteLLMProvider(LLMProvider):
|
||||
"LiteLLM is not installed. Please install it with: uv pip install litellm"
|
||||
)
|
||||
|
||||
def _completion_with_rate_limit_retry(self, **kwargs: Any) -> Any:
|
||||
def _completion_with_rate_limit_retry(
|
||||
self, max_retries: int | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Call litellm.completion with retry on 429 rate limit errors and empty responses."""
|
||||
model = kwargs.get("model", self.model)
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
retries = max_retries if max_retries is not None else RATE_LIMIT_MAX_RETRIES
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = litellm.completion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
@@ -194,22 +283,22 @@ class LiteLLMProvider(LLMProvider):
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
|
||||
if attempt == RATE_LIMIT_MAX_RETRIES:
|
||||
if attempt == retries:
|
||||
logger.error(
|
||||
f"[retry] GAVE UP on {model} after {RATE_LIMIT_MAX_RETRIES + 1} "
|
||||
f"[retry] GAVE UP on {model} after {retries + 1} "
|
||||
f"attempts — empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0})"
|
||||
)
|
||||
return response
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
wait = _compute_retry_delay(attempt)
|
||||
logger.warning(
|
||||
f"[retry] {model} returned empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0}) — "
|
||||
f"likely rate limited or quota exceeded. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
f"(attempt {attempt + 1}/{retries})"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
@@ -225,21 +314,21 @@ class LiteLLMProvider(LLMProvider):
|
||||
error_type="rate_limit",
|
||||
attempt=attempt,
|
||||
)
|
||||
if attempt == RATE_LIMIT_MAX_RETRIES:
|
||||
if attempt == retries:
|
||||
logger.error(
|
||||
f"[retry] GAVE UP on {model} after {RATE_LIMIT_MAX_RETRIES + 1} "
|
||||
f"[retry] GAVE UP on {model} after {retries + 1} "
|
||||
f"attempts — rate limit error: {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
raise
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
f"[retry] {model} rate limited (429): {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
f"(attempt {attempt + 1}/{retries})"
|
||||
)
|
||||
time.sleep(wait)
|
||||
# unreachable, but satisfies type checker
|
||||
@@ -253,6 +342,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion using LiteLLM."""
|
||||
# Prepare messages with system prompt
|
||||
@@ -293,12 +383,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
# Make the call
|
||||
response = self._completion_with_rate_limit_retry(**kwargs)
|
||||
response = self._completion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
|
||||
|
||||
# Extract content
|
||||
content = response.choices[0].message.content or ""
|
||||
|
||||
# Get usage info
|
||||
# Get usage info.
|
||||
# NOTE: completion_tokens includes reasoning/thinking tokens for models
|
||||
# that use them (o1, gpt-5-mini, etc.). LiteLLM does not reliably expose
|
||||
# usage.completion_tokens_details.reasoning_tokens across all providers.
|
||||
# This means output_tokens may be inflated for reasoning models.
|
||||
# Compaction is unaffected — it uses prompt_tokens (input-side only).
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
@@ -433,6 +528,267 @@ class LiteLLMProvider(LLMProvider):
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Async variants — non-blocking on the event loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _acompletion_with_rate_limit_retry(
|
||||
self, max_retries: int | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Async version of _completion_with_rate_limit_retry.
|
||||
|
||||
Uses litellm.acompletion and asyncio.sleep instead of blocking calls.
|
||||
"""
|
||||
model = kwargs.get("model", self.model)
|
||||
retries = max_retries if max_retries is not None else RATE_LIMIT_MAX_RETRIES
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
content = response.choices[0].message.content if response.choices else None
|
||||
has_tool_calls = bool(response.choices and response.choices[0].message.tool_calls)
|
||||
if not content and not has_tool_calls:
|
||||
messages = kwargs.get("messages", [])
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
logger.debug(
|
||||
"[async-retry] Empty response after assistant message — "
|
||||
"expected, not retrying."
|
||||
)
|
||||
return response
|
||||
|
||||
finish_reason = (
|
||||
response.choices[0].finish_reason if response.choices else "unknown"
|
||||
)
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_response",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[async-retry] Empty response - {len(messages)} messages, "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
|
||||
if attempt == retries:
|
||||
logger.error(
|
||||
f"[async-retry] GAVE UP on {model} after {retries + 1} "
|
||||
f"attempts — empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0})"
|
||||
)
|
||||
return response
|
||||
wait = _compute_retry_delay(attempt)
|
||||
logger.warning(
|
||||
f"[async-retry] {model} returned empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0}) — "
|
||||
f"likely rate limited or quota exceeded. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{retries})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
return response
|
||||
except RateLimitError as e:
|
||||
messages = kwargs.get("messages", [])
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
error_type="rate_limit",
|
||||
attempt=attempt,
|
||||
)
|
||||
if attempt == retries:
|
||||
logger.error(
|
||||
f"[async-retry] GAVE UP on {model} after {retries + 1} "
|
||||
f"attempts — rate limit error: {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
raise
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
f"[async-retry] {model} rate limited (429): {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{retries})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
raise RuntimeError("Exhausted rate limit retries")
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete(). Uses litellm.acompletion — non-blocking."""
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
if json_mode:
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
else:
|
||||
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=response.choices[0].finish_reason or "",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 4096,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete_with_tools(). Uses litellm.acompletion — non-blocking."""
|
||||
current_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
current_messages.append({"role": "system", "content": system})
|
||||
current_messages.extend(messages)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
openai_tools = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for _ in range(max_iterations):
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(**kwargs)
|
||||
|
||||
usage = response.usage
|
||||
if usage:
|
||||
total_input_tokens += usage.prompt_tokens
|
||||
total_output_tokens += usage.completion_tokens
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
if choice.finish_reason == "stop" or not message.tool_calls:
|
||||
return LLMResponse(
|
||||
content=message.content or "",
|
||||
model=response.model or self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason=choice.finish_reason or "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=args,
|
||||
)
|
||||
|
||||
result = tool_executor(tool_use)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content="Max tool iterations reached",
|
||||
model=self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason="max_iterations",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def _tool_to_openai_format(self, tool: Tool) -> dict[str, Any]:
|
||||
"""Convert Tool to OpenAI function calling format."""
|
||||
return {
|
||||
@@ -591,7 +947,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
wait = _compute_retry_delay(attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
@@ -619,10 +975,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
except RateLimitError as e:
|
||||
if attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"Retrying in {wait:.1f}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
@@ -631,5 +987,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
if _is_stream_transient_error(e) and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} transient error "
|
||||
f"({type(e).__name__}): {e!s}. "
|
||||
f"Retrying in {wait:.1f}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
recoverable = _is_stream_transient_error(e)
|
||||
yield StreamErrorEvent(error=str(e), recoverable=recoverable)
|
||||
return
|
||||
|
||||
@@ -120,6 +120,7 @@ class MockLLMProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a mock completion without calling a real LLM.
|
||||
@@ -182,6 +183,44 @@ class MockLLMProvider(LLMProvider):
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async mock completion (no I/O, returns immediately)."""
|
||||
return self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Async mock tool-use completion (no I/O, returns immediately)."""
|
||||
return self.complete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -65,6 +67,7 @@ class LLMProvider(ABC):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
@@ -79,6 +82,8 @@ class LLMProvider(ABC):
|
||||
- {"type": "json_schema", "json_schema": {"name": "...", "schema": {...}}}
|
||||
for strict JSON schema enforcement
|
||||
json_mode: If True, request structured JSON output from the LLM
|
||||
max_retries: Override retry count for rate-limit/empty-response retries.
|
||||
None uses the provider default.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and metadata
|
||||
@@ -109,6 +114,62 @@ class LLMProvider(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list["Tool"] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> "LLMResponse":
|
||||
"""Async version of complete(). Non-blocking on the event loop.
|
||||
|
||||
Default implementation offloads the sync complete() to a thread pool.
|
||||
Subclasses SHOULD override for native async I/O.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.complete,
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
),
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list["Tool"],
|
||||
tool_executor: Callable[["ToolUse"], "ToolResult"],
|
||||
max_iterations: int = 10,
|
||||
) -> "LLMResponse":
|
||||
"""Async version of complete_with_tools(). Non-blocking on the event loop.
|
||||
|
||||
Default implementation offloads the sync complete_with_tools() to a thread pool.
|
||||
Subclasses SHOULD override for native async I/O.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.complete_with_tools,
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
),
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -132,7 +193,7 @@ class LLMProvider(ABC):
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
response = self.complete(
|
||||
response = await self.acomplete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
|
||||
@@ -8,19 +8,22 @@ Usage:
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
# Project root resolution. This file lives at core/framework/mcp/agent_builder_server.py,
|
||||
# so the project root (where exports/ lives) is four parents up.
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
|
||||
|
||||
# Ensure exports/ is on sys.path so AgentRunner can import agent modules.
|
||||
_framework_dir = Path(__file__).resolve().parent.parent # core/framework/ -> core/
|
||||
_project_root = _framework_dir.parent # core/ -> project root
|
||||
_exports_dir = _project_root / "exports"
|
||||
_exports_dir = _PROJECT_ROOT / "exports"
|
||||
if _exports_dir.is_dir() and str(_exports_dir) not in sys.path:
|
||||
sys.path.insert(0, str(_exports_dir))
|
||||
del _framework_dir, _project_root, _exports_dir
|
||||
del _exports_dir
|
||||
|
||||
from mcp.server import FastMCP # noqa: E402
|
||||
from pydantic import ValidationError # noqa: E402
|
||||
@@ -176,8 +179,8 @@ def _load_active_session() -> BuildSession | None:
|
||||
|
||||
if session_id:
|
||||
return _load_session(session_id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.warning("Failed to load active session: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
@@ -542,6 +545,9 @@ def _validate_agent_path(agent_path: str) -> tuple[Path | None, str | None]:
|
||||
"""
|
||||
Validate and normalize agent_path.
|
||||
|
||||
Resolves relative paths against _PROJECT_ROOT since the MCP server's
|
||||
cwd (core/) differs from the user's cwd (project root).
|
||||
|
||||
Returns:
|
||||
(Path, None) if valid
|
||||
(None, error_json) if invalid
|
||||
@@ -556,6 +562,12 @@ def _validate_agent_path(agent_path: str) -> tuple[Path | None, str | None]:
|
||||
|
||||
path = Path(agent_path)
|
||||
|
||||
# Resolve relative paths against project root (not MCP server's cwd)
|
||||
if not path.is_absolute() and not path.exists():
|
||||
resolved = _PROJECT_ROOT / path
|
||||
if resolved.exists():
|
||||
path = resolved
|
||||
|
||||
if not path.exists():
|
||||
return None, json.dumps(
|
||||
{
|
||||
@@ -1100,11 +1112,11 @@ def validate_graph() -> str:
|
||||
if entry_candidates:
|
||||
reachable = set()
|
||||
|
||||
# For pause/resume agents, start from ALL entry points (including resume)
|
||||
if is_pause_resume_agent:
|
||||
to_visit = list(entry_candidates) # All nodes without incoming edges
|
||||
else:
|
||||
to_visit = [entry_candidates[0]] # Just the primary entry
|
||||
# Start from ALL entry candidates (nodes without incoming edges).
|
||||
# This handles both pause/resume agents and async entry point agents
|
||||
# where multiple nodes have no incoming edges (e.g., a primary entry
|
||||
# node and an event-driven entry node).
|
||||
to_visit = list(entry_candidates)
|
||||
|
||||
while to_visit:
|
||||
current = to_visit.pop()
|
||||
@@ -3019,18 +3031,15 @@ def _format_success_criteria(criteria: list[SuccessCriterion]) -> str:
|
||||
|
||||
# Test template for Claude to use when writing tests
|
||||
CONSTRAINT_TEST_TEMPLATE = '''@pytest.mark.asyncio
|
||||
async def test_constraint_{constraint_id}_{scenario}(mock_mode):
|
||||
async def test_constraint_{constraint_id}_{scenario}(runner, auto_responder, mock_mode):
|
||||
"""Test: {description}"""
|
||||
result = await default_agent.run({{"key": "value"}}, mock_mode=mock_mode)
|
||||
|
||||
# IMPORTANT: result is an ExecutionResult object with these attributes:
|
||||
# - result.success: bool - whether the agent succeeded
|
||||
# - result.output: dict - the agent's output data (access data here!)
|
||||
# - result.error: str or None - error message if failed
|
||||
await auto_responder.start()
|
||||
try:
|
||||
result = await runner.run({{"key": "value"}})
|
||||
finally:
|
||||
await auto_responder.stop()
|
||||
|
||||
assert result.success, f"Agent failed: {{result.error}}"
|
||||
|
||||
# Access output data via result.output
|
||||
output_data = result.output or {{}}
|
||||
|
||||
# Add constraint-specific assertions here
|
||||
@@ -3038,18 +3047,15 @@ async def test_constraint_{constraint_id}_{scenario}(mock_mode):
|
||||
'''
|
||||
|
||||
SUCCESS_TEST_TEMPLATE = '''@pytest.mark.asyncio
|
||||
async def test_success_{criteria_id}_{scenario}(mock_mode):
|
||||
async def test_success_{criteria_id}_{scenario}(runner, auto_responder, mock_mode):
|
||||
"""Test: {description}"""
|
||||
result = await default_agent.run({{"key": "value"}}, mock_mode=mock_mode)
|
||||
|
||||
# IMPORTANT: result is an ExecutionResult object with these attributes:
|
||||
# - result.success: bool - whether the agent succeeded
|
||||
# - result.output: dict - the agent's output data (access data here!)
|
||||
# - result.error: str or None - error message if failed
|
||||
await auto_responder.start()
|
||||
try:
|
||||
result = await runner.run({{"key": "value"}})
|
||||
finally:
|
||||
await auto_responder.stop()
|
||||
|
||||
assert result.success, f"Agent failed: {{result.error}}"
|
||||
|
||||
# Access output data via result.output
|
||||
output_data = result.output or {{}}
|
||||
|
||||
# Add success criteria-specific assertions here
|
||||
@@ -3105,7 +3111,6 @@ def generate_constraint_tests(
|
||||
test_type="Constraint",
|
||||
agent_name=agent_module,
|
||||
description=f"Tests for constraints defined in goal: {goal.name}",
|
||||
agent_module=agent_module,
|
||||
)
|
||||
|
||||
# Return guidelines + data for Claude to write tests directly
|
||||
@@ -3121,14 +3126,22 @@ def generate_constraint_tests(
|
||||
"max_tests": 5,
|
||||
"naming_convention": "test_constraint_<constraint_id>_<scenario>",
|
||||
"required_decorator": "@pytest.mark.asyncio",
|
||||
"required_fixture": "mock_mode",
|
||||
"agent_call_pattern": "await default_agent.run(input_dict, mock_mode=mock_mode)",
|
||||
"required_fixtures": "runner, auto_responder, mock_mode",
|
||||
"agent_call_pattern": "await runner.run(input_dict)",
|
||||
"auto_responder_pattern": (
|
||||
"await auto_responder.start()\n"
|
||||
"try:\n"
|
||||
" result = await runner.run(input_dict)\n"
|
||||
"finally:\n"
|
||||
" await auto_responder.stop()"
|
||||
),
|
||||
"result_type": "ExecutionResult with .success, .output (dict), .error",
|
||||
"critical_rules": [
|
||||
"Every test function MUST be async with @pytest.mark.asyncio",
|
||||
"Every test MUST accept mock_mode as a parameter",
|
||||
"Use await default_agent.run(input, mock_mode=mock_mode)",
|
||||
"default_agent is already imported - do NOT add imports",
|
||||
"Every test MUST accept runner, auto_responder, and mock_mode fixtures",
|
||||
"Use await runner.run(input) -- NOT default_agent.run()",
|
||||
"Start auto_responder before running, stop in finally block",
|
||||
"runner and auto_responder are from conftest.py -- do NOT import them",
|
||||
"NEVER call result.get() - use result.output.get() instead",
|
||||
"Always check result.success before accessing result.output",
|
||||
],
|
||||
@@ -3192,7 +3205,6 @@ def generate_success_tests(
|
||||
test_type="Success criteria",
|
||||
agent_name=agent_module,
|
||||
description=f"Tests for success criteria defined in goal: {goal.name}",
|
||||
agent_module=agent_module,
|
||||
)
|
||||
|
||||
# Return guidelines + data for Claude to write tests directly
|
||||
@@ -3214,14 +3226,22 @@ def generate_success_tests(
|
||||
"max_tests": 12,
|
||||
"naming_convention": "test_success_<criteria_id>_<scenario>",
|
||||
"required_decorator": "@pytest.mark.asyncio",
|
||||
"required_fixture": "mock_mode",
|
||||
"agent_call_pattern": "await default_agent.run(input_dict, mock_mode=mock_mode)",
|
||||
"required_fixtures": "runner, auto_responder, mock_mode",
|
||||
"agent_call_pattern": "await runner.run(input_dict)",
|
||||
"auto_responder_pattern": (
|
||||
"await auto_responder.start()\n"
|
||||
"try:\n"
|
||||
" result = await runner.run(input_dict)\n"
|
||||
"finally:\n"
|
||||
" await auto_responder.stop()"
|
||||
),
|
||||
"result_type": "ExecutionResult with .success, .output (dict), .error",
|
||||
"critical_rules": [
|
||||
"Every test function MUST be async with @pytest.mark.asyncio",
|
||||
"Every test MUST accept mock_mode as a parameter",
|
||||
"Use await default_agent.run(input, mock_mode=mock_mode)",
|
||||
"default_agent is already imported - do NOT add imports",
|
||||
"Every test MUST accept runner, auto_responder, and mock_mode fixtures",
|
||||
"Use await runner.run(input) -- NOT default_agent.run()",
|
||||
"Start auto_responder before running, stop in finally block",
|
||||
"runner and auto_responder are from conftest.py -- do NOT import them",
|
||||
"NEVER call result.get() - use result.output.get() instead",
|
||||
"Always check result.success before accessing result.output",
|
||||
],
|
||||
@@ -3318,11 +3338,13 @@ def run_tests(
|
||||
# Add short traceback and quiet summary
|
||||
cmd.append("--tb=short")
|
||||
|
||||
# Set PYTHONPATH to project root so agents can import from core.framework
|
||||
# Set PYTHONPATH so framework and agent packages are importable
|
||||
env = os.environ.copy()
|
||||
pythonpath = env.get("PYTHONPATH", "")
|
||||
project_root = Path(__file__).parent.parent.parent.parent.resolve()
|
||||
env["PYTHONPATH"] = f"{project_root}:{pythonpath}"
|
||||
core_path = project_root / "core"
|
||||
exports_path = project_root / "exports"
|
||||
env["PYTHONPATH"] = f"{core_path}:{exports_path}:{project_root}:{pythonpath}"
|
||||
|
||||
# Run pytest
|
||||
try:
|
||||
@@ -3792,7 +3814,11 @@ def check_missing_credentials(
|
||||
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
runner = AgentRunner.load(agent_path)
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
runner = AgentRunner.load(str(path))
|
||||
runner.validate()
|
||||
|
||||
store = _get_credential_store()
|
||||
@@ -3992,7 +4018,11 @@ def verify_credentials(
|
||||
try:
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
runner = AgentRunner.load(agent_path)
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
runner = AgentRunner.load(str(path))
|
||||
validation = runner.validate()
|
||||
|
||||
return json.dumps(
|
||||
@@ -4009,6 +4039,382 @@ def verify_credentials(
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SESSION & CHECKPOINT TOOLS (read-only, no build session required)
|
||||
# =============================================================================
|
||||
|
||||
_MAX_DIFF_VALUE_LEN = 500
|
||||
|
||||
|
||||
def _read_session_json(path: Path) -> dict | None:
|
||||
"""Read a JSON file, returning None on failure."""
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _scan_agent_sessions(agent_work_dir: Path) -> list[tuple[str, Path]]:
|
||||
"""Find session directories with state.json, sorted most-recent-first."""
|
||||
sessions: list[tuple[str, Path]] = []
|
||||
sessions_dir = agent_work_dir / "sessions"
|
||||
if not sessions_dir.exists():
|
||||
return sessions
|
||||
for session_dir in sessions_dir.iterdir():
|
||||
if session_dir.is_dir() and session_dir.name.startswith("session_"):
|
||||
state_path = session_dir / "state.json"
|
||||
if state_path.exists():
|
||||
sessions.append((session_dir.name, state_path))
|
||||
sessions.sort(key=lambda t: t[0], reverse=True)
|
||||
return sessions
|
||||
|
||||
|
||||
def _truncate_value(value: object, max_len: int = _MAX_DIFF_VALUE_LEN) -> object:
|
||||
"""Truncate a value's JSON representation if too long."""
|
||||
s = json.dumps(value, default=str)
|
||||
if len(s) <= max_len:
|
||||
return value
|
||||
return {"_truncated": True, "_preview": s[:max_len] + "...", "_length": len(s)}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def list_agent_sessions(
|
||||
agent_work_dir: Annotated[
|
||||
str,
|
||||
"Path to the agent's working directory (e.g., ~/.hive/agents/my_agent)",
|
||||
],
|
||||
status: Annotated[
|
||||
str,
|
||||
"Filter by status: 'active', 'paused', 'completed', 'failed', 'cancelled'. Empty for all.",
|
||||
] = "",
|
||||
limit: Annotated[int, "Maximum number of results (default 20)"] = 20,
|
||||
offset: Annotated[int, "Number of sessions to skip for pagination"] = 0,
|
||||
) -> str:
|
||||
"""
|
||||
List sessions for an agent with optional status filter.
|
||||
|
||||
Use this to discover which sessions exist, find resumable sessions,
|
||||
or identify failed sessions for debugging. Combines well with
|
||||
query_runtime_logs for correlating session state with log data.
|
||||
"""
|
||||
work_dir = Path(agent_work_dir)
|
||||
all_sessions = _scan_agent_sessions(work_dir)
|
||||
|
||||
if not all_sessions:
|
||||
return json.dumps({"sessions": [], "total": 0, "offset": offset, "limit": limit})
|
||||
|
||||
summaries = []
|
||||
for session_id, state_path in all_sessions:
|
||||
data = _read_session_json(state_path)
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
session_status = data.get("status", "")
|
||||
if status and session_status != status:
|
||||
continue
|
||||
|
||||
timestamps = data.get("timestamps", {})
|
||||
progress = data.get("progress", {})
|
||||
checkpoint_dir = state_path.parent / "checkpoints"
|
||||
|
||||
summaries.append(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"status": session_status,
|
||||
"goal_id": data.get("goal_id", ""),
|
||||
"started_at": timestamps.get("started_at", ""),
|
||||
"updated_at": timestamps.get("updated_at", ""),
|
||||
"completed_at": timestamps.get("completed_at"),
|
||||
"is_resumable": data.get("is_resumable", False),
|
||||
"is_resumable_from_checkpoint": data.get("is_resumable_from_checkpoint", False),
|
||||
"current_node": progress.get("current_node"),
|
||||
"paused_at": progress.get("paused_at"),
|
||||
"steps_executed": progress.get("steps_executed", 0),
|
||||
"execution_quality": progress.get("execution_quality", ""),
|
||||
"has_checkpoints": checkpoint_dir.exists()
|
||||
and any(checkpoint_dir.glob("cp_*.json")),
|
||||
}
|
||||
)
|
||||
|
||||
total = len(summaries)
|
||||
page = summaries[offset : offset + limit]
|
||||
return json.dumps(
|
||||
{"sessions": page, "total": total, "offset": offset, "limit": limit}, indent=2
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_agent_session_state(
|
||||
agent_work_dir: Annotated[str, "Path to the agent's working directory"],
|
||||
session_id: Annotated[str, "The session ID (e.g., 'session_20260208_143022_abc12345')"],
|
||||
) -> str:
|
||||
"""
|
||||
Load full session state for a specific session.
|
||||
|
||||
Returns complete session data including status, progress, result,
|
||||
metrics, and checkpoint info. Memory values are excluded to prevent
|
||||
context bloat -- use get_agent_session_memory to retrieve memory contents.
|
||||
"""
|
||||
state_path = Path(agent_work_dir) / "sessions" / session_id / "state.json"
|
||||
data = _read_session_json(state_path)
|
||||
if data is None:
|
||||
return json.dumps({"error": f"Session not found: {session_id}"})
|
||||
|
||||
memory = data.get("memory", {})
|
||||
data["memory_keys"] = list(memory.keys()) if isinstance(memory, dict) else []
|
||||
data["memory_size"] = len(memory) if isinstance(memory, dict) else 0
|
||||
data.pop("memory", None)
|
||||
|
||||
return json.dumps(data, indent=2, default=str)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_agent_session_memory(
|
||||
agent_work_dir: Annotated[str, "Path to the agent's working directory"],
|
||||
session_id: Annotated[str, "The session ID"],
|
||||
key: Annotated[str, "Specific memory key to retrieve. Empty for all."] = "",
|
||||
) -> str:
|
||||
"""
|
||||
Get memory contents from a session.
|
||||
|
||||
Memory stores intermediate results passed between nodes. Use this
|
||||
to inspect what data was produced during execution.
|
||||
|
||||
If key is provided, returns only that memory key's value.
|
||||
If key is empty, returns all memory keys and their values.
|
||||
"""
|
||||
state_path = Path(agent_work_dir) / "sessions" / session_id / "state.json"
|
||||
data = _read_session_json(state_path)
|
||||
if data is None:
|
||||
return json.dumps({"error": f"Session not found: {session_id}"})
|
||||
|
||||
memory = data.get("memory", {})
|
||||
if not isinstance(memory, dict):
|
||||
memory = {}
|
||||
|
||||
if key:
|
||||
if key not in memory:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"Memory key not found: '{key}'",
|
||||
"available_keys": list(memory.keys()),
|
||||
}
|
||||
)
|
||||
value = memory[key]
|
||||
return json.dumps(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"key": key,
|
||||
"value": value,
|
||||
"value_type": type(value).__name__,
|
||||
},
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{"session_id": session_id, "memory": memory, "total_keys": len(memory)},
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def list_agent_checkpoints(
|
||||
agent_work_dir: Annotated[str, "Path to the agent's working directory"],
|
||||
session_id: Annotated[str, "The session ID to list checkpoints for"],
|
||||
checkpoint_type: Annotated[
|
||||
str,
|
||||
"Filter by type: 'node_start', 'node_complete', 'loop_iteration'. Empty for all.",
|
||||
] = "",
|
||||
is_clean: Annotated[str, "Filter by clean status: 'true', 'false', or empty for all."] = "",
|
||||
) -> str:
|
||||
"""
|
||||
List checkpoints for a specific session.
|
||||
|
||||
Checkpoints capture execution state at node boundaries for
|
||||
crash recovery and resume. Use with get_agent_checkpoint for
|
||||
detailed checkpoint inspection.
|
||||
"""
|
||||
session_dir = Path(agent_work_dir) / "sessions" / session_id
|
||||
checkpoint_dir = session_dir / "checkpoints"
|
||||
|
||||
if not session_dir.exists():
|
||||
return json.dumps({"error": f"Session not found: {session_id}"})
|
||||
|
||||
if not checkpoint_dir.exists():
|
||||
return json.dumps(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"checkpoints": [],
|
||||
"total": 0,
|
||||
"latest_checkpoint_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Try index.json first
|
||||
index_data = _read_session_json(checkpoint_dir / "index.json")
|
||||
if index_data and "checkpoints" in index_data:
|
||||
checkpoints = index_data["checkpoints"]
|
||||
else:
|
||||
# Fallback: scan individual checkpoint files
|
||||
checkpoints = []
|
||||
for cp_file in sorted(checkpoint_dir.glob("cp_*.json")):
|
||||
cp_data = _read_session_json(cp_file)
|
||||
if cp_data:
|
||||
checkpoints.append(
|
||||
{
|
||||
"checkpoint_id": cp_data.get("checkpoint_id", cp_file.stem),
|
||||
"checkpoint_type": cp_data.get("checkpoint_type", ""),
|
||||
"created_at": cp_data.get("created_at", ""),
|
||||
"current_node": cp_data.get("current_node"),
|
||||
"next_node": cp_data.get("next_node"),
|
||||
"is_clean": cp_data.get("is_clean", True),
|
||||
"description": cp_data.get("description", ""),
|
||||
}
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if checkpoint_type:
|
||||
checkpoints = [c for c in checkpoints if c.get("checkpoint_type") == checkpoint_type]
|
||||
if is_clean:
|
||||
clean_val = is_clean.lower() == "true"
|
||||
checkpoints = [c for c in checkpoints if c.get("is_clean") == clean_val]
|
||||
|
||||
latest_id = None
|
||||
if index_data:
|
||||
latest_id = index_data.get("latest_checkpoint_id")
|
||||
elif checkpoints:
|
||||
latest_id = checkpoints[-1].get("checkpoint_id")
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"checkpoints": checkpoints,
|
||||
"total": len(checkpoints),
|
||||
"latest_checkpoint_id": latest_id,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_agent_checkpoint(
|
||||
agent_work_dir: Annotated[str, "Path to the agent's working directory"],
|
||||
session_id: Annotated[str, "The session ID"],
|
||||
checkpoint_id: Annotated[str, "Specific checkpoint ID, or empty for latest"] = "",
|
||||
) -> str:
|
||||
"""
|
||||
Load a specific checkpoint with full state data.
|
||||
|
||||
Returns the complete checkpoint including shared memory snapshot,
|
||||
execution path, accumulated outputs, and metrics. If checkpoint_id
|
||||
is empty, loads the latest checkpoint.
|
||||
"""
|
||||
session_dir = Path(agent_work_dir) / "sessions" / session_id
|
||||
checkpoint_dir = session_dir / "checkpoints"
|
||||
|
||||
if not checkpoint_dir.exists():
|
||||
return json.dumps({"error": f"No checkpoints found for session: {session_id}"})
|
||||
|
||||
if not checkpoint_id:
|
||||
index_data = _read_session_json(checkpoint_dir / "index.json")
|
||||
if index_data and index_data.get("latest_checkpoint_id"):
|
||||
checkpoint_id = index_data["latest_checkpoint_id"]
|
||||
else:
|
||||
cp_files = sorted(checkpoint_dir.glob("cp_*.json"))
|
||||
if not cp_files:
|
||||
return json.dumps({"error": f"No checkpoints found for session: {session_id}"})
|
||||
checkpoint_id = cp_files[-1].stem
|
||||
|
||||
cp_path = checkpoint_dir / f"{checkpoint_id}.json"
|
||||
data = _read_session_json(cp_path)
|
||||
if data is None:
|
||||
return json.dumps({"error": f"Checkpoint not found: {checkpoint_id}"})
|
||||
|
||||
return json.dumps(data, indent=2, default=str)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def compare_agent_checkpoints(
|
||||
agent_work_dir: Annotated[str, "Path to the agent's working directory"],
|
||||
session_id: Annotated[str, "The session ID"],
|
||||
checkpoint_id_before: Annotated[str, "The earlier checkpoint ID"],
|
||||
checkpoint_id_after: Annotated[str, "The later checkpoint ID"],
|
||||
) -> str:
|
||||
"""
|
||||
Compare memory state between two checkpoints.
|
||||
|
||||
Shows what memory keys were added, removed, or changed between
|
||||
two points in execution. Useful for understanding how data flows
|
||||
through the agent graph.
|
||||
"""
|
||||
checkpoint_dir = Path(agent_work_dir) / "sessions" / session_id / "checkpoints"
|
||||
|
||||
before = _read_session_json(checkpoint_dir / f"{checkpoint_id_before}.json")
|
||||
if before is None:
|
||||
return json.dumps({"error": f"Checkpoint not found: {checkpoint_id_before}"})
|
||||
|
||||
after = _read_session_json(checkpoint_dir / f"{checkpoint_id_after}.json")
|
||||
if after is None:
|
||||
return json.dumps({"error": f"Checkpoint not found: {checkpoint_id_after}"})
|
||||
|
||||
mem_before = before.get("shared_memory", {})
|
||||
mem_after = after.get("shared_memory", {})
|
||||
|
||||
keys_before = set(mem_before.keys())
|
||||
keys_after = set(mem_after.keys())
|
||||
|
||||
added = {k: _truncate_value(mem_after[k]) for k in keys_after - keys_before}
|
||||
removed = list(keys_before - keys_after)
|
||||
unchanged = []
|
||||
changed = {}
|
||||
|
||||
for k in keys_before & keys_after:
|
||||
if mem_before[k] == mem_after[k]:
|
||||
unchanged.append(k)
|
||||
else:
|
||||
changed[k] = {
|
||||
"before": _truncate_value(mem_before[k]),
|
||||
"after": _truncate_value(mem_after[k]),
|
||||
}
|
||||
|
||||
path_before = before.get("execution_path", [])
|
||||
path_after = after.get("execution_path", [])
|
||||
new_nodes = path_after[len(path_before) :]
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"before": {
|
||||
"checkpoint_id": checkpoint_id_before,
|
||||
"current_node": before.get("current_node"),
|
||||
"created_at": before.get("created_at", ""),
|
||||
},
|
||||
"after": {
|
||||
"checkpoint_id": checkpoint_id_after,
|
||||
"current_node": after.get("current_node"),
|
||||
"created_at": after.get("created_at", ""),
|
||||
},
|
||||
"memory_diff": {
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"changed": changed,
|
||||
"unchanged": unchanged,
|
||||
},
|
||||
"execution_path_diff": {
|
||||
"new_nodes": new_nodes,
|
||||
"path_before": path_before,
|
||||
"path_after": path_after,
|
||||
},
|
||||
},
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN
|
||||
# =============================================================================
|
||||
|
||||
@@ -332,10 +332,67 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
resume_parser.set_defaults(func=cmd_resume)
|
||||
|
||||
|
||||
def _load_resume_state(
|
||||
agent_path: str, session_id: str, checkpoint_id: str | None = None
|
||||
) -> dict | None:
|
||||
"""Load session or checkpoint state for headless resume.
|
||||
|
||||
Args:
|
||||
agent_path: Path to the agent folder (e.g., exports/my_agent)
|
||||
session_id: Session ID to resume from
|
||||
checkpoint_id: Optional checkpoint ID within the session
|
||||
|
||||
Returns:
|
||||
session_state dict for executor, or None if not found
|
||||
"""
|
||||
agent_name = Path(agent_path).name
|
||||
agent_work_dir = Path.home() / ".hive" / "agents" / agent_name
|
||||
session_dir = agent_work_dir / "sessions" / session_id
|
||||
|
||||
if not session_dir.exists():
|
||||
return None
|
||||
|
||||
if checkpoint_id:
|
||||
# Checkpoint-based resume: load checkpoint and extract state
|
||||
cp_path = session_dir / "checkpoints" / f"{checkpoint_id}.json"
|
||||
if not cp_path.exists():
|
||||
return None
|
||||
try:
|
||||
cp_data = json.loads(cp_path.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
return {
|
||||
"resume_session_id": session_id,
|
||||
"memory": cp_data.get("shared_memory", {}),
|
||||
"paused_at": cp_data.get("next_node") or cp_data.get("current_node"),
|
||||
"execution_path": cp_data.get("execution_path", []),
|
||||
"node_visit_counts": {},
|
||||
}
|
||||
else:
|
||||
# Session state resume: load state.json
|
||||
state_path = session_dir / "state.json"
|
||||
if not state_path.exists():
|
||||
return None
|
||||
try:
|
||||
state_data = json.loads(state_path.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
progress = state_data.get("progress", {})
|
||||
paused_at = progress.get("paused_at") or progress.get("resume_from")
|
||||
return {
|
||||
"resume_session_id": session_id,
|
||||
"memory": state_data.get("memory", {}),
|
||||
"paused_at": paused_at,
|
||||
"execution_path": progress.get("path", []),
|
||||
"node_visit_counts": progress.get("node_visit_counts", {}),
|
||||
}
|
||||
|
||||
|
||||
def cmd_run(args: argparse.Namespace) -> int:
|
||||
"""Run an exported agent."""
|
||||
import logging
|
||||
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Set logging level (quiet by default for cleaner output)
|
||||
@@ -374,8 +431,10 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
runner = AgentRunner.load(
|
||||
args.agent_path,
|
||||
model=args.model,
|
||||
enable_tui=True,
|
||||
)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error loading agent: {e}")
|
||||
return
|
||||
@@ -415,12 +474,35 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
runner = AgentRunner.load(
|
||||
args.agent_path,
|
||||
model=args.model,
|
||||
enable_tui=False,
|
||||
)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Load session/checkpoint state for resume (headless mode)
|
||||
session_state = None
|
||||
resume_session = getattr(args, "resume_session", None)
|
||||
checkpoint = getattr(args, "checkpoint", None)
|
||||
if resume_session:
|
||||
session_state = _load_resume_state(args.agent_path, resume_session, checkpoint)
|
||||
if session_state is None:
|
||||
print(
|
||||
f"Error: Could not load session state for {resume_session}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
if not args.quiet:
|
||||
resume_node = session_state.get("paused_at", "unknown")
|
||||
if checkpoint:
|
||||
print(f"Resuming from checkpoint: {checkpoint}")
|
||||
else:
|
||||
print(f"Resuming session: {resume_session}")
|
||||
print(f"Resume point: {resume_node}")
|
||||
print()
|
||||
|
||||
# Auto-inject user_id if the agent expects it but it's not provided
|
||||
entry_input_keys = runner.graph.nodes[0].input_keys if runner.graph.nodes else []
|
||||
if "user_id" in entry_input_keys and context.get("user_id") is None:
|
||||
@@ -440,7 +522,7 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
result = asyncio.run(runner.run(context))
|
||||
result = asyncio.run(runner.run(context, session_state=session_state))
|
||||
|
||||
# Format output
|
||||
output = {
|
||||
@@ -520,10 +602,14 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
|
||||
def cmd_info(args: argparse.Namespace) -> int:
|
||||
"""Show agent information."""
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
try:
|
||||
runner = AgentRunner.load(args.agent_path)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
@@ -583,10 +669,14 @@ def cmd_info(args: argparse.Namespace) -> int:
|
||||
|
||||
def cmd_validate(args: argparse.Namespace) -> int:
|
||||
"""Validate an exported agent."""
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
try:
|
||||
runner = AgentRunner.load(args.agent_path)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
@@ -903,6 +993,7 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
"""Start an interactive agent session."""
|
||||
import logging
|
||||
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Configure logging to show runtime visibility
|
||||
@@ -927,6 +1018,9 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
|
||||
try:
|
||||
runner = AgentRunner.load(agent_path)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
@@ -1136,6 +1230,7 @@ def cmd_tui(args: argparse.Namespace) -> int:
|
||||
"""Browse agents and launch the interactive TUI dashboard."""
|
||||
import logging
|
||||
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
from framework.tui.app import AdenTUI
|
||||
|
||||
@@ -1185,8 +1280,10 @@ def cmd_tui(args: argparse.Namespace) -> int:
|
||||
runner = AgentRunner.load(
|
||||
agent_path,
|
||||
model=args.model,
|
||||
enable_tui=True,
|
||||
)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error loading agent: {e}")
|
||||
return
|
||||
@@ -1352,6 +1449,7 @@ def _select_agent(agents_dir: Path) -> str | None:
|
||||
for path in agents_dir.iterdir():
|
||||
if _is_valid_agent_dir(path):
|
||||
agents.append(path)
|
||||
agents.sort(key=lambda p: p.name)
|
||||
|
||||
if not agents:
|
||||
print(f"No agents found in {agents_dir}", file=sys.stderr)
|
||||
|
||||
@@ -183,8 +183,11 @@ class MCPClient:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
# Create persistent stdio client context
|
||||
self._stdio_context = stdio_client(server_params)
|
||||
# Create persistent stdio client context.
|
||||
# Redirect server stderr to devnull to prevent raw
|
||||
# output from leaking behind the TUI.
|
||||
devnull = open(os.devnull, "w") # noqa: SIM115
|
||||
self._stdio_context = stdio_client(server_params, errlog=devnull)
|
||||
(
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
|
||||
@@ -456,7 +456,7 @@ Respond with JSON only:
|
||||
}}"""
|
||||
|
||||
try:
|
||||
response = self._llm.complete(
|
||||
response = await self._llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a request router. Respond with JSON only.",
|
||||
max_tokens=256,
|
||||
|
||||
+156
-174
@@ -9,6 +9,10 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.config import get_hive_config, get_preferred_model
|
||||
from framework.credentials.validation import (
|
||||
ensure_credential_key_env as _ensure_credential_key_env,
|
||||
validate_agent_credentials,
|
||||
)
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import (
|
||||
DEFAULT_MAX_TOKENS,
|
||||
@@ -17,17 +21,13 @@ from framework.graph.edge import (
|
||||
EdgeSpec,
|
||||
GraphSpec,
|
||||
)
|
||||
from framework.graph.executor import ExecutionResult, GraphExecutor
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
|
||||
# Multi-entry-point runtime imports
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.runtime_log_store import RuntimeLogStore
|
||||
from framework.runtime.runtime_logger import RuntimeLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner.protocol import AgentMessage, CapabilityResponse
|
||||
@@ -35,32 +35,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ensure_credential_key_env() -> None:
|
||||
"""Load HIVE_CREDENTIAL_KEY from shell config if not already in environment.
|
||||
|
||||
The setup-credentials skill writes the encryption key to ~/.zshrc or ~/.bashrc.
|
||||
If the user hasn't sourced their config in the current shell, this reads it
|
||||
directly so the runner (and any MCP subprocesses it spawns) can unlock the
|
||||
encrypted credential store.
|
||||
|
||||
Only HIVE_CREDENTIAL_KEY is loaded this way — all other secrets (API keys, etc.)
|
||||
come from the credential store itself.
|
||||
"""
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
return
|
||||
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import check_env_var_in_shell_config
|
||||
|
||||
found, value = check_env_var_in_shell_config("HIVE_CREDENTIAL_KEY")
|
||||
if found and value:
|
||||
os.environ["HIVE_CREDENTIAL_KEY"] = value
|
||||
logger.debug("Loaded HIVE_CREDENTIAL_KEY from shell config")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
CLAUDE_CREDENTIALS_FILE = Path.home() / ".claude" / ".credentials.json"
|
||||
|
||||
|
||||
@@ -271,7 +245,8 @@ class AgentRunner:
|
||||
mock_mode: bool = False,
|
||||
storage_path: Path | None = None,
|
||||
model: str | None = None,
|
||||
enable_tui: bool = False,
|
||||
intro_message: str = "",
|
||||
runtime_config: "AgentRuntimeConfig | None" = None,
|
||||
):
|
||||
"""
|
||||
Initialize the runner (use AgentRunner.load() instead).
|
||||
@@ -283,14 +258,16 @@ class AgentRunner:
|
||||
mock_mode: If True, use mock LLM responses
|
||||
storage_path: Path for runtime storage (defaults to temp)
|
||||
model: Model to use (reads from agent config or ~/.hive/configuration.json if None)
|
||||
enable_tui: If True, forces use of AgentRuntime with EventBus
|
||||
intro_message: Optional greeting shown to user on TUI load
|
||||
runtime_config: Optional AgentRuntimeConfig (webhook settings, etc.)
|
||||
"""
|
||||
self.agent_path = agent_path
|
||||
self.graph = graph
|
||||
self.goal = goal
|
||||
self.mock_mode = mock_mode
|
||||
self.model = model or self._resolve_default_model()
|
||||
self.enable_tui = enable_tui
|
||||
self.intro_message = intro_message
|
||||
self.runtime_config = runtime_config
|
||||
|
||||
# Set up storage
|
||||
if storage_path:
|
||||
@@ -310,15 +287,17 @@ class AgentRunner:
|
||||
|
||||
# Initialize components
|
||||
self._tool_registry = ToolRegistry()
|
||||
self._runtime: Runtime | None = None
|
||||
self._llm: LLMProvider | None = None
|
||||
self._executor: GraphExecutor | None = None
|
||||
self._approval_callback: Callable | None = None
|
||||
|
||||
# Multi-entry-point support (AgentRuntime)
|
||||
# AgentRuntime — unified execution path for all agents
|
||||
self._agent_runtime: AgentRuntime | None = None
|
||||
self._uses_async_entry_points = self.graph.has_async_entry_points()
|
||||
|
||||
# Validate credentials before spawning MCP servers.
|
||||
# Fails fast with actionable guidance — no MCP noise on screen.
|
||||
self._validate_credentials()
|
||||
|
||||
# Auto-discover tools from tools.py
|
||||
tools_path = agent_path / "tools.py"
|
||||
if tools_path.exists():
|
||||
@@ -329,6 +308,13 @@ class AgentRunner:
|
||||
if mcp_config_path.exists():
|
||||
self._load_mcp_servers_from_config(mcp_config_path)
|
||||
|
||||
def _validate_credentials(self) -> None:
|
||||
"""Check that required credentials are available before spawning MCP servers.
|
||||
|
||||
Raises CredentialError with actionable guidance if any are missing.
|
||||
"""
|
||||
validate_agent_credentials(self.graph.nodes)
|
||||
|
||||
@staticmethod
|
||||
def _import_agent_module(agent_path: Path):
|
||||
"""Import an agent package from its directory path.
|
||||
@@ -372,7 +358,6 @@ class AgentRunner:
|
||||
mock_mode: bool = False,
|
||||
storage_path: Path | None = None,
|
||||
model: str | None = None,
|
||||
enable_tui: bool = False,
|
||||
) -> "AgentRunner":
|
||||
"""
|
||||
Load an agent from an export folder.
|
||||
@@ -386,7 +371,6 @@ class AgentRunner:
|
||||
mock_mode: If True, use mock LLM responses
|
||||
storage_path: Path for runtime storage (defaults to ~/.hive/agents/{name})
|
||||
model: LLM model to use (reads from agent's default_config if None)
|
||||
enable_tui: If True, forces use of AgentRuntime with EventBus
|
||||
|
||||
Returns:
|
||||
AgentRunner instance ready to run
|
||||
@@ -420,19 +404,39 @@ class AgentRunner:
|
||||
hive_config = get_hive_config()
|
||||
max_tokens = hive_config.get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
|
||||
|
||||
# Read intro_message from agent metadata (shown on TUI load)
|
||||
agent_metadata = getattr(agent_module, "metadata", None)
|
||||
intro_message = ""
|
||||
if agent_metadata and hasattr(agent_metadata, "intro_message"):
|
||||
intro_message = agent_metadata.intro_message
|
||||
|
||||
# Build GraphSpec from module-level variables
|
||||
graph = GraphSpec(
|
||||
id=f"{agent_path.name}-graph",
|
||||
goal_id=goal.id,
|
||||
version="1.0.0",
|
||||
entry_node=getattr(agent_module, "entry_node", nodes[0].id),
|
||||
entry_points=getattr(agent_module, "entry_points", {}),
|
||||
terminal_nodes=getattr(agent_module, "terminal_nodes", []),
|
||||
pause_nodes=getattr(agent_module, "pause_nodes", []),
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
graph_kwargs: dict = {
|
||||
"id": f"{agent_path.name}-graph",
|
||||
"goal_id": goal.id,
|
||||
"version": "1.0.0",
|
||||
"entry_node": getattr(agent_module, "entry_node", nodes[0].id),
|
||||
"entry_points": getattr(agent_module, "entry_points", {}),
|
||||
"async_entry_points": getattr(agent_module, "async_entry_points", []),
|
||||
"terminal_nodes": getattr(agent_module, "terminal_nodes", []),
|
||||
"pause_nodes": getattr(agent_module, "pause_nodes", []),
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"max_tokens": max_tokens,
|
||||
"loop_config": getattr(agent_module, "loop_config", {}),
|
||||
}
|
||||
# Only pass optional fields if explicitly defined by the agent module
|
||||
conversation_mode = getattr(agent_module, "conversation_mode", None)
|
||||
if conversation_mode is not None:
|
||||
graph_kwargs["conversation_mode"] = conversation_mode
|
||||
identity_prompt = getattr(agent_module, "identity_prompt", None)
|
||||
if identity_prompt is not None:
|
||||
graph_kwargs["identity_prompt"] = identity_prompt
|
||||
|
||||
graph = GraphSpec(**graph_kwargs)
|
||||
|
||||
# Read runtime config (webhook settings, etc.) if defined
|
||||
agent_runtime_config = getattr(agent_module, "runtime_config", None)
|
||||
|
||||
return cls(
|
||||
agent_path=agent_path,
|
||||
@@ -441,7 +445,8 @@ class AgentRunner:
|
||||
mock_mode=mock_mode,
|
||||
storage_path=storage_path,
|
||||
model=model,
|
||||
enable_tui=enable_tui,
|
||||
intro_message=intro_message,
|
||||
runtime_config=agent_runtime_config,
|
||||
)
|
||||
|
||||
# Fallback: load from agent.json (legacy JSON-based agents)
|
||||
@@ -459,7 +464,6 @@ class AgentRunner:
|
||||
mock_mode=mock_mode,
|
||||
storage_path=storage_path,
|
||||
model=model,
|
||||
enable_tui=enable_tui,
|
||||
)
|
||||
|
||||
def register_tool(
|
||||
@@ -549,9 +553,6 @@ class AgentRunner:
|
||||
callback: Function to call for approval (receives node info, returns bool)
|
||||
"""
|
||||
self._approval_callback = callback
|
||||
# If executor already exists, update it
|
||||
if self._executor is not None:
|
||||
self._executor.approval_callback = callback
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Set up runtime, LLM, and executor."""
|
||||
@@ -600,7 +601,10 @@ class AgentRunner:
|
||||
self._llm = LiteLLMProvider(model=self.model, api_key=api_key)
|
||||
else:
|
||||
# Fall back to environment variable
|
||||
api_key_env = self._get_api_key_env_var(self.model)
|
||||
# First check api_key_env_var from config (set by quickstart)
|
||||
api_key_env = llm_config.get("api_key_env_var") or self._get_api_key_env_var(
|
||||
self.model
|
||||
)
|
||||
if api_key_env and os.environ.get(api_key_env):
|
||||
self._llm = LiteLLMProvider(model=self.model)
|
||||
else:
|
||||
@@ -616,16 +620,11 @@ class AgentRunner:
|
||||
print(f"Warning: {api_key_env} not set. LLM calls will fail.")
|
||||
print(f"Set it with: export {api_key_env}=your-api-key")
|
||||
|
||||
# Get tools for executor/runtime
|
||||
# Get tools for runtime
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
tool_executor = self._tool_registry.get_executor()
|
||||
|
||||
if self._uses_async_entry_points or self.enable_tui:
|
||||
# Multi-entry-point mode or TUI mode: use AgentRuntime
|
||||
self._setup_agent_runtime(tools, tool_executor)
|
||||
else:
|
||||
# Single-entry-point mode: use legacy GraphExecutor
|
||||
self._setup_legacy_executor(tools, tool_executor)
|
||||
self._setup_agent_runtime(tools, tool_executor)
|
||||
|
||||
def _get_api_key_env_var(self, model: str) -> str | None:
|
||||
"""Get the environment variable name for the API key based on model name."""
|
||||
@@ -640,7 +639,7 @@ class AgentRunner:
|
||||
elif model_lower.startswith("anthropic/") or model_lower.startswith("claude"):
|
||||
return "ANTHROPIC_API_KEY"
|
||||
elif model_lower.startswith("gemini/") or model_lower.startswith("google/"):
|
||||
return "GOOGLE_API_KEY"
|
||||
return "GEMINI_API_KEY"
|
||||
elif model_lower.startswith("mistral/"):
|
||||
return "MISTRAL_API_KEY"
|
||||
elif model_lower.startswith("groq/"):
|
||||
@@ -686,26 +685,6 @@ class AgentRunner:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _setup_legacy_executor(self, tools: list, tool_executor: Callable | None) -> None:
|
||||
"""Set up legacy single-entry-point execution using GraphExecutor."""
|
||||
# Create runtime
|
||||
self._runtime = Runtime(storage_path=self._storage_path)
|
||||
|
||||
# Create runtime logger
|
||||
log_store = RuntimeLogStore(base_path=self._storage_path / "runtime_logs")
|
||||
runtime_logger = RuntimeLogger(store=log_store, agent_id=self.graph.id)
|
||||
|
||||
# Create executor
|
||||
self._executor = GraphExecutor(
|
||||
runtime=self._runtime,
|
||||
llm=self._llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
approval_callback=self._approval_callback,
|
||||
runtime_logger=runtime_logger,
|
||||
loop_config=self.graph.loop_config,
|
||||
)
|
||||
|
||||
def _setup_agent_runtime(self, tools: list, tool_executor: Callable | None) -> None:
|
||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
||||
# Convert AsyncEntryPointSpec to EntryPointSpec for AgentRuntime
|
||||
@@ -723,17 +702,19 @@ class AgentRunner:
|
||||
)
|
||||
entry_points.append(ep)
|
||||
|
||||
# If TUI enabled but no entry points (single-entry agent), create default
|
||||
if not entry_points and self.enable_tui and self.graph.entry_node:
|
||||
logger.info("Creating default entry point for TUI")
|
||||
entry_points.append(
|
||||
# Always create a primary entry point for the graph's entry node.
|
||||
# For multi-entry-point agents this ensures the primary path (e.g.
|
||||
# user-facing rule setup) is reachable alongside async entry points.
|
||||
if self.graph.entry_node:
|
||||
entry_points.insert(
|
||||
0,
|
||||
EntryPointSpec(
|
||||
id="default",
|
||||
name="Default",
|
||||
entry_node=self.graph.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Create AgentRuntime with all entry points
|
||||
@@ -760,8 +741,12 @@ class AgentRunner:
|
||||
tool_executor=tool_executor,
|
||||
runtime_log_store=log_store,
|
||||
checkpoint_config=checkpoint_config,
|
||||
config=self.runtime_config,
|
||||
)
|
||||
|
||||
# Pass intro_message through for TUI display
|
||||
self._agent_runtime.intro_message = self.intro_message
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: dict | None = None,
|
||||
@@ -801,32 +786,9 @@ class AgentRunner:
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
if self._uses_async_entry_points or self.enable_tui:
|
||||
# Multi-entry-point mode: use AgentRuntime
|
||||
return await self._run_with_agent_runtime(
|
||||
input_data=input_data or {},
|
||||
entry_point_id=entry_point_id,
|
||||
)
|
||||
else:
|
||||
# Legacy single-entry-point mode
|
||||
return await self._run_with_executor(
|
||||
input_data=input_data or {},
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
async def _run_with_executor(
|
||||
self,
|
||||
input_data: dict,
|
||||
session_state: dict | None = None,
|
||||
) -> ExecutionResult:
|
||||
"""Run using legacy GraphExecutor (single entry point)."""
|
||||
if self._executor is None:
|
||||
self._setup()
|
||||
|
||||
return await self._executor.execute(
|
||||
graph=self.graph,
|
||||
goal=self.goal,
|
||||
input_data=input_data,
|
||||
return await self._run_with_agent_runtime(
|
||||
input_data=input_data or {},
|
||||
entry_point_id=entry_point_id,
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
@@ -834,8 +796,11 @@ class AgentRunner:
|
||||
self,
|
||||
input_data: dict,
|
||||
entry_point_id: str | None = None,
|
||||
session_state: dict | None = None,
|
||||
) -> ExecutionResult:
|
||||
"""Run using AgentRuntime (multi-entry-point)."""
|
||||
"""Run using AgentRuntime."""
|
||||
import sys
|
||||
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
@@ -843,6 +808,52 @@ class AgentRunner:
|
||||
if not self._agent_runtime.is_running:
|
||||
await self._agent_runtime.start()
|
||||
|
||||
# Set up stdin-based I/O for client-facing nodes in headless mode.
|
||||
# When a client_facing EventLoopNode calls ask_user(), it emits
|
||||
# CLIENT_INPUT_REQUESTED on the event bus and blocks. We subscribe
|
||||
# a handler that prints the prompt and reads from stdin, then injects
|
||||
# the user's response back into the node to unblock it.
|
||||
has_client_facing = any(n.client_facing for n in self.graph.nodes)
|
||||
sub_ids: list[str] = []
|
||||
|
||||
if has_client_facing and sys.stdin.isatty():
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
runtime = self._agent_runtime
|
||||
|
||||
async def _handle_client_output(event):
|
||||
"""Print agent output to stdout as it streams."""
|
||||
content = event.data.get("content", "")
|
||||
if content:
|
||||
print(content, end="", flush=True)
|
||||
|
||||
async def _handle_input_requested(event):
|
||||
"""Read user input from stdin and inject it into the node."""
|
||||
import asyncio
|
||||
|
||||
node_id = event.node_id
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
user_input = await loop.run_in_executor(None, input, "\n>>> ")
|
||||
except EOFError:
|
||||
user_input = ""
|
||||
|
||||
# Inject into the waiting EventLoopNode via runtime
|
||||
await runtime.inject_input(node_id, user_input)
|
||||
|
||||
sub_ids.append(
|
||||
runtime.subscribe_to_events(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA],
|
||||
handler=_handle_client_output,
|
||||
)
|
||||
)
|
||||
sub_ids.append(
|
||||
runtime.subscribe_to_events(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=_handle_input_requested,
|
||||
)
|
||||
)
|
||||
|
||||
# Determine entry point
|
||||
if entry_point_id is None:
|
||||
# Use first entry point or "default" if no entry points defined
|
||||
@@ -852,44 +863,38 @@ class AgentRunner:
|
||||
else:
|
||||
entry_point_id = "default"
|
||||
|
||||
# Trigger and wait for result
|
||||
result = await self._agent_runtime.trigger_and_wait(
|
||||
entry_point_id=entry_point_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
# Return result or create error result
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error="Execution timed out or failed to complete",
|
||||
try:
|
||||
# Trigger and wait for result
|
||||
result = await self._agent_runtime.trigger_and_wait(
|
||||
entry_point_id=entry_point_id,
|
||||
input_data=input_data,
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
# === Multi-Entry-Point API (for agents with async_entry_points) ===
|
||||
# Return result or create error result
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error="Execution timed out or failed to complete",
|
||||
)
|
||||
finally:
|
||||
# Clean up subscriptions
|
||||
for sub_id in sub_ids:
|
||||
self._agent_runtime.unsubscribe_from_events(sub_id)
|
||||
|
||||
# === Runtime API ===
|
||||
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the agent runtime (for multi-entry-point agents).
|
||||
|
||||
This starts all registered entry points and allows concurrent execution.
|
||||
For single-entry-point agents, this is a no-op.
|
||||
"""
|
||||
if not self._uses_async_entry_points:
|
||||
return
|
||||
|
||||
"""Start the agent runtime."""
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
await self._agent_runtime.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the agent runtime (for multi-entry-point agents).
|
||||
|
||||
For single-entry-point agents, this is a no-op.
|
||||
"""
|
||||
"""Stop the agent runtime."""
|
||||
if self._agent_runtime is not None:
|
||||
await self._agent_runtime.stop()
|
||||
|
||||
@@ -902,7 +907,7 @@ class AgentRunner:
|
||||
"""
|
||||
Trigger execution at a specific entry point (non-blocking).
|
||||
|
||||
For multi-entry-point agents only. Returns execution ID for tracking.
|
||||
Returns execution ID for tracking.
|
||||
|
||||
Args:
|
||||
entry_point_id: Which entry point to trigger
|
||||
@@ -911,16 +916,7 @@ class AgentRunner:
|
||||
|
||||
Returns:
|
||||
Execution ID for tracking
|
||||
|
||||
Raises:
|
||||
RuntimeError: If agent doesn't use async entry points
|
||||
"""
|
||||
if not self._uses_async_entry_points:
|
||||
raise RuntimeError(
|
||||
"trigger() is only available for multi-entry-point agents. "
|
||||
"Use run() for single-entry-point agents."
|
||||
)
|
||||
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
@@ -937,19 +933,9 @@ class AgentRunner:
|
||||
"""
|
||||
Get goal progress across all execution streams.
|
||||
|
||||
For multi-entry-point agents only.
|
||||
|
||||
Returns:
|
||||
Dict with overall_progress, criteria_status, constraint_violations, etc.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If agent doesn't use async entry points
|
||||
"""
|
||||
if not self._uses_async_entry_points:
|
||||
raise RuntimeError(
|
||||
"get_goal_progress() is only available for multi-entry-point agents."
|
||||
)
|
||||
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
@@ -957,14 +943,11 @@ class AgentRunner:
|
||||
|
||||
def get_entry_points(self) -> list[EntryPointSpec]:
|
||||
"""
|
||||
Get all registered entry points (for multi-entry-point agents).
|
||||
Get all registered entry points.
|
||||
|
||||
Returns:
|
||||
List of EntryPointSpec objects
|
||||
"""
|
||||
if not self._uses_async_entry_points:
|
||||
return []
|
||||
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
@@ -1244,7 +1227,7 @@ Respond with JSON only:
|
||||
}}"""
|
||||
|
||||
try:
|
||||
response = eval_llm.complete(
|
||||
response = await eval_llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a capability evaluator. Respond with JSON only.",
|
||||
max_tokens=256,
|
||||
@@ -1388,7 +1371,7 @@ Respond with JSON only:
|
||||
self._temp_dir = None
|
||||
|
||||
async def cleanup_async(self) -> None:
|
||||
"""Clean up resources (asynchronous - for multi-entry-point agents)."""
|
||||
"""Clean up resources (asynchronous)."""
|
||||
# Stop agent runtime if running
|
||||
if self._agent_runtime is not None and self._agent_runtime.is_running:
|
||||
await self._agent_runtime.stop()
|
||||
@@ -1399,8 +1382,7 @@ Respond with JSON only:
|
||||
async def __aenter__(self) -> "AgentRunner":
|
||||
"""Context manager entry."""
|
||||
self._setup()
|
||||
# Start runtime for multi-entry-point agents
|
||||
if self._uses_async_entry_points and self._agent_runtime is not None:
|
||||
if self._agent_runtime is not None:
|
||||
await self._agent_runtime.start()
|
||||
return self
|
||||
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
# Agent Runtime
|
||||
|
||||
Unified execution system for all Hive agents. Every agent — single-entry or multi-entry, headless or TUI — runs through the same runtime stack.
|
||||
|
||||
## Topology
|
||||
|
||||
```
|
||||
AgentRunner.load(agent_path)
|
||||
|
|
||||
AgentRunner
|
||||
(factory + public API)
|
||||
|
|
||||
_setup_agent_runtime()
|
||||
|
|
||||
AgentRuntime
|
||||
(lifecycle + orchestration)
|
||||
/ | \
|
||||
Stream A Stream B Stream C ← one per entry point
|
||||
| | |
|
||||
GraphExecutor GraphExecutor GraphExecutor
|
||||
| | |
|
||||
Node → Node → Node (graph traversal)
|
||||
```
|
||||
|
||||
Single-entry agents get a `"default"` entry point automatically. There is no separate code path.
|
||||
|
||||
## Components
|
||||
|
||||
| Component | File | Role |
|
||||
|---|---|---|
|
||||
| `AgentRunner` | `runner/runner.py` | Load agents, configure tools/LLM, expose high-level API |
|
||||
| `AgentRuntime` | `runtime/agent_runtime.py` | Lifecycle management, entry point routing, event bus |
|
||||
| `ExecutionStream` | `runtime/execution_stream.py` | Per-entry-point execution queue, session persistence |
|
||||
| `GraphExecutor` | `graph/executor.py` | Node traversal, tool dispatch, checkpointing |
|
||||
| `EventBus` | `runtime/event_bus.py` | Pub/sub for execution events (streaming, I/O) |
|
||||
| `SharedStateManager` | `runtime/shared_state.py` | Cross-stream state with isolation levels |
|
||||
| `OutcomeAggregator` | `runtime/outcome_aggregator.py` | Goal progress tracking across streams |
|
||||
| `SessionStore` | `storage/session_store.py` | Session state persistence (`sessions/{id}/state.json`) |
|
||||
|
||||
## Programming Interface
|
||||
|
||||
### AgentRunner (high-level)
|
||||
|
||||
```python
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Load and run
|
||||
runner = AgentRunner.load("exports/my_agent", model="anthropic/claude-sonnet-4-20250514")
|
||||
result = await runner.run({"query": "hello"})
|
||||
|
||||
# Resume from paused session
|
||||
result = await runner.run({"query": "continue"}, session_state=saved_state)
|
||||
|
||||
# Lifecycle
|
||||
await runner.start() # Start the runtime
|
||||
await runner.stop() # Stop the runtime
|
||||
exec_id = await runner.trigger("default", {}) # Non-blocking trigger
|
||||
progress = await runner.get_goal_progress() # Goal evaluation
|
||||
entry_points = runner.get_entry_points() # List entry points
|
||||
|
||||
# Context manager
|
||||
async with AgentRunner.load("exports/my_agent") as runner:
|
||||
result = await runner.run({"query": "hello"})
|
||||
|
||||
# Cleanup
|
||||
runner.cleanup() # Synchronous
|
||||
await runner.cleanup_async() # Asynchronous
|
||||
```
|
||||
|
||||
### AgentRuntime (lower-level)
|
||||
|
||||
```python
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
|
||||
# Create runtime with entry points
|
||||
runtime = create_agent_runtime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path("~/.hive/agents/my_agent"),
|
||||
entry_points=[
|
||||
EntryPointSpec(id="default", name="Default", entry_node="start", trigger_type="manual"),
|
||||
],
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
|
||||
# Lifecycle
|
||||
await runtime.start()
|
||||
await runtime.stop()
|
||||
|
||||
# Execution
|
||||
exec_id = await runtime.trigger("default", {"query": "hello"}) # Non-blocking
|
||||
result = await runtime.trigger_and_wait("default", {"query": "hello"}) # Blocking
|
||||
result = await runtime.trigger_and_wait("default", {}, session_state=state) # Resume
|
||||
|
||||
# Client-facing node I/O
|
||||
await runtime.inject_input(node_id="chat", content="user response")
|
||||
|
||||
# Events
|
||||
sub_id = runtime.subscribe_to_events(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA],
|
||||
handler=my_handler,
|
||||
)
|
||||
runtime.unsubscribe_from_events(sub_id)
|
||||
|
||||
# Inspection
|
||||
runtime.is_running # bool
|
||||
runtime.event_bus # EventBus
|
||||
runtime.state_manager # SharedStateManager
|
||||
runtime.get_stats() # Runtime statistics
|
||||
```
|
||||
|
||||
## Execution Flow
|
||||
|
||||
1. `AgentRunner.run()` calls `AgentRuntime.trigger_and_wait()`
|
||||
2. `AgentRuntime` routes to the `ExecutionStream` for the entry point
|
||||
3. `ExecutionStream` creates a `GraphExecutor` and calls `execute()`
|
||||
4. `GraphExecutor` traverses nodes, dispatches tools, manages checkpoints
|
||||
5. `ExecutionResult` flows back up through the stack
|
||||
6. `ExecutionStream` writes session state to disk
|
||||
|
||||
## Session Resume
|
||||
|
||||
All execution paths support session resume:
|
||||
|
||||
```python
|
||||
# First run (agent pauses at a client-facing node)
|
||||
result = await runner.run({"query": "start task"})
|
||||
# result.paused_at = "review-node"
|
||||
# result.session_state = {"memory": {...}, "paused_at": "review-node", ...}
|
||||
|
||||
# Resume
|
||||
result = await runner.run({"input": "approved"}, session_state=result.session_state)
|
||||
```
|
||||
|
||||
Session state flows: `AgentRunner.run()` → `AgentRuntime.trigger_and_wait()` → `ExecutionStream.execute()` → `GraphExecutor.execute()`.
|
||||
|
||||
Checkpoints are saved at node boundaries (`sessions/{id}/checkpoints/`) for crash recovery.
|
||||
|
||||
## Event Bus
|
||||
|
||||
The `EventBus` provides real-time execution visibility:
|
||||
|
||||
| Event | When |
|
||||
|---|---|
|
||||
| `NODE_STARTED` | Node begins execution |
|
||||
| `NODE_COMPLETED` | Node finishes |
|
||||
| `TOOL_CALL_STARTED` | Tool invocation begins |
|
||||
| `TOOL_CALL_COMPLETED` | Tool invocation finishes |
|
||||
| `CLIENT_OUTPUT_DELTA` | Agent streams text to user |
|
||||
| `CLIENT_INPUT_REQUESTED` | Agent needs user input |
|
||||
| `EXECUTION_COMPLETED` | Full execution finishes |
|
||||
|
||||
In headless mode, `AgentRunner` subscribes to `CLIENT_OUTPUT_DELTA` and `CLIENT_INPUT_REQUESTED` to print output and read stdin. In TUI mode, `AdenTUI` subscribes to route events to UI widgets.
|
||||
|
||||
## Storage Layout
|
||||
|
||||
```
|
||||
~/.hive/agents/{agent_name}/
|
||||
sessions/
|
||||
session_YYYYMMDD_HHMMSS_{uuid}/
|
||||
state.json # Session state (status, memory, progress)
|
||||
checkpoints/ # Node-boundary snapshots
|
||||
logs/
|
||||
summary.json # Execution summary
|
||||
details.jsonl # Detailed event log
|
||||
tool_logs.jsonl # Tool call log
|
||||
runtime_logs/ # Cross-session runtime logs
|
||||
```
|
||||
@@ -478,7 +478,7 @@ async def resume_session(
|
||||
|
||||
```bash
|
||||
# List resumable sessions
|
||||
hive sessions list --agent twitter_outreach --status failed
|
||||
hive sessions list --agent deep_research_agent --status failed
|
||||
|
||||
# Show checkpoints for a session
|
||||
hive sessions checkpoints session_20260208_143022_abc12345
|
||||
|
||||
@@ -224,7 +224,7 @@ Three MCP tools provide access to the logging system:
|
||||
|
||||
```python
|
||||
query_runtime_logs(
|
||||
agent_work_dir: str, # e.g., "~/.hive/agents/twitter_outreach"
|
||||
agent_work_dir: str, # e.g., "~/.hive/agents/deep_research_agent"
|
||||
status: str = "", # "needs_attention", "success", "failure", "degraded"
|
||||
limit: int = 20
|
||||
) -> dict # {"runs": [...], "total": int}
|
||||
@@ -371,14 +371,14 @@ query_runtime_log_raw(agent_work_dir, run_id)
|
||||
```python
|
||||
# 1. Find problematic runs (L1)
|
||||
result = query_runtime_logs(
|
||||
agent_work_dir="~/.hive/agents/twitter_outreach",
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
status="needs_attention"
|
||||
)
|
||||
run_id = result["runs"][0]["run_id"]
|
||||
|
||||
# 2. Identify failing nodes (L2)
|
||||
details = query_runtime_log_details(
|
||||
agent_work_dir="~/.hive/agents/twitter_outreach",
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
run_id=run_id,
|
||||
needs_attention_only=True
|
||||
)
|
||||
@@ -386,7 +386,7 @@ problem_node = details["nodes"][0]["node_id"]
|
||||
|
||||
# 3. Analyze root cause (L3)
|
||||
raw = query_runtime_log_raw(
|
||||
agent_work_dir="~/.hive/agents/twitter_outreach",
|
||||
agent_work_dir="~/.hive/agents/deep_research_agent",
|
||||
run_id=run_id,
|
||||
node_id=problem_node
|
||||
)
|
||||
@@ -496,7 +496,7 @@ logger.start_run(goal_id, session_id=execution_id)
|
||||
```json
|
||||
{
|
||||
"run_id": "session_20260206_115718_e22339c5",
|
||||
"goal_id": "twitter-outreach-multi-loop",
|
||||
"goal_id": "deep-research",
|
||||
"status": "degraded",
|
||||
"started_at": "2026-02-06T11:57:18.593081",
|
||||
"ended_at": "2026-02-06T11:58:45.123456",
|
||||
|
||||
@@ -7,8 +7,9 @@ while preserving the goal-driven approach.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -39,6 +40,11 @@ class AgentRuntimeConfig:
|
||||
max_history: int = 1000
|
||||
execution_result_max: int = 1000
|
||||
execution_result_ttl_seconds: float | None = None
|
||||
# Webhook server config (only starts if webhook_routes is non-empty)
|
||||
webhook_host: str = "127.0.0.1"
|
||||
webhook_port: int = 8080
|
||||
webhook_routes: list[dict] = field(default_factory=list)
|
||||
# Each dict: {"source_id": str, "path": str, "methods": ["POST"], "secret": str|None}
|
||||
|
||||
|
||||
class AgentRuntime:
|
||||
@@ -150,10 +156,22 @@ class AgentRuntime:
|
||||
self._entry_points: dict[str, EntryPointSpec] = {}
|
||||
self._streams: dict[str, ExecutionStream] = {}
|
||||
|
||||
# Webhook server (created on start if webhook_routes configured)
|
||||
self._webhook_server: Any = None
|
||||
# Event-driven entry point subscriptions
|
||||
self._event_subscriptions: list[str] = []
|
||||
# Timer tasks for scheduled entry points
|
||||
self._timer_tasks: list[asyncio.Task] = []
|
||||
# Next fire time for each timer entry point (ep_id -> datetime)
|
||||
self._timer_next_fire: dict[str, float] = {}
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Optional greeting shown to user on TUI load (set by AgentRunner)
|
||||
self.intro_message: str = ""
|
||||
|
||||
def register_entry_point(self, spec: EntryPointSpec) -> None:
|
||||
"""
|
||||
Register a named entry point for the agent.
|
||||
@@ -231,6 +249,131 @@ class AgentRuntime:
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
|
||||
# Start webhook server if routes are configured
|
||||
if self._config.webhook_routes:
|
||||
from framework.runtime.webhook_server import (
|
||||
WebhookRoute,
|
||||
WebhookServer,
|
||||
WebhookServerConfig,
|
||||
)
|
||||
|
||||
wh_config = WebhookServerConfig(
|
||||
host=self._config.webhook_host,
|
||||
port=self._config.webhook_port,
|
||||
)
|
||||
self._webhook_server = WebhookServer(self._event_bus, wh_config)
|
||||
|
||||
for rc in self._config.webhook_routes:
|
||||
route = WebhookRoute(
|
||||
source_id=rc["source_id"],
|
||||
path=rc["path"],
|
||||
methods=rc.get("methods", ["POST"]),
|
||||
secret=rc.get("secret"),
|
||||
)
|
||||
self._webhook_server.add_route(route)
|
||||
|
||||
await self._webhook_server.start()
|
||||
|
||||
# Subscribe event-driven entry points to EventBus
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
for ep_id, spec in self._entry_points.items():
|
||||
if spec.trigger_type != "event":
|
||||
continue
|
||||
|
||||
tc = spec.trigger_config
|
||||
event_types = [_ET(et) for et in tc.get("event_types", [])]
|
||||
if not event_types:
|
||||
logger.warning(
|
||||
f"Entry point '{ep_id}' has trigger_type='event' "
|
||||
"but no event_types in trigger_config"
|
||||
)
|
||||
continue
|
||||
|
||||
# Capture ep_id in closure
|
||||
def _make_handler(entry_point_id: str):
|
||||
async def _on_event(event):
|
||||
if self._running and entry_point_id in self._streams:
|
||||
# Run in the same session as the primary entry
|
||||
# point so memory (e.g. user-defined rules) is
|
||||
# shared and logs land in one session directory.
|
||||
session_state = self._get_primary_session_state(
|
||||
exclude_entry_point=entry_point_id
|
||||
)
|
||||
await self.trigger(
|
||||
entry_point_id,
|
||||
{"event": event.to_dict()},
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
return _on_event
|
||||
|
||||
sub_id = self._event_bus.subscribe(
|
||||
event_types=event_types,
|
||||
handler=_make_handler(ep_id),
|
||||
filter_stream=tc.get("filter_stream"),
|
||||
filter_node=tc.get("filter_node"),
|
||||
)
|
||||
self._event_subscriptions.append(sub_id)
|
||||
|
||||
# Start timer-driven entry points
|
||||
for ep_id, spec in self._entry_points.items():
|
||||
if spec.trigger_type != "timer":
|
||||
continue
|
||||
|
||||
tc = spec.trigger_config
|
||||
interval = tc.get("interval_minutes")
|
||||
if not interval or interval <= 0:
|
||||
logger.warning(
|
||||
f"Entry point '{ep_id}' has trigger_type='timer' "
|
||||
"but no valid interval_minutes in trigger_config"
|
||||
)
|
||||
continue
|
||||
|
||||
run_immediately = tc.get("run_immediately", False)
|
||||
|
||||
def _make_timer(entry_point_id: str, mins: float, immediate: bool):
|
||||
async def _timer_loop():
|
||||
interval_secs = mins * 60
|
||||
if not immediate:
|
||||
self._timer_next_fire[entry_point_id] = time.monotonic() + interval_secs
|
||||
await asyncio.sleep(interval_secs)
|
||||
while self._running:
|
||||
self._timer_next_fire.pop(entry_point_id, None)
|
||||
try:
|
||||
session_state = self._get_primary_session_state(
|
||||
exclude_entry_point=entry_point_id
|
||||
)
|
||||
await self.trigger(
|
||||
entry_point_id,
|
||||
{"event": {"source": "timer", "reason": "scheduled"}},
|
||||
session_state=session_state,
|
||||
)
|
||||
logger.info(
|
||||
"Timer fired for entry point '%s' (next in %s min)",
|
||||
entry_point_id,
|
||||
mins,
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Timer trigger failed for '%s'",
|
||||
entry_point_id,
|
||||
exc_info=True,
|
||||
)
|
||||
self._timer_next_fire[entry_point_id] = time.monotonic() + interval_secs
|
||||
await asyncio.sleep(interval_secs)
|
||||
|
||||
return _timer_loop
|
||||
|
||||
task = asyncio.create_task(_make_timer(ep_id, interval, run_immediately)())
|
||||
self._timer_tasks.append(task)
|
||||
logger.info(
|
||||
"Started timer for entry point '%s' every %s min%s",
|
||||
ep_id,
|
||||
interval,
|
||||
" (immediate first run)" if run_immediately else "",
|
||||
)
|
||||
|
||||
self._running = True
|
||||
logger.info(f"AgentRuntime started with {len(self._streams)} streams")
|
||||
|
||||
@@ -240,6 +383,21 @@ class AgentRuntime:
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
# Cancel timer tasks
|
||||
for task in self._timer_tasks:
|
||||
task.cancel()
|
||||
self._timer_tasks.clear()
|
||||
|
||||
# Unsubscribe event-driven entry points
|
||||
for sub_id in self._event_subscriptions:
|
||||
self._event_bus.unsubscribe(sub_id)
|
||||
self._event_subscriptions.clear()
|
||||
|
||||
# Stop webhook server
|
||||
if self._webhook_server:
|
||||
await self._webhook_server.stop()
|
||||
self._webhook_server = None
|
||||
|
||||
# Stop all streams
|
||||
for stream in self._streams.values():
|
||||
await stream.stop()
|
||||
@@ -311,6 +469,66 @@ class AgentRuntime:
|
||||
raise ValueError(f"Entry point '{entry_point_id}' not found")
|
||||
return await stream.wait_for_completion(exec_id, timeout)
|
||||
|
||||
def _get_primary_session_state(self, exclude_entry_point: str) -> dict[str, Any] | None:
|
||||
"""Build session_state so an async entry point runs in the primary session.
|
||||
|
||||
Looks for an active execution from another stream (the "primary"
|
||||
session, e.g. the user-facing intake loop) and returns a
|
||||
``session_state`` dict containing:
|
||||
|
||||
- ``resume_session_id``: reuse the same session directory
|
||||
- ``memory``: only the keys that the async entry node declares
|
||||
as inputs (e.g. ``rules``, ``max_emails``). Stale outputs
|
||||
from previous runs (``emails``, ``actions_taken``, …) are
|
||||
excluded so each trigger starts fresh.
|
||||
|
||||
The memory is read from the primary session's ``state.json``
|
||||
which is kept up-to-date by ``GraphExecutor._write_progress()``
|
||||
at every node transition.
|
||||
|
||||
Returns ``None`` if no primary session is active (the webhook
|
||||
execution will just create its own session).
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
# Determine which memory keys the async entry node needs.
|
||||
allowed_keys: set[str] | None = None
|
||||
ep_spec = self._entry_points.get(exclude_entry_point)
|
||||
if ep_spec:
|
||||
entry_node = self.graph.get_node(ep_spec.entry_node)
|
||||
if entry_node and entry_node.input_keys:
|
||||
allowed_keys = set(entry_node.input_keys)
|
||||
|
||||
for ep_id, stream in self._streams.items():
|
||||
if ep_id == exclude_entry_point:
|
||||
continue
|
||||
for exec_id in stream.active_execution_ids:
|
||||
state_path = self._storage.base_path / "sessions" / exec_id / "state.json"
|
||||
try:
|
||||
if state_path.exists():
|
||||
data = _json.loads(state_path.read_text(encoding="utf-8"))
|
||||
full_memory = data.get("memory", {})
|
||||
if not full_memory:
|
||||
continue
|
||||
# Filter to only input keys so stale outputs
|
||||
# from previous triggers don't leak through.
|
||||
if allowed_keys is not None:
|
||||
memory = {k: v for k, v in full_memory.items() if k in allowed_keys}
|
||||
else:
|
||||
memory = full_memory
|
||||
if memory:
|
||||
return {
|
||||
"resume_session_id": exec_id,
|
||||
"memory": memory,
|
||||
}
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not read state.json for %s: skipping",
|
||||
exec_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def inject_input(self, node_id: str, content: str) -> bool:
|
||||
"""Inject user input into a running client-facing node.
|
||||
|
||||
@@ -445,6 +663,11 @@ class AgentRuntime:
|
||||
"""Access the outcome aggregator."""
|
||||
return self._outcome_aggregator
|
||||
|
||||
@property
|
||||
def webhook_server(self) -> Any:
|
||||
"""Access the webhook server (None if no webhook entry points)."""
|
||||
return self._webhook_server
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if runtime is running."""
|
||||
|
||||
@@ -56,6 +56,12 @@ class Runtime:
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: str | Path):
|
||||
# Validate and create storage path if needed
|
||||
path = Path(storage_path) if isinstance(storage_path, str) else storage_path
|
||||
if not path.exists():
|
||||
logger.warning(f"Storage path does not exist, creating: {path}")
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.storage = FileStorage(storage_path)
|
||||
self._current_run: Run | None = None
|
||||
self._current_node: str = "unknown"
|
||||
|
||||
@@ -62,6 +62,23 @@ class EventType(StrEnum):
|
||||
NODE_INTERNAL_OUTPUT = "node_internal_output"
|
||||
NODE_INPUT_BLOCKED = "node_input_blocked"
|
||||
NODE_STALLED = "node_stalled"
|
||||
NODE_TOOL_DOOM_LOOP = "node_tool_doom_loop"
|
||||
|
||||
# Judge decisions
|
||||
JUDGE_VERDICT = "judge_verdict"
|
||||
|
||||
# Output tracking
|
||||
OUTPUT_KEY_SET = "output_key_set"
|
||||
|
||||
# Retry / edge tracking
|
||||
NODE_RETRY = "node_retry"
|
||||
EDGE_TRAVERSED = "edge_traversed"
|
||||
|
||||
# Context management
|
||||
CONTEXT_COMPACTED = "context_compacted"
|
||||
|
||||
# External triggers
|
||||
WEBHOOK_RECEIVED = "webhook_received"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
@@ -615,6 +632,24 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_tool_doom_loop(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
description: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool doom loop detection event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_TOOL_DOOM_LOOP,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"description": description},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_input_blocked(
|
||||
self,
|
||||
stream_id: str,
|
||||
@@ -633,6 +668,158 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
# === JUDGE / OUTPUT / RETRY / EDGE PUBLISHERS ===
|
||||
|
||||
async def emit_judge_verdict(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
action: str,
|
||||
feedback: str = "",
|
||||
judge_type: str = "implicit",
|
||||
iteration: int = 0,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit judge verdict event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.JUDGE_VERDICT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"action": action,
|
||||
"feedback": feedback,
|
||||
"judge_type": judge_type,
|
||||
"iteration": iteration,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_output_key_set(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
key: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit output key set event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.OUTPUT_KEY_SET,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"key": key},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_retry(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
retry_count: int,
|
||||
max_retries: int,
|
||||
error: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node retry event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_RETRY,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"retry_count": retry_count,
|
||||
"max_retries": max_retries,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_edge_traversed(
|
||||
self,
|
||||
stream_id: str,
|
||||
source_node: str,
|
||||
target_node: str,
|
||||
edge_condition: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit edge traversed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EDGE_TRAVERSED,
|
||||
stream_id=stream_id,
|
||||
node_id=source_node,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"source_node": source_node,
|
||||
"target_node": target_node,
|
||||
"edge_condition": edge_condition,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_execution_paused(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
reason: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution paused event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_PAUSED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"reason": reason},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_execution_resumed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution resumed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_RESUMED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_webhook_received(
|
||||
self,
|
||||
source_id: str,
|
||||
path: str,
|
||||
method: str,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
query_params: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Emit webhook received event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.WEBHOOK_RECEIVED,
|
||||
stream_id=source_id,
|
||||
data={
|
||||
"path": path,
|
||||
"method": method,
|
||||
"headers": headers,
|
||||
"payload": payload,
|
||||
"query_params": query_params or {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
|
||||
@@ -196,6 +196,11 @@ class ExecutionStream:
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def active_execution_ids(self) -> list[str]:
|
||||
"""Return IDs of all currently active executions."""
|
||||
return list(self._active_executions.keys())
|
||||
|
||||
def _record_execution_result(self, execution_id: str, result: ExecutionResult) -> None:
|
||||
"""Record a completed execution result with retention pruning."""
|
||||
self._execution_results[execution_id] = result
|
||||
@@ -293,8 +298,13 @@ class ExecutionStream:
|
||||
if not self._running:
|
||||
raise RuntimeError(f"ExecutionStream '{self.stream_id}' is not running")
|
||||
|
||||
# Generate execution ID using unified session format
|
||||
if self._session_store:
|
||||
# When resuming, reuse the original session ID so the execution
|
||||
# continues in the same session directory instead of creating a new one.
|
||||
resume_session_id = session_state.get("resume_session_id") if session_state else None
|
||||
|
||||
if resume_session_id:
|
||||
execution_id = resume_session_id
|
||||
elif self._session_store:
|
||||
execution_id = self._session_store.generate_session_id()
|
||||
else:
|
||||
# Fallback to old format if SessionStore not available (shouldn't happen)
|
||||
@@ -337,6 +347,11 @@ class ExecutionStream:
|
||||
"""Run a single execution within the stream."""
|
||||
execution_id = ctx.id
|
||||
|
||||
# When sharing a session with another entry point (resume_session_id),
|
||||
# skip writing initial/final session state — the primary execution
|
||||
# owns the state.json and _write_progress() keeps memory up-to-date.
|
||||
_is_shared_session = bool(ctx.session_state and ctx.session_state.get("resume_session_id"))
|
||||
|
||||
# Acquire semaphore to limit concurrency
|
||||
async with self._semaphore:
|
||||
ctx.status = "running"
|
||||
@@ -399,7 +414,8 @@ class ExecutionStream:
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
# Write initial session state
|
||||
await self._write_session_state(execution_id, ctx)
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx)
|
||||
|
||||
# Create modified graph with entry point
|
||||
# We need to override the entry_node to use our entry point
|
||||
@@ -433,8 +449,9 @@ class ExecutionStream:
|
||||
if result.paused_at:
|
||||
ctx.status = "paused"
|
||||
|
||||
# Write final session state
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
# Write final session state (skip for shared-session executions)
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
|
||||
# Emit completion/failure event
|
||||
if self._event_bus:
|
||||
@@ -485,11 +502,14 @@ class ExecutionStream:
|
||||
# Store result with retention
|
||||
self._record_execution_result(execution_id, result)
|
||||
|
||||
# Write session state
|
||||
if has_result and result.paused_at:
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
else:
|
||||
await self._write_session_state(execution_id, ctx, error="Execution cancelled")
|
||||
# Write session state (skip for shared-session executions)
|
||||
if not _is_shared_session:
|
||||
if has_result and result.paused_at:
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
else:
|
||||
await self._write_session_state(
|
||||
execution_id, ctx, error="Execution cancelled"
|
||||
)
|
||||
|
||||
# Don't re-raise - we've handled it and saved state
|
||||
|
||||
@@ -506,8 +526,9 @@ class ExecutionStream:
|
||||
),
|
||||
)
|
||||
|
||||
# Write error session state
|
||||
await self._write_session_state(execution_id, ctx, error=str(e))
|
||||
# Write error session state (skip for shared-session executions)
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx, error=str(e))
|
||||
|
||||
# End run with failure (for observability)
|
||||
try:
|
||||
@@ -597,10 +618,22 @@ class ExecutionStream:
|
||||
entry_point=self.entry_spec.id,
|
||||
)
|
||||
else:
|
||||
# Create initial state
|
||||
from framework.schemas.session_state import SessionTimestamps
|
||||
# Create initial state — when resuming, preserve the previous
|
||||
# execution's progress so crashes don't lose track of state.
|
||||
from framework.schemas.session_state import (
|
||||
SessionProgress,
|
||||
SessionTimestamps,
|
||||
)
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
ss = ctx.session_state or {}
|
||||
progress = SessionProgress(
|
||||
current_node=ss.get("paused_at") or ss.get("resume_from"),
|
||||
paused_at=ss.get("paused_at"),
|
||||
resume_from=ss.get("paused_at") or ss.get("resume_from"),
|
||||
path=ss.get("execution_path", []),
|
||||
node_visit_counts=ss.get("node_visit_counts", {}),
|
||||
)
|
||||
state = SessionState(
|
||||
session_id=execution_id,
|
||||
stream_id=self.stream_id,
|
||||
@@ -613,6 +646,8 @@ class ExecutionStream:
|
||||
started_at=ctx.started_at.isoformat(),
|
||||
updated_at=now,
|
||||
),
|
||||
progress=progress,
|
||||
memory=ss.get("memory", {}),
|
||||
input_data=ctx.input_data,
|
||||
)
|
||||
|
||||
@@ -629,20 +664,35 @@ class ExecutionStream:
|
||||
logger.error(f"Failed to write state.json for {execution_id}: {e}")
|
||||
|
||||
def _create_modified_graph(self) -> "GraphSpec":
|
||||
"""Create a graph with the entry point overridden."""
|
||||
# Use the existing graph but override entry_node
|
||||
"""Create a graph with the entry point overridden.
|
||||
|
||||
Preserves the original graph's entry_points and async_entry_points
|
||||
so that validation correctly considers ALL entry nodes reachable.
|
||||
Each stream only executes from its own entry_node, but the full
|
||||
graph must validate with all entry points accounted for.
|
||||
"""
|
||||
from framework.graph.edge import GraphSpec
|
||||
|
||||
# Create a copy with modified entry node
|
||||
# Merge entry points: this stream's entry + original graph's primary
|
||||
# entry + any other entry points. This ensures all nodes are
|
||||
# reachable during validation even though this stream only starts
|
||||
# from self.entry_spec.entry_node.
|
||||
merged_entry_points = {
|
||||
"start": self.entry_spec.entry_node,
|
||||
}
|
||||
# Preserve the original graph's primary entry node
|
||||
if self.graph.entry_node:
|
||||
merged_entry_points["primary"] = self.graph.entry_node
|
||||
# Include any explicitly defined entry points from the graph
|
||||
merged_entry_points.update(self.graph.entry_points)
|
||||
|
||||
return GraphSpec(
|
||||
id=self.graph.id,
|
||||
goal_id=self.graph.goal_id,
|
||||
version=self.graph.version,
|
||||
entry_node=self.entry_spec.entry_node, # Use our entry point
|
||||
entry_points={
|
||||
"start": self.entry_spec.entry_node,
|
||||
**self.graph.entry_points,
|
||||
},
|
||||
entry_points=merged_entry_points,
|
||||
async_entry_points=self.graph.async_entry_points,
|
||||
terminal_nodes=self.graph.terminal_nodes,
|
||||
pause_nodes=self.graph.pause_nodes,
|
||||
nodes=self.graph.nodes,
|
||||
@@ -651,6 +701,9 @@ class ExecutionStream:
|
||||
max_tokens=self.graph.max_tokens,
|
||||
max_steps=self.graph.max_steps,
|
||||
cleanup_llm_model=self.graph.cleanup_llm_model,
|
||||
loop_config=self.graph.loop_config,
|
||||
conversation_mode=self.graph.conversation_mode,
|
||||
identity_prompt=self.graph.identity_prompt,
|
||||
)
|
||||
|
||||
async def wait_for_completion(
|
||||
|
||||
@@ -0,0 +1,717 @@
|
||||
"""
|
||||
Tests for WebhookServer and event-driven entry points.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac as hmac_mod
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.webhook_server import (
|
||||
WebhookRoute,
|
||||
WebhookServer,
|
||||
WebhookServerConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_server(event_bus: EventBus, routes: list[WebhookRoute] | None = None):
|
||||
"""Helper to create a WebhookServer with port=0 for OS-assigned port."""
|
||||
config = WebhookServerConfig(host="127.0.0.1", port=0)
|
||||
server = WebhookServer(event_bus, config)
|
||||
for route in routes or []:
|
||||
server.add_route(route)
|
||||
return server
|
||||
|
||||
|
||||
def _base_url(server: WebhookServer) -> str:
|
||||
"""Get the base URL for a running server."""
|
||||
return f"http://127.0.0.1:{server.port}"
|
||||
|
||||
|
||||
class TestWebhookServerLifecycle:
|
||||
"""Tests for server start/stop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="test", path="/webhooks/test", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
|
||||
await server.start()
|
||||
assert server.is_running
|
||||
assert server.port is not None
|
||||
|
||||
await server.stop()
|
||||
assert not server.is_running
|
||||
assert server.port is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_routes_skips_start(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(bus) # no routes
|
||||
|
||||
await server.start()
|
||||
assert not server.is_running
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_not_started(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(bus)
|
||||
|
||||
# Should be a no-op, not raise
|
||||
await server.stop()
|
||||
assert not server.is_running
|
||||
|
||||
|
||||
class TestWebhookEventPublishing:
|
||||
"""Tests for HTTP request -> EventBus event publishing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_publishes_webhook_received(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="gh", path="/webhooks/github", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/github",
|
||||
json={"action": "opened", "number": 42},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
body = await resp.json()
|
||||
assert body["status"] == "accepted"
|
||||
|
||||
# Give event bus time to dispatch
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
event = received[0]
|
||||
assert event.type == EventType.WEBHOOK_RECEIVED
|
||||
assert event.stream_id == "gh"
|
||||
assert event.data["path"] == "/webhooks/github"
|
||||
assert event.data["method"] == "POST"
|
||||
assert event.data["payload"] == {"action": "opened", "number": 42}
|
||||
assert isinstance(event.data["headers"], dict)
|
||||
assert event.data["query_params"] == {}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_params_included(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="hook", path="/webhooks/hook", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/hook?source=test&v=2",
|
||||
json={"data": "hello"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["query_params"] == {"source": "test", "v": "2"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_json_body(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="raw", path="/webhooks/raw", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/raw",
|
||||
data=b"plain text body",
|
||||
headers={"Content-Type": "text/plain"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["payload"] == {"raw_body": "plain text body"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_body(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="empty", path="/webhooks/empty", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{_base_url(server)}/webhooks/empty") as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["payload"] == {}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_routes(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="a", path="/webhooks/a", methods=["POST"]),
|
||||
WebhookRoute(source_id="b", path="/webhooks/b", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/a", json={"from": "a"}
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/b", json={"from": "b"}
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 2
|
||||
stream_ids = {e.stream_id for e in received}
|
||||
assert stream_ids == {"a", "b"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_stream_subscription(self):
|
||||
"""Subscribers can filter by stream_id (source_id)."""
|
||||
bus = EventBus()
|
||||
a_events = []
|
||||
b_events = []
|
||||
|
||||
async def handle_a(event):
|
||||
a_events.append(event)
|
||||
|
||||
async def handle_b(event):
|
||||
b_events.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handle_a, filter_stream="a")
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handle_b, filter_stream="b")
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="a", path="/webhooks/a", methods=["POST"]),
|
||||
WebhookRoute(source_id="b", path="/webhooks/b", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(f"{_base_url(server)}/webhooks/a", json={"x": 1})
|
||||
await session.post(f"{_base_url(server)}/webhooks/b", json={"x": 2})
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(a_events) == 1
|
||||
assert a_events[0].data["payload"] == {"x": 1}
|
||||
assert len(b_events) == 1
|
||||
assert b_events[0].data["payload"] == {"x": 2}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
class TestHMACVerification:
|
||||
"""Tests for HMAC-SHA256 signature verification."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_signature_accepted(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
secret = "test-secret-key"
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret=secret,
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
body = json.dumps({"event": "push"}).encode()
|
||||
sig = hmac_mod.new(secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Hub-Signature-256": f"sha256={sig}",
|
||||
},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 1
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_signature_rejected(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret="real-secret",
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
json={"event": "push"},
|
||||
headers={"X-Hub-Signature-256": "sha256=invalidsignature"},
|
||||
) as resp:
|
||||
assert resp.status == 401
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 0 # No event published
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_signature_rejected(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret="my-secret",
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# No X-Hub-Signature-256 header
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
json={"event": "push"},
|
||||
) as resp:
|
||||
assert resp.status == 401
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 0
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_secret_skips_verification(self):
|
||||
"""Routes without a secret accept any request."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="open",
|
||||
path="/webhooks/open",
|
||||
methods=["POST"],
|
||||
secret=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/open",
|
||||
json={"data": "test"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 1
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
class TestEventDrivenEntryPoints:
|
||||
"""Tests for event-driven entry points wired through AgentRuntime."""
|
||||
|
||||
def _make_graph_and_goal(self):
|
||||
"""Minimal graph + goal for testing entry point triggering."""
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import SuccessCriterion
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="process-event",
|
||||
name="Process Event",
|
||||
description="Process incoming event",
|
||||
node_type="llm_generate",
|
||||
input_keys=["event"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
]
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id="test-goal",
|
||||
version="1.0.0",
|
||||
entry_node="process-event",
|
||||
entry_points={"start": "process-event"},
|
||||
async_entry_points=[],
|
||||
terminal_nodes=[],
|
||||
pause_nodes=[],
|
||||
nodes=nodes,
|
||||
edges=[],
|
||||
)
|
||||
goal = Goal(
|
||||
id="test-goal",
|
||||
name="Test Goal",
|
||||
description="Test",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="sc-1",
|
||||
description="Done",
|
||||
metric="done",
|
||||
target="yes",
|
||||
weight=1.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
return graph, goal
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_subscribes_to_bus(self):
|
||||
"""Entry point with trigger_type='event' subscribes and triggers on matching events."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
config = AgentRuntimeConfig(
|
||||
webhook_host="127.0.0.1",
|
||||
webhook_port=0,
|
||||
webhook_routes=[
|
||||
{"source_id": "gh", "path": "/webhooks/github"},
|
||||
],
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
config=config,
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="gh-handler",
|
||||
name="GitHub Handler",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["webhook_received"],
|
||||
"filter_stream": "gh",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
assert runtime.webhook_server is not None
|
||||
assert runtime.webhook_server.is_running
|
||||
|
||||
port = runtime.webhook_server.port
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/github",
|
||||
json={"action": "push", "ref": "main"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
ep_id, data = trigger_calls[0]
|
||||
assert ep_id == "gh-handler"
|
||||
assert "event" in data
|
||||
assert data["event"]["type"] == "webhook_received"
|
||||
assert data["event"]["stream_id"] == "gh"
|
||||
assert data["event"]["data"]["payload"] == {
|
||||
"action": "push",
|
||||
"ref": "main",
|
||||
}
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
assert runtime.webhook_server is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_filter_stream(self):
|
||||
"""Entry point only triggers for matching stream_id (source_id)."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
config = AgentRuntimeConfig(
|
||||
webhook_routes=[
|
||||
{"source_id": "github", "path": "/webhooks/github"},
|
||||
{"source_id": "stripe", "path": "/webhooks/stripe"},
|
||||
],
|
||||
webhook_port=0,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
config=config,
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="gh-only",
|
||||
name="GitHub Only",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["webhook_received"],
|
||||
"filter_stream": "github",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
port = runtime.webhook_server.port
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# POST to stripe — should NOT trigger
|
||||
await session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/stripe",
|
||||
json={"type": "payment"},
|
||||
)
|
||||
# POST to github — should trigger
|
||||
await session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/github",
|
||||
json={"action": "opened"},
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
assert trigger_calls[0][0] == "gh-only"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_webhook_routes_skips_server(self):
|
||||
"""Runtime without webhook_routes does not start a webhook server."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="manual",
|
||||
name="Manual",
|
||||
entry_node="process-event",
|
||||
trigger_type="manual",
|
||||
)
|
||||
)
|
||||
|
||||
await runtime.start()
|
||||
try:
|
||||
assert runtime.webhook_server is None
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_custom_event(self):
|
||||
"""Entry point can subscribe to CUSTOM events, not just webhooks."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="custom-handler",
|
||||
name="Custom Handler",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["custom"],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
await runtime.event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CUSTOM,
|
||||
stream_id="some-source",
|
||||
data={"key": "value"},
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
assert trigger_calls[0][0] == "custom-handler"
|
||||
assert trigger_calls[0][1]["event"]["type"] == "custom"
|
||||
assert trigger_calls[0][1]["event"]["data"]["key"] == "value"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Webhook HTTP Server - Receives HTTP requests and publishes them as EventBus events.
|
||||
|
||||
Only starts if webhook-type entry points are registered. Uses aiohttp for
|
||||
a lightweight embedded HTTP server that runs within the existing asyncio loop.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebhookRoute:
|
||||
"""A registered webhook route derived from an EntryPointSpec."""
|
||||
|
||||
source_id: str
|
||||
path: str
|
||||
methods: list[str]
|
||||
secret: str | None = None # For HMAC-SHA256 signature verification
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebhookServerConfig:
|
||||
"""Configuration for the webhook HTTP server."""
|
||||
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8080
|
||||
|
||||
|
||||
class WebhookServer:
|
||||
"""
|
||||
Embedded HTTP server that receives webhook requests and publishes
|
||||
them as WEBHOOK_RECEIVED events on the EventBus.
|
||||
|
||||
The server's only job is: receive HTTP -> publish AgentEvent.
|
||||
Subscribers decide what to do with the event.
|
||||
|
||||
Lifecycle:
|
||||
server = WebhookServer(event_bus, config)
|
||||
server.add_route(WebhookRoute(...))
|
||||
await server.start()
|
||||
# ... server running ...
|
||||
await server.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_bus: EventBus,
|
||||
config: WebhookServerConfig | None = None,
|
||||
):
|
||||
self._event_bus = event_bus
|
||||
self._config = config or WebhookServerConfig()
|
||||
self._routes: dict[str, WebhookRoute] = {} # path -> route
|
||||
self._app: web.Application | None = None
|
||||
self._runner: web.AppRunner | None = None
|
||||
self._site: web.TCPSite | None = None
|
||||
|
||||
def add_route(self, route: WebhookRoute) -> None:
|
||||
"""Register a webhook route."""
|
||||
self._routes[route.path] = route
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the HTTP server. No-op if no routes registered."""
|
||||
if not self._routes:
|
||||
logger.debug("No webhook routes registered, skipping server start")
|
||||
return
|
||||
|
||||
self._app = web.Application()
|
||||
|
||||
for path, route in self._routes.items():
|
||||
for method in route.methods:
|
||||
self._app.router.add_route(method, path, self._handle_request)
|
||||
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
self._site = web.TCPSite(
|
||||
self._runner,
|
||||
self._config.host,
|
||||
self._config.port,
|
||||
)
|
||||
await self._site.start()
|
||||
logger.info(
|
||||
f"Webhook server started on {self._config.host}:{self._config.port} "
|
||||
f"with {len(self._routes)} route(s)"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the HTTP server gracefully."""
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._app = None
|
||||
self._site = None
|
||||
logger.info("Webhook server stopped")
|
||||
|
||||
async def _handle_request(self, request: web.Request) -> web.Response:
|
||||
"""Handle an incoming webhook request."""
|
||||
path = request.path
|
||||
route = self._routes.get(path)
|
||||
|
||||
if route is None:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
|
||||
# Read body
|
||||
try:
|
||||
body = await request.read()
|
||||
except Exception:
|
||||
return web.json_response(
|
||||
{"error": "Failed to read request body"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Verify HMAC signature if secret is configured
|
||||
if route.secret:
|
||||
if not self._verify_signature(request, body, route.secret):
|
||||
return web.json_response({"error": "Invalid signature"}, status=401)
|
||||
|
||||
# Parse body as JSON (fall back to raw text for non-JSON)
|
||||
try:
|
||||
payload = json.loads(body) if body else {}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = {"raw_body": body.decode("utf-8", errors="replace")}
|
||||
|
||||
# Publish event to bus
|
||||
await self._event_bus.emit_webhook_received(
|
||||
source_id=route.source_id,
|
||||
path=path,
|
||||
method=request.method,
|
||||
headers=dict(request.headers),
|
||||
payload=payload,
|
||||
query_params=dict(request.query),
|
||||
)
|
||||
|
||||
return web.json_response({"status": "accepted"}, status=202)
|
||||
|
||||
def _verify_signature(
|
||||
self,
|
||||
request: web.Request,
|
||||
body: bytes,
|
||||
secret: str,
|
||||
) -> bool:
|
||||
"""Verify HMAC-SHA256 signature from X-Hub-Signature-256 header."""
|
||||
signature_header = request.headers.get("X-Hub-Signature-256", "")
|
||||
if not signature_header.startswith("sha256="):
|
||||
return False
|
||||
|
||||
expected_sig = signature_header[7:] # strip "sha256="
|
||||
computed_sig = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
body,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
return hmac.compare_digest(expected_sig, computed_sig)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the server is running."""
|
||||
return self._site is not None
|
||||
|
||||
@property
|
||||
def port(self) -> int | None:
|
||||
"""Return the actual listening port (useful when configured with port=0)."""
|
||||
if self._site and self._site._server and self._site._server.sockets:
|
||||
return self._site._server.sockets[0].getsockname()[1]
|
||||
return None
|
||||
@@ -156,8 +156,13 @@ class SessionState(BaseModel):
|
||||
@computed_field
|
||||
@property
|
||||
def is_resumable(self) -> bool:
|
||||
"""Can this session be resumed?"""
|
||||
return self.status == SessionStatus.PAUSED and self.progress.resume_from is not None
|
||||
"""Can this session be resumed?
|
||||
|
||||
Every non-completed session is resumable. If resume_from/paused_at
|
||||
aren't set, the executor falls back to the graph entry point —
|
||||
so we don't gate on those. Even catastrophic failures are resumable.
|
||||
"""
|
||||
return self.status != SessionStatus.COMPLETED
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
@@ -279,9 +284,16 @@ class SessionState(BaseModel):
|
||||
|
||||
def to_session_state_dict(self) -> dict[str, Any]:
|
||||
"""Convert to session_state format for GraphExecutor.execute()."""
|
||||
# Derive resume target: explicit > last node in path > entry point
|
||||
resume_from = (
|
||||
self.progress.resume_from
|
||||
or self.progress.paused_at
|
||||
or (self.progress.path[-1] if self.progress.path else None)
|
||||
)
|
||||
return {
|
||||
"paused_at": self.progress.paused_at,
|
||||
"resume_from": self.progress.resume_from,
|
||||
"paused_at": resume_from,
|
||||
"resume_from": resume_from,
|
||||
"memory": self.memory,
|
||||
"next_node": None,
|
||||
"execution_path": self.progress.path,
|
||||
"node_visit_counts": self.progress.node_visit_counts,
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ class SessionStore:
|
||||
Initialize session store.
|
||||
|
||||
Args:
|
||||
base_path: Base path for storage (e.g., ~/.hive/agents/twitter_outreach)
|
||||
base_path: Base path for storage (e.g., ~/.hive/agents/deep_research_agent)
|
||||
"""
|
||||
self.base_path = Path(base_path)
|
||||
self.sessions_dir = self.base_path / "sessions"
|
||||
|
||||
@@ -3,6 +3,10 @@ Pytest templates for test file generation.
|
||||
|
||||
These templates provide headers and fixtures for pytest-compatible async tests.
|
||||
Tests are written to exports/{agent}/tests/ as Python files and run with pytest.
|
||||
|
||||
Tests use AgentRunner.load() — the canonical runtime path — which creates
|
||||
AgentRuntime, ExecutionStream, and proper session/log storage. For agents
|
||||
with client-facing nodes, an auto_responder fixture handles input injection.
|
||||
"""
|
||||
|
||||
# Template for the test file header (imports and fixtures)
|
||||
@@ -11,17 +15,19 @@ PYTEST_TEST_FILE_HEADER = '''"""
|
||||
|
||||
{description}
|
||||
|
||||
REQUIRES: API_KEY (OpenAI or Anthropic) for real testing.
|
||||
REQUIRES: API_KEY for execution tests. Structure tests run without keys.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from {agent_module} import default_agent
|
||||
from pathlib import Path
|
||||
|
||||
# Agent path resolved from this test file's location
|
||||
AGENT_PATH = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def _get_api_key():
|
||||
"""Get API key from CredentialStoreAdapter or environment."""
|
||||
# 1. Try CredentialStoreAdapter for Anthropic
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
creds = CredentialStoreAdapter.default()
|
||||
@@ -29,28 +35,43 @@ def _get_api_key():
|
||||
return creds.get("anthropic")
|
||||
except (ImportError, KeyError):
|
||||
pass
|
||||
|
||||
# 2. Fallback to standard environment variables for OpenAI and others
|
||||
return (
|
||||
os.environ.get("OPENAI_API_KEY") or
|
||||
os.environ.get("ANTHROPIC_API_KEY") or
|
||||
os.environ.get("CEREBRAS_API_KEY") or
|
||||
os.environ.get("GROQ_API_KEY")
|
||||
os.environ.get("GROQ_API_KEY") or
|
||||
os.environ.get("GEMINI_API_KEY")
|
||||
)
|
||||
|
||||
|
||||
# Skip all tests if no API key and not in mock mode
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _get_api_key() and not os.environ.get("MOCK_MODE"),
|
||||
reason="API key required. Please set OPENAI_API_KEY, ANTHROPIC_API_KEY, or use MOCK_MODE=1."
|
||||
reason="API key required. Set ANTHROPIC_API_KEY or use MOCK_MODE=1 for structure tests."
|
||||
)
|
||||
'''
|
||||
|
||||
# Template for conftest.py with shared fixtures
|
||||
PYTEST_CONFTEST_TEMPLATE = '''"""Shared test fixtures for {agent_name} tests."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add exports/ and core/ to sys.path so the agent package and framework are importable
|
||||
_repo_root = Path(__file__).resolve().parents[3]
|
||||
for _p in ["exports", "core"]:
|
||||
_path = str(_repo_root / _p)
|
||||
if _path not in sys.path:
|
||||
sys.path.insert(0, _path)
|
||||
|
||||
import pytest
|
||||
from framework.runner.runner import AgentRunner
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
AGENT_PATH = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def _get_api_key():
|
||||
@@ -62,19 +83,80 @@ def _get_api_key():
|
||||
return creds.get("anthropic")
|
||||
except (ImportError, KeyError):
|
||||
pass
|
||||
|
||||
return (
|
||||
os.environ.get("OPENAI_API_KEY") or
|
||||
os.environ.get("ANTHROPIC_API_KEY") or
|
||||
os.environ.get("CEREBRAS_API_KEY") or
|
||||
os.environ.get("GROQ_API_KEY")
|
||||
os.environ.get("GROQ_API_KEY") or
|
||||
os.environ.get("GEMINI_API_KEY")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_mode():
|
||||
"""Check if running in mock mode."""
|
||||
return bool(os.environ.get("MOCK_MODE"))
|
||||
"""Return True if running in mock mode (no API key or MOCK_MODE=1)."""
|
||||
if os.environ.get("MOCK_MODE"):
|
||||
return True
|
||||
return not bool(_get_api_key())
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def runner(tmp_path_factory, mock_mode):
|
||||
"""Create an AgentRunner using the canonical runtime path.
|
||||
|
||||
Uses tmp_path_factory for storage so tests don't pollute ~/.hive/agents/.
|
||||
Goes through AgentRunner.load() -> _setup() -> AgentRuntime, the same
|
||||
path as ``hive run``.
|
||||
"""
|
||||
storage = tmp_path_factory.mktemp("agent_storage")
|
||||
r = AgentRunner.load(
|
||||
AGENT_PATH,
|
||||
mock_mode=mock_mode,
|
||||
storage_path=storage,
|
||||
)
|
||||
r._setup()
|
||||
yield r
|
||||
await r.cleanup_async()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auto_responder(runner):
|
||||
"""Auto-respond to client-facing node input requests.
|
||||
|
||||
Subscribes to CLIENT_INPUT_REQUESTED events and injects a response
|
||||
to unblock the node. Customize the response before calling start():
|
||||
|
||||
auto_responder.response = "approve the report"
|
||||
await auto_responder.start()
|
||||
"""
|
||||
class AutoResponder:
|
||||
def __init__(self, runner_instance):
|
||||
self._runner = runner_instance
|
||||
self.response = "yes, proceed"
|
||||
self.interactions = []
|
||||
self._sub_id = None
|
||||
|
||||
async def start(self):
|
||||
runtime = self._runner._agent_runtime
|
||||
if runtime is None:
|
||||
return
|
||||
|
||||
async def _handle(event):
|
||||
self.interactions.append(event.node_id)
|
||||
await runtime.inject_input(event.node_id, self.response)
|
||||
|
||||
self._sub_id = runtime.subscribe_to_events(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=_handle,
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
runtime = self._runner._agent_runtime
|
||||
if self._sub_id and runtime:
|
||||
runtime.unsubscribe_from_events(self._sub_id)
|
||||
self._sub_id = None
|
||||
|
||||
return AutoResponder(runner)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -82,19 +164,51 @@ def check_api_key():
|
||||
"""Ensure API key is set for real testing."""
|
||||
if not _get_api_key():
|
||||
if os.environ.get("MOCK_MODE"):
|
||||
print("\\n⚠️ Running in MOCK MODE - structure validation only")
|
||||
print(" This does NOT test LLM behavior or agent quality")
|
||||
print(" Set OPENAI_API_KEY or ANTHROPIC_API_KEY for real testing\\n")
|
||||
print("\\n Running in MOCK MODE - structure validation only")
|
||||
print(" Set ANTHROPIC_API_KEY for real testing\\n")
|
||||
else:
|
||||
pytest.fail(
|
||||
"\\n❌ No API key found!\\n\\n"
|
||||
"Real testing requires an API key. Choose one:\\n"
|
||||
"1. Set OpenAI key:\\n"
|
||||
" export OPENAI_API_KEY='your-key-here'\\n"
|
||||
"2. Set Anthropic key:\\n"
|
||||
" export ANTHROPIC_API_KEY='your-key-here'\\n"
|
||||
"3. Run structure validation only:\\n"
|
||||
" MOCK_MODE=1 pytest exports/{agent_name}/tests/\\n\\n"
|
||||
"Note: Mock mode does NOT validate agent behavior or quality."
|
||||
"\\nNo API key found!\\n"
|
||||
"Set ANTHROPIC_API_KEY or use MOCK_MODE=1 for structure tests.\\n"
|
||||
)
|
||||
|
||||
|
||||
def parse_json_from_output(result, key):
|
||||
"""Parse JSON from agent output (framework may store full LLM response as string)."""
|
||||
val = result.output.get(key, "")
|
||||
if isinstance(val, (dict, list)):
|
||||
return val
|
||||
if isinstance(val, str):
|
||||
json_text = re.sub(r"```json\\s*|\\s*```", "", val).strip()
|
||||
try:
|
||||
return json.loads(json_text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return val
|
||||
return val
|
||||
|
||||
|
||||
def safe_get_nested(result, key_path, default=None):
|
||||
"""Safely get nested value from result.output."""
|
||||
output = result.output or {{}}
|
||||
current = output
|
||||
for key in key_path:
|
||||
if isinstance(current, dict):
|
||||
current = current.get(key)
|
||||
elif isinstance(current, str):
|
||||
try:
|
||||
json_text = re.sub(r"```json\\s*|\\s*```", "", current).strip()
|
||||
parsed = json.loads(json_text)
|
||||
if isinstance(parsed, dict):
|
||||
current = parsed.get(key)
|
||||
else:
|
||||
return default
|
||||
except json.JSONDecodeError:
|
||||
return default
|
||||
else:
|
||||
return default
|
||||
return current if current is not None else default
|
||||
|
||||
|
||||
pytest.parse_json_from_output = parse_json_from_output
|
||||
pytest.safe_get_nested = safe_get_nested
|
||||
'''
|
||||
|
||||
+137
-47
@@ -1,18 +1,18 @@
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal, Vertical
|
||||
from textual.widgets import Footer, Input, Label
|
||||
from textual.containers import Container, Horizontal
|
||||
from textual.widgets import Footer, Label
|
||||
|
||||
from framework.runtime.agent_runtime import AgentRuntime
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.tui.widgets.chat_repl import ChatRepl
|
||||
from framework.tui.widgets.graph_view import GraphOverview
|
||||
from framework.tui.widgets.log_pane import LogPane
|
||||
from framework.tui.widgets.selectable_rich_log import SelectableRichLog
|
||||
|
||||
|
||||
@@ -136,28 +136,15 @@ class AdenTUI(App):
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
#left-pane {
|
||||
width: 60%;
|
||||
height: 100%;
|
||||
layout: vertical;
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
GraphOverview {
|
||||
height: 40%;
|
||||
width: 40%;
|
||||
height: 100%;
|
||||
background: $panel;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
LogPane {
|
||||
height: 60%;
|
||||
background: $surface;
|
||||
padding: 0;
|
||||
margin-bottom: 1;
|
||||
}
|
||||
|
||||
ChatRepl {
|
||||
width: 40%;
|
||||
width: 60%;
|
||||
height: 100%;
|
||||
background: $panel;
|
||||
border-left: tall $primary;
|
||||
@@ -180,13 +167,13 @@ class AdenTUI(App):
|
||||
scrollbar-color: $primary;
|
||||
}
|
||||
|
||||
Input {
|
||||
ChatTextArea {
|
||||
background: $surface;
|
||||
border: tall $primary;
|
||||
margin-top: 1;
|
||||
}
|
||||
|
||||
Input:focus {
|
||||
ChatTextArea:focus {
|
||||
border: tall $accent;
|
||||
}
|
||||
|
||||
@@ -208,8 +195,10 @@ class AdenTUI(App):
|
||||
Binding("ctrl+c", "ctrl_c", "Interrupt", show=False, priority=True),
|
||||
Binding("super+c", "ctrl_c", "Copy", show=False, priority=True),
|
||||
Binding("ctrl+s", "screenshot", "Screenshot (SVG)", show=True, priority=True),
|
||||
Binding("ctrl+l", "toggle_logs", "Toggle Logs", show=True, priority=True),
|
||||
Binding("ctrl+z", "pause_execution", "Pause", show=True, priority=True),
|
||||
Binding("ctrl+r", "show_sessions", "Sessions", show=True, priority=True),
|
||||
Binding("ctrl+p", "attach_pdf", "Attach PDF", show=True, priority=True),
|
||||
Binding("tab", "focus_next", "Next Panel", show=True),
|
||||
Binding("shift+tab", "focus_previous", "Previous Panel", show=False),
|
||||
]
|
||||
@@ -223,7 +212,6 @@ class AdenTUI(App):
|
||||
super().__init__()
|
||||
|
||||
self.runtime = runtime
|
||||
self.log_pane = LogPane()
|
||||
self.graph_view = GraphOverview(runtime)
|
||||
self.chat_repl = ChatRepl(runtime, resume_session, resume_checkpoint)
|
||||
self.status_bar = StatusBar(graph_id=runtime.graph.id)
|
||||
@@ -253,11 +241,7 @@ class AdenTUI(App):
|
||||
yield self.status_bar
|
||||
|
||||
yield Horizontal(
|
||||
Vertical(
|
||||
self.log_pane,
|
||||
self.graph_view,
|
||||
id="left-pane",
|
||||
),
|
||||
self.graph_view,
|
||||
self.chat_repl,
|
||||
)
|
||||
|
||||
@@ -328,7 +312,7 @@ class AdenTUI(App):
|
||||
if record.name.startswith(("textual", "LiteLLM", "litellm")):
|
||||
continue
|
||||
|
||||
self.log_pane.write_python_log(record)
|
||||
self.chat_repl.write_python_log(record)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -350,6 +334,14 @@ class AdenTUI(App):
|
||||
EventType.CONSTRAINT_VIOLATION,
|
||||
EventType.STATE_CHANGED,
|
||||
EventType.NODE_INPUT_BLOCKED,
|
||||
EventType.CONTEXT_COMPACTED,
|
||||
EventType.NODE_INTERNAL_OUTPUT,
|
||||
EventType.JUDGE_VERDICT,
|
||||
EventType.OUTPUT_KEY_SET,
|
||||
EventType.NODE_RETRY,
|
||||
EventType.EDGE_TRAVERSED,
|
||||
EventType.EXECUTION_PAUSED,
|
||||
EventType.EXECUTION_RESUMED,
|
||||
]
|
||||
|
||||
_LOG_PANE_EVENTS = frozenset(_EVENT_TYPES) - {
|
||||
@@ -368,15 +360,36 @@ class AdenTUI(App):
|
||||
pass
|
||||
|
||||
async def _handle_event(self, event: AgentEvent) -> None:
|
||||
"""Called from the agent thread — bridge to Textual's main thread."""
|
||||
"""Bridge events to Textual's main thread for UI updates.
|
||||
|
||||
Events may arrive from the agent-execution thread (normal LLM/tool
|
||||
work) or from the Textual thread itself (e.g. webhook server events).
|
||||
``call_from_thread`` requires a *different* thread, so we detect
|
||||
which thread we're on and act accordingly.
|
||||
"""
|
||||
try:
|
||||
self.call_from_thread(self._route_event, event)
|
||||
except Exception:
|
||||
pass
|
||||
if threading.get_ident() == self._thread_id:
|
||||
# Already on Textual's thread — call directly.
|
||||
self._route_event(event)
|
||||
else:
|
||||
# On a different thread — bridge via call_from_thread.
|
||||
self.call_from_thread(self._route_event, event)
|
||||
except Exception as e:
|
||||
logging.getLogger("tui.events").error(
|
||||
"call_from_thread failed for %s (node=%s): %s",
|
||||
event.type.value,
|
||||
event.node_id or "?",
|
||||
e,
|
||||
)
|
||||
|
||||
def _route_event(self, event: AgentEvent) -> None:
|
||||
"""Route incoming events to widgets. Runs on Textual's main thread."""
|
||||
if not self.is_ready:
|
||||
logging.getLogger("tui.events").warning(
|
||||
"Event dropped (not ready): %s node=%s",
|
||||
event.type.value,
|
||||
event.node_id or "?",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -407,6 +420,35 @@ class AdenTUI(App):
|
||||
self.chat_repl.handle_input_requested(
|
||||
event.node_id or event.data.get("node_id", ""),
|
||||
)
|
||||
elif et == EventType.NODE_LOOP_STARTED:
|
||||
self.chat_repl.handle_node_started(event.node_id or "")
|
||||
elif et == EventType.NODE_LOOP_ITERATION:
|
||||
self.chat_repl.handle_loop_iteration(event.data.get("iteration", 0))
|
||||
|
||||
# Track active node in chat_repl for mid-execution input
|
||||
if et == EventType.NODE_LOOP_STARTED:
|
||||
self.chat_repl.handle_node_started(event.node_id or "")
|
||||
elif et == EventType.NODE_LOOP_COMPLETED:
|
||||
self.chat_repl.handle_node_completed(event.node_id or "")
|
||||
|
||||
# Non-client-facing node output → chat repl
|
||||
if et == EventType.NODE_INTERNAL_OUTPUT:
|
||||
content = event.data.get("content", "")
|
||||
if content.strip():
|
||||
self.chat_repl.handle_internal_output(event.node_id or "", content)
|
||||
|
||||
# Execution paused/resumed → chat repl
|
||||
if et == EventType.EXECUTION_PAUSED:
|
||||
reason = event.data.get("reason", "")
|
||||
self.chat_repl.handle_execution_paused(event.node_id or "", reason)
|
||||
elif et == EventType.EXECUTION_RESUMED:
|
||||
self.chat_repl.handle_execution_resumed(event.node_id or "")
|
||||
|
||||
# Goal achieved / constraint violation → chat repl
|
||||
if et == EventType.GOAL_ACHIEVED:
|
||||
self.chat_repl.handle_goal_achieved(event.data)
|
||||
elif et == EventType.CONSTRAINT_VIOLATION:
|
||||
self.chat_repl.handle_constraint_violation(event.data)
|
||||
|
||||
# --- Graph view events ---
|
||||
if et in (
|
||||
@@ -444,6 +486,13 @@ class AdenTUI(App):
|
||||
started=False,
|
||||
)
|
||||
|
||||
# Edge traversal → graph view
|
||||
if et == EventType.EDGE_TRAVERSED:
|
||||
self.graph_view.handle_edge_traversed(
|
||||
event.data.get("source_node", ""),
|
||||
event.data.get("target_node", ""),
|
||||
)
|
||||
|
||||
# --- Status bar events ---
|
||||
if et == EventType.EXECUTION_STARTED:
|
||||
entry_node = event.data.get("entry_node") or (
|
||||
@@ -464,12 +513,36 @@ class AdenTUI(App):
|
||||
self.status_bar.set_node_detail("thinking...")
|
||||
elif et == EventType.NODE_STALLED:
|
||||
self.status_bar.set_node_detail(f"stalled: {event.data.get('reason', '')}")
|
||||
elif et == EventType.CONTEXT_COMPACTED:
|
||||
before = event.data.get("usage_before", "?")
|
||||
after = event.data.get("usage_after", "?")
|
||||
self.status_bar.set_node_detail(f"compacted: {before}% \u2192 {after}%")
|
||||
elif et == EventType.JUDGE_VERDICT:
|
||||
action = event.data.get("action", "?")
|
||||
self.status_bar.set_node_detail(f"judge: {action}")
|
||||
elif et == EventType.OUTPUT_KEY_SET:
|
||||
key = event.data.get("key", "?")
|
||||
self.status_bar.set_node_detail(f"set: {key}")
|
||||
elif et == EventType.NODE_RETRY:
|
||||
retry = event.data.get("retry_count", "?")
|
||||
max_r = event.data.get("max_retries", "?")
|
||||
self.status_bar.set_node_detail(f"retry {retry}/{max_r}")
|
||||
elif et == EventType.EXECUTION_PAUSED:
|
||||
self.status_bar.set_node_detail("paused")
|
||||
elif et == EventType.EXECUTION_RESUMED:
|
||||
self.status_bar.set_node_detail("resumed")
|
||||
|
||||
# --- Log pane events ---
|
||||
# --- Log events (inline in chat) ---
|
||||
if et in self._LOG_PANE_EVENTS:
|
||||
self.log_pane.write_event(event)
|
||||
except Exception:
|
||||
pass
|
||||
self.chat_repl.write_log_event(event)
|
||||
except Exception as e:
|
||||
logging.getLogger("tui.events").error(
|
||||
"Route failed for %s (node=%s): %s",
|
||||
event.type.value,
|
||||
event.node_id or "?",
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def save_screenshot(self, filename: str | None = None) -> str:
|
||||
"""Save a screenshot of the current screen as SVG (viewable in browsers).
|
||||
@@ -504,8 +577,8 @@ class AdenTUI(App):
|
||||
original_chat_border = chat_widget.styles.border_left
|
||||
chat_widget.styles.border_left = ("none", "transparent")
|
||||
|
||||
# Hide all Input widget borders
|
||||
input_widgets = self.query("Input")
|
||||
# Hide all TextArea widget borders
|
||||
input_widgets = self.query("ChatTextArea")
|
||||
original_input_borders = []
|
||||
for input_widget in input_widgets:
|
||||
original_input_borders.append(input_widget.styles.border)
|
||||
@@ -535,6 +608,12 @@ class AdenTUI(App):
|
||||
except Exception as e:
|
||||
self.notify(f"Screenshot failed: {e}", severity="error", timeout=5)
|
||||
|
||||
def action_toggle_logs(self) -> None:
|
||||
"""Toggle inline log display in chat (bound to Ctrl+L)."""
|
||||
self.chat_repl.toggle_logs()
|
||||
mode = "ON" if self.chat_repl._show_logs else "OFF"
|
||||
self.notify(f"Logs {mode}", severity="information", timeout=2)
|
||||
|
||||
def action_pause_execution(self) -> None:
|
||||
"""Immediately pause execution by cancelling task (bound to Ctrl+Z)."""
|
||||
try:
|
||||
@@ -575,19 +654,12 @@ class AdenTUI(App):
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
def action_show_sessions(self) -> None:
|
||||
async def action_show_sessions(self) -> None:
|
||||
"""Show sessions list (bound to Ctrl+R)."""
|
||||
# Send /sessions command to chat input
|
||||
try:
|
||||
chat_repl = self.query_one(ChatRepl)
|
||||
chat_input = chat_repl.query_one("#chat-input", Input)
|
||||
chat_input.value = "/sessions"
|
||||
# Trigger submission
|
||||
self.notify(
|
||||
"💡 Type /sessions in the chat to see all sessions",
|
||||
severity="information",
|
||||
timeout=3,
|
||||
)
|
||||
await chat_repl._submit_input("/sessions")
|
||||
except Exception:
|
||||
self.notify(
|
||||
"Use /sessions command to see all sessions",
|
||||
@@ -595,6 +667,24 @@ class AdenTUI(App):
|
||||
timeout=3,
|
||||
)
|
||||
|
||||
async def action_attach_pdf(self) -> None:
|
||||
"""Open native OS file dialog for PDF selection (bound to Ctrl+P)."""
|
||||
from framework.tui.widgets.file_browser import _has_gui, pick_pdf_file
|
||||
|
||||
if not _has_gui():
|
||||
self.notify(
|
||||
"No GUI available. Use /attach <path> instead.",
|
||||
severity="warning",
|
||||
timeout=5,
|
||||
)
|
||||
return
|
||||
|
||||
self.notify("Opening file dialog...", severity="information", timeout=2)
|
||||
path = await pick_pdf_file()
|
||||
|
||||
if path is not None:
|
||||
self.chat_repl.attach_pdf(path)
|
||||
|
||||
async def on_unmount(self) -> None:
|
||||
"""Cleanup on app shutdown - cancel execution which will save state."""
|
||||
self.is_ready = False
|
||||
|
||||
@@ -15,19 +15,49 @@ Client-facing input:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Vertical
|
||||
from textual.widgets import Input, Label
|
||||
from textual.message import Message
|
||||
from textual.widgets import Label, TextArea
|
||||
|
||||
from framework.runtime.agent_runtime import AgentRuntime
|
||||
from framework.runtime.event_bus import AgentEvent
|
||||
from framework.tui.widgets.log_pane import format_event, format_python_log
|
||||
from framework.tui.widgets.selectable_rich_log import SelectableRichLog as RichLog
|
||||
|
||||
|
||||
class ChatTextArea(TextArea):
|
||||
"""TextArea that submits on Enter and inserts newlines on Shift+Enter."""
|
||||
|
||||
class Submitted(Message):
|
||||
"""Posted when the user presses Enter."""
|
||||
|
||||
def __init__(self, text: str) -> None:
|
||||
super().__init__()
|
||||
self.text = text
|
||||
|
||||
async def _on_key(self, event) -> None:
|
||||
if event.key == "enter":
|
||||
text = self.text.strip()
|
||||
self.clear()
|
||||
if text:
|
||||
self.post_message(self.Submitted(text))
|
||||
event.stop()
|
||||
event.prevent_default()
|
||||
elif event.key == "shift+enter":
|
||||
event.key = "enter"
|
||||
await super()._on_key(event)
|
||||
else:
|
||||
await super()._on_key(event)
|
||||
|
||||
|
||||
class ChatRepl(Vertical):
|
||||
"""Widget for interactive chat/REPL."""
|
||||
|
||||
@@ -56,16 +86,17 @@ class ChatRepl(Vertical):
|
||||
display: none;
|
||||
}
|
||||
|
||||
ChatRepl > Input {
|
||||
ChatRepl > ChatTextArea {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
max-height: 7;
|
||||
dock: bottom;
|
||||
background: $surface;
|
||||
border: tall $primary;
|
||||
margin-top: 1;
|
||||
}
|
||||
|
||||
ChatRepl > Input:focus {
|
||||
ChatRepl > ChatTextArea:focus {
|
||||
border: tall $accent;
|
||||
}
|
||||
"""
|
||||
@@ -82,8 +113,14 @@ class ChatRepl(Vertical):
|
||||
self._streaming_snapshot: str = ""
|
||||
self._waiting_for_input: bool = False
|
||||
self._input_node_id: str | None = None
|
||||
self._pending_ask_question: str = ""
|
||||
self._active_node_id: str | None = None # Currently executing node
|
||||
self._resume_session = resume_session
|
||||
self._resume_checkpoint = resume_checkpoint
|
||||
self._session_index: list[str] = [] # IDs from last listing
|
||||
self._show_logs: bool = False # Clean mode by default
|
||||
self._log_buffer: list[str] = [] # Buffered log lines for backfill on toggle ON
|
||||
self._attached_pdf: dict | None = None # Pending PDF attachment for next message
|
||||
|
||||
# Dedicated event loop for agent execution.
|
||||
# Keeps blocking runtime code (LLM calls, MCP tools) off
|
||||
@@ -106,7 +143,7 @@ class ChatRepl(Vertical):
|
||||
min_width=0,
|
||||
)
|
||||
yield Label("Agent is processing...", id="processing-indicator")
|
||||
yield Input(placeholder="Enter input for agent...", id="chat-input")
|
||||
yield ChatTextArea(id="chat-input", placeholder="Enter input for agent...")
|
||||
|
||||
# Regex for file:// URIs that are NOT already inside Rich [link=...] markup
|
||||
_FILE_URI_RE = re.compile(r"(?<!\[link=)(file://[^\s)\]>*]+)")
|
||||
@@ -129,6 +166,31 @@ class ChatRepl(Vertical):
|
||||
if was_at_bottom:
|
||||
history.scroll_end(animate=False)
|
||||
|
||||
def toggle_logs(self) -> None:
|
||||
"""Toggle inline log display on/off. Backfills buffered logs on toggle ON."""
|
||||
self._show_logs = not self._show_logs
|
||||
if self._show_logs and self._log_buffer:
|
||||
self._write_history("[dim]--- Backfilling logs ---[/dim]")
|
||||
for line in self._log_buffer:
|
||||
self._write_history(line)
|
||||
self._write_history("[dim]--- Live logs ---[/dim]")
|
||||
mode = "ON (dirty)" if self._show_logs else "OFF (clean)"
|
||||
self._write_history(f"[dim]Logs {mode}[/dim]")
|
||||
|
||||
def write_log_event(self, event: AgentEvent) -> None:
|
||||
"""Buffer a formatted agent event. Display inline if logs are ON."""
|
||||
formatted = format_event(event)
|
||||
self._log_buffer.append(formatted)
|
||||
if self._show_logs:
|
||||
self._write_history(formatted)
|
||||
|
||||
def write_python_log(self, record: logging.LogRecord) -> None:
|
||||
"""Buffer a formatted Python log record. Display inline if logs are ON."""
|
||||
formatted = format_python_log(record)
|
||||
self._log_buffer.append(formatted)
|
||||
if self._show_logs:
|
||||
self._write_history(formatted)
|
||||
|
||||
async def _handle_command(self, command: str) -> None:
|
||||
"""Handle slash commands for session and checkpoint operations."""
|
||||
parts = command.split(maxsplit=2)
|
||||
@@ -136,35 +198,49 @@ class ChatRepl(Vertical):
|
||||
|
||||
if cmd == "/help":
|
||||
self._write_history("""[bold cyan]Available Commands:[/bold cyan]
|
||||
[bold]/attach[/bold] - Open file dialog to attach a PDF
|
||||
[bold]/attach[/bold] <file_path> - Attach a PDF from a specific path
|
||||
[bold]/detach[/bold] - Remove the currently attached PDF
|
||||
[bold]/sessions[/bold] - List all sessions for this agent
|
||||
[bold]/sessions[/bold] <session_id> - Show session details and checkpoints
|
||||
[bold]/resume[/bold] - Resume latest paused/failed session
|
||||
[bold]/resume[/bold] <session_id> - Resume session from where it stopped
|
||||
[bold]/resume[/bold] - List sessions and pick one to resume
|
||||
[bold]/resume[/bold] <number> - Resume session by list number
|
||||
[bold]/resume[/bold] <session_id> - Resume session by ID
|
||||
[bold]/recover[/bold] <session_id> <cp_id> - Recover from specific checkpoint
|
||||
[bold]/pause[/bold] - Pause current execution (Ctrl+Z)
|
||||
[bold]/help[/bold] - Show this help message
|
||||
|
||||
[dim]Examples:[/dim]
|
||||
/attach [dim]# Open file picker dialog[/dim]
|
||||
/attach ~/Documents/report.pdf [dim]# Attach a specific PDF[/dim]
|
||||
/detach [dim]# Remove attached PDF[/dim]
|
||||
/sessions [dim]# List all sessions[/dim]
|
||||
/sessions session_20260208_143022 [dim]# Show session details[/dim]
|
||||
/resume [dim]# Resume latest session (from state)[/dim]
|
||||
/resume session_20260208_143022 [dim]# Resume specific session (from state)[/dim]
|
||||
/recover session_20260208_143022 cp_xxx [dim]# Recover from specific checkpoint[/dim]
|
||||
/resume 1 [dim]# Resume first listed session[/dim]
|
||||
/pause [dim]# Pause (or Ctrl+Z)[/dim]
|
||||
""")
|
||||
elif cmd == "/sessions":
|
||||
session_id = parts[1].strip() if len(parts) > 1 else None
|
||||
await self._cmd_sessions(session_id)
|
||||
elif cmd == "/resume":
|
||||
# Resume from session state (not checkpoint-based)
|
||||
if len(parts) < 2:
|
||||
session_id = await self._find_latest_resumable_session()
|
||||
if not session_id:
|
||||
self._write_history("[bold red]No resumable sessions found[/bold red]")
|
||||
self._write_history(" Tip: Use [bold]/sessions[/bold] to see all sessions")
|
||||
# No arg → show session list so user can pick one
|
||||
await self._cmd_sessions(None)
|
||||
return
|
||||
|
||||
arg = parts[1].strip()
|
||||
|
||||
# Numeric index → resolve from last listing
|
||||
if arg.isdigit():
|
||||
idx = int(arg) - 1 # 1-based to 0-based
|
||||
if 0 <= idx < len(self._session_index):
|
||||
session_id = self._session_index[idx]
|
||||
else:
|
||||
self._write_history(f"[bold red]Error:[/bold red] No session at index {arg}")
|
||||
self._write_history(" Use [bold]/resume[/bold] to see available sessions")
|
||||
return
|
||||
else:
|
||||
session_id = parts[1].strip()
|
||||
session_id = arg
|
||||
|
||||
await self._cmd_resume(session_id)
|
||||
elif cmd == "/recover":
|
||||
# Recover from specific checkpoint
|
||||
@@ -180,6 +256,16 @@ class ChatRepl(Vertical):
|
||||
session_id = parts[1].strip()
|
||||
checkpoint_id = parts[2].strip()
|
||||
await self._cmd_recover(session_id, checkpoint_id)
|
||||
elif cmd == "/attach":
|
||||
file_path = parts[1].strip() if len(parts) > 1 else None
|
||||
await self._cmd_attach(file_path)
|
||||
elif cmd == "/detach":
|
||||
if self._attached_pdf:
|
||||
name = self._attached_pdf["filename"]
|
||||
self._attached_pdf = None
|
||||
self._write_history(f"[dim]Detached: {name}[/dim]")
|
||||
else:
|
||||
self._write_history("[dim]No PDF attached.[/dim]")
|
||||
elif cmd == "/pause":
|
||||
await self._cmd_pause()
|
||||
else:
|
||||
@@ -188,6 +274,63 @@ class ChatRepl(Vertical):
|
||||
"Type [bold]/help[/bold] for available commands"
|
||||
)
|
||||
|
||||
def attach_pdf(self, path: Path) -> None:
|
||||
"""Validate and stage a PDF file for the next message.
|
||||
|
||||
Copies the PDF to ~/.hive/assets/ and stores the path. The agent's
|
||||
pdf_read tool handles text extraction at runtime.
|
||||
|
||||
Called by /attach <path> or by the native file dialog.
|
||||
"""
|
||||
path = Path(path).expanduser().resolve()
|
||||
|
||||
if not path.exists():
|
||||
self._write_history(f"[bold red]Error:[/bold red] File not found: {path}")
|
||||
return
|
||||
if path.suffix.lower() != ".pdf":
|
||||
self._write_history("[bold red]Error:[/bold red] Only PDF files are supported")
|
||||
return
|
||||
|
||||
# Copy to ~/.hive/assets/, deduplicating like a normal filesystem:
|
||||
# resume.pdf → resume(1).pdf → resume(2).pdf
|
||||
assets_dir = Path.home() / ".hive" / "assets"
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = assets_dir / path.name
|
||||
counter = 1
|
||||
while dest.exists():
|
||||
dest = assets_dir / f"{path.stem}({counter}){path.suffix}"
|
||||
counter += 1
|
||||
shutil.copy2(path, dest)
|
||||
|
||||
self._attached_pdf = {
|
||||
"filename": path.name,
|
||||
"path": str(dest),
|
||||
}
|
||||
|
||||
self._write_history(f"[green]Attached:[/green] {path.name}")
|
||||
self._write_history("[dim]PDF will be read by the agent on your next message.[/dim]")
|
||||
|
||||
async def _cmd_attach(self, file_path: str | None = None) -> None:
|
||||
"""Attach a PDF file for context injection into the next message."""
|
||||
if file_path is None:
|
||||
from framework.tui.widgets.file_browser import _has_gui, pick_pdf_file
|
||||
|
||||
if not _has_gui():
|
||||
self._write_history(
|
||||
"[bold yellow]No GUI available.[/bold yellow] "
|
||||
"Provide a path: [bold]/attach /path/to/file.pdf[/bold]"
|
||||
)
|
||||
return
|
||||
|
||||
self._write_history("[dim]Opening file dialog...[/dim]")
|
||||
path = await pick_pdf_file()
|
||||
|
||||
if path is not None:
|
||||
self.attach_pdf(path)
|
||||
return
|
||||
|
||||
self.attach_pdf(Path(file_path))
|
||||
|
||||
async def _cmd_sessions(self, session_id: str | None) -> None:
|
||||
"""List sessions or show details of a specific session."""
|
||||
try:
|
||||
@@ -241,6 +384,15 @@ class ChatRepl(Vertical):
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _get_session_label(self, state: dict) -> str:
|
||||
"""Extract the first user message from input_data as a human-readable label."""
|
||||
input_data = state.get("input_data", {})
|
||||
for value in input_data.values():
|
||||
if isinstance(value, str) and value.strip():
|
||||
label = value.strip()
|
||||
return label[:60] + "..." if len(label) > 60 else label
|
||||
return "(no input)"
|
||||
|
||||
async def _list_sessions(self, storage_path: Path) -> None:
|
||||
"""List all sessions for the agent."""
|
||||
self._write_history("[bold cyan]Available Sessions:[/bold cyan]")
|
||||
@@ -264,6 +416,11 @@ class ChatRepl(Vertical):
|
||||
|
||||
self._write_history(f"[dim]Found {len(session_dirs)} session(s)[/dim]\n")
|
||||
|
||||
# Reset the session index for numeric lookups
|
||||
self._session_index = []
|
||||
|
||||
import json
|
||||
|
||||
for session_dir in session_dirs[:10]: # Show last 10 sessions
|
||||
session_id = session_dir.name
|
||||
state_file = session_dir / "state.json"
|
||||
@@ -273,12 +430,15 @@ class ChatRepl(Vertical):
|
||||
|
||||
# Read session state
|
||||
try:
|
||||
import json
|
||||
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
|
||||
# Track this session for /resume <number> lookup
|
||||
self._session_index.append(session_id)
|
||||
index = len(self._session_index)
|
||||
|
||||
status = state.get("status", "unknown").upper()
|
||||
label = self._get_session_label(state)
|
||||
|
||||
# Status with color
|
||||
if status == "COMPLETED":
|
||||
@@ -292,25 +452,17 @@ class ChatRepl(Vertical):
|
||||
else:
|
||||
status_colored = f"[dim]{status}[/dim]"
|
||||
|
||||
# Check for checkpoints
|
||||
checkpoint_dir = session_dir / "checkpoints"
|
||||
checkpoint_count = 0
|
||||
if checkpoint_dir.exists():
|
||||
checkpoint_files = list(checkpoint_dir.glob("cp_*.json"))
|
||||
checkpoint_count = len(checkpoint_files)
|
||||
|
||||
# Session line
|
||||
self._write_history(f"📋 [bold]{session_id}[/bold]")
|
||||
self._write_history(f" Status: {status_colored} Checkpoints: {checkpoint_count}")
|
||||
|
||||
if checkpoint_count > 0:
|
||||
self._write_history(f" [dim]Resume: /resume {session_id}[/dim]")
|
||||
|
||||
# Session line with index and label
|
||||
self._write_history(f" [bold]{index}.[/bold] {label} {status_colored}")
|
||||
self._write_history(f" [dim]{session_id}[/dim]")
|
||||
self._write_history("") # Blank line
|
||||
|
||||
except Exception as e:
|
||||
self._write_history(f" [dim red]Error reading: {e}[/dim red]")
|
||||
|
||||
if self._session_index:
|
||||
self._write_history("[dim]Use [bold]/resume <number>[/bold] to resume a session[/dim]")
|
||||
|
||||
async def _show_session_details(self, storage_path: Path, session_id: str) -> None:
|
||||
"""Show detailed information about a specific session."""
|
||||
self._write_history(f"[bold cyan]Session Details:[/bold cyan] {session_id}\n")
|
||||
@@ -428,6 +580,7 @@ class ChatRepl(Vertical):
|
||||
if paused_at:
|
||||
# Has paused_at - resume from there
|
||||
resume_session_state = {
|
||||
"resume_session_id": session_id,
|
||||
"paused_at": paused_at,
|
||||
"memory": state.get("memory", {}),
|
||||
"execution_path": progress.get("path", []),
|
||||
@@ -435,8 +588,13 @@ class ChatRepl(Vertical):
|
||||
}
|
||||
resume_info = f"From node: [cyan]{paused_at}[/cyan]"
|
||||
else:
|
||||
# No paused_at - just retry with same input
|
||||
resume_session_state = {}
|
||||
# No paused_at - retry with same input but reuse session directory
|
||||
resume_session_state = {
|
||||
"resume_session_id": session_id,
|
||||
"memory": state.get("memory", {}),
|
||||
"execution_path": progress.get("path", []),
|
||||
"node_visit_counts": progress.get("node_visit_counts", {}),
|
||||
}
|
||||
resume_info = "Retrying with same input"
|
||||
|
||||
# Display resume info
|
||||
@@ -462,7 +620,7 @@ class ChatRepl(Vertical):
|
||||
indicator.display = True
|
||||
|
||||
# Update placeholder
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent resuming...)"
|
||||
|
||||
# Trigger execution with resume state
|
||||
@@ -540,6 +698,7 @@ class ChatRepl(Vertical):
|
||||
|
||||
# Create session_state for checkpoint recovery
|
||||
recover_session_state = {
|
||||
"resume_session_id": session_id,
|
||||
"resume_from_checkpoint": checkpoint_id,
|
||||
}
|
||||
|
||||
@@ -549,7 +708,7 @@ class ChatRepl(Vertical):
|
||||
indicator.display = True
|
||||
|
||||
# Update placeholder
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent recovering...)"
|
||||
|
||||
# Trigger execution with checkpoint recovery
|
||||
@@ -638,10 +797,14 @@ class ChatRepl(Vertical):
|
||||
# Check for resumable sessions
|
||||
self._check_and_show_resumable_sessions()
|
||||
|
||||
history.write(
|
||||
"[dim]Quick start: /sessions to see previous sessions, "
|
||||
"/pause to pause execution[/dim]\n"
|
||||
)
|
||||
# Show agent intro message if available
|
||||
if self.runtime.intro_message:
|
||||
history.write(f"[bold blue]Agent:[/bold blue] {self.runtime.intro_message}\n")
|
||||
else:
|
||||
history.write(
|
||||
"[dim]Quick start: /sessions to see previous sessions, "
|
||||
"/pause to pause execution[/dim]\n"
|
||||
)
|
||||
|
||||
def _check_and_show_resumable_sessions(self) -> None:
|
||||
"""Check for non-terminated sessions and prompt user."""
|
||||
@@ -678,16 +841,20 @@ class ChatRepl(Vertical):
|
||||
{
|
||||
"session_id": session_dir.name,
|
||||
"status": status.upper(),
|
||||
"label": self._get_session_label(state),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if resumable:
|
||||
self._write_history("\n[bold yellow]⚠ Non-terminated sessions found:[/bold yellow]")
|
||||
# Populate session index so /resume <number> works immediately
|
||||
self._session_index = [s["session_id"] for s in resumable[:3]]
|
||||
|
||||
self._write_history("\n[bold yellow]Non-terminated sessions found:[/bold yellow]")
|
||||
for i, session in enumerate(resumable[:3], 1): # Show top 3
|
||||
status = session["status"]
|
||||
session_id = session["session_id"]
|
||||
label = session["label"]
|
||||
|
||||
# Color code status
|
||||
if status == "PAUSED":
|
||||
@@ -699,23 +866,21 @@ class ChatRepl(Vertical):
|
||||
else:
|
||||
status_colored = f"[dim]{status}[/dim]"
|
||||
|
||||
self._write_history(f" {i}. {session_id[:32]}... [{status_colored}]")
|
||||
self._write_history(f" [bold]{i}.[/bold] {label} {status_colored}")
|
||||
|
||||
self._write_history("\n[bold cyan]What would you like to do?[/bold cyan]")
|
||||
self._write_history(" • Type [bold]/resume[/bold] to continue the latest session")
|
||||
self._write_history(
|
||||
f" • Type [bold]/resume {resumable[0]['session_id']}[/bold] "
|
||||
"for specific session"
|
||||
)
|
||||
self._write_history(" • Or just type your input to start a new session\n")
|
||||
self._write_history("\n Type [bold]/resume <number>[/bold] to continue a session")
|
||||
self._write_history(" Or just type your input to start a new session\n")
|
||||
|
||||
except Exception:
|
||||
# Silently fail - don't block TUI startup
|
||||
pass
|
||||
|
||||
async def on_input_submitted(self, message: Input.Submitted) -> None:
|
||||
"""Handle input submission — either start new execution or inject input."""
|
||||
user_input = message.value.strip()
|
||||
async def on_chat_text_area_submitted(self, message: ChatTextArea.Submitted) -> None:
|
||||
"""Handle chat input submission."""
|
||||
await self._submit_input(message.text)
|
||||
|
||||
async def _submit_input(self, user_input: str) -> None:
|
||||
"""Handle submitted text — either start new execution or inject input."""
|
||||
if not user_input:
|
||||
return
|
||||
|
||||
@@ -723,16 +888,14 @@ class ChatRepl(Vertical):
|
||||
# Commands work during execution, during client-facing input, anytime
|
||||
if user_input.startswith("/"):
|
||||
await self._handle_command(user_input)
|
||||
message.input.value = ""
|
||||
return
|
||||
|
||||
# Client-facing input: route to the waiting node
|
||||
if self._waiting_for_input and self._input_node_id:
|
||||
self._write_history(f"[bold green]You:[/bold green] {user_input}")
|
||||
message.input.value = ""
|
||||
|
||||
# Keep input enabled for commands (but change placeholder)
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent processing...)"
|
||||
self._waiting_for_input = False
|
||||
|
||||
@@ -752,16 +915,29 @@ class ChatRepl(Vertical):
|
||||
self._write_history(f"[bold red]Error delivering input:[/bold red] {e}")
|
||||
return
|
||||
|
||||
# Double-submit guard: reject input while an execution is in-flight
|
||||
# Mid-execution input: inject into the active node's conversation
|
||||
if self._current_exec_id is not None and self._active_node_id:
|
||||
self._write_history(f"[bold green]You:[/bold green] {user_input}")
|
||||
node_id = self._active_node_id
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.runtime.inject_input(node_id, user_input),
|
||||
self._agent_loop,
|
||||
)
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error delivering input:[/bold red] {e}")
|
||||
return
|
||||
|
||||
# Double-submit guard: no active node to inject into
|
||||
if self._current_exec_id is not None:
|
||||
self._write_history("[dim]Agent is still running — please wait.[/dim]")
|
||||
return
|
||||
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
|
||||
# Append user message and clear input
|
||||
# Append user message
|
||||
self._write_history(f"[bold green]You:[/bold green] {user_input}")
|
||||
message.input.value = ""
|
||||
|
||||
try:
|
||||
# Get entry point
|
||||
@@ -787,9 +963,16 @@ class ChatRepl(Vertical):
|
||||
indicator.display = True
|
||||
|
||||
# Keep input enabled for commands during execution
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.placeholder = "Commands available: /pause, /sessions, /help"
|
||||
|
||||
# Build input data, injecting attached PDF file path if present
|
||||
input_data = {input_key: user_input}
|
||||
if self._attached_pdf:
|
||||
input_data["pdf_file_path"] = self._attached_pdf["path"]
|
||||
self._write_history(f"[dim]Including PDF: {self._attached_pdf['filename']}[/dim]")
|
||||
self._attached_pdf = None
|
||||
|
||||
# Submit execution to the dedicated agent loop so blocking
|
||||
# runtime code (LLM, MCP tools) never touches Textual's loop.
|
||||
# trigger() returns immediately with an exec_id; the heavy
|
||||
@@ -797,7 +980,7 @@ class ChatRepl(Vertical):
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.runtime.trigger(
|
||||
entry_point_id=entry_point.id,
|
||||
input_data={input_key: user_input},
|
||||
input_data=input_data,
|
||||
),
|
||||
self._agent_loop,
|
||||
)
|
||||
@@ -808,12 +991,32 @@ class ChatRepl(Vertical):
|
||||
indicator.display = False
|
||||
self._current_exec_id = None
|
||||
# Re-enable input on error
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.disabled = False
|
||||
self._write_history(f"[bold red]Error:[/bold red] {e}")
|
||||
|
||||
# -- Event handlers called by app.py _handle_event --
|
||||
|
||||
def handle_node_started(self, node_id: str) -> None:
|
||||
"""Reset streaming state and track active node when a new node begins.
|
||||
|
||||
Flushes any stale ``_streaming_snapshot`` left over from the
|
||||
previous node and resets the processing indicator so the user
|
||||
sees a clean transition between graph nodes.
|
||||
"""
|
||||
self._active_node_id = node_id
|
||||
if self._streaming_snapshot:
|
||||
self._write_history(f"[bold blue]Agent:[/bold blue] {self._streaming_snapshot}")
|
||||
self._streaming_snapshot = ""
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
indicator.update("Thinking...")
|
||||
|
||||
def handle_loop_iteration(self, iteration: int) -> None:
|
||||
"""Flush accumulated streaming text when a new loop iteration starts."""
|
||||
if self._streaming_snapshot:
|
||||
self._write_history(f"[bold blue]Agent:[/bold blue] {self._streaming_snapshot}")
|
||||
self._streaming_snapshot = ""
|
||||
|
||||
def handle_text_delta(self, content: str, snapshot: str) -> None:
|
||||
"""Handle a streaming text token from the LLM."""
|
||||
self._streaming_snapshot = snapshot
|
||||
@@ -829,23 +1032,42 @@ class ChatRepl(Vertical):
|
||||
|
||||
def handle_tool_started(self, tool_name: str, tool_input: dict[str, Any]) -> None:
|
||||
"""Handle a tool call starting."""
|
||||
# Update indicator to show tool activity
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
|
||||
if tool_name == "ask_user":
|
||||
# Stash the question for handle_input_requested() to display.
|
||||
# Suppress the generic "Tool: ask_user" line.
|
||||
self._pending_ask_question = tool_input.get("question", "")
|
||||
indicator.update("Preparing question...")
|
||||
return
|
||||
|
||||
# Update indicator to show tool activity
|
||||
indicator.update(f"Using tool: {tool_name}...")
|
||||
|
||||
# Write a discrete status line to history
|
||||
self._write_history(f"[dim]Tool: {tool_name}[/dim]")
|
||||
# Buffer and conditionally display tool status line
|
||||
line = f"[dim]Tool: {tool_name}[/dim]"
|
||||
self._log_buffer.append(line)
|
||||
if self._show_logs:
|
||||
self._write_history(line)
|
||||
|
||||
def handle_tool_completed(self, tool_name: str, result: str, is_error: bool) -> None:
|
||||
"""Handle a tool call completing."""
|
||||
if tool_name == "ask_user":
|
||||
# Suppress the synthetic "Waiting for user input..." result.
|
||||
# The actual question is displayed by handle_input_requested().
|
||||
return
|
||||
|
||||
result_str = str(result)
|
||||
preview = result_str[:200] + "..." if len(result_str) > 200 else result_str
|
||||
preview = preview.replace("\n", " ")
|
||||
|
||||
if is_error:
|
||||
self._write_history(f"[dim red]Tool {tool_name} error: {preview}[/dim red]")
|
||||
line = f"[dim red]Tool {tool_name} error: {preview}[/dim red]"
|
||||
else:
|
||||
self._write_history(f"[dim]Tool {tool_name} result: {preview}[/dim]")
|
||||
line = f"[dim]Tool {tool_name} result: {preview}[/dim]"
|
||||
self._log_buffer.append(line)
|
||||
if self._show_logs:
|
||||
self._write_history(line)
|
||||
|
||||
# Restore thinking indicator
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
@@ -868,9 +1090,12 @@ class ChatRepl(Vertical):
|
||||
self._streaming_snapshot = ""
|
||||
self._waiting_for_input = False
|
||||
self._input_node_id = None
|
||||
self._active_node_id = None
|
||||
self._pending_ask_question = ""
|
||||
self._log_buffer.clear()
|
||||
|
||||
# Re-enable input
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.disabled = False
|
||||
chat_input.placeholder = "Enter input for agent..."
|
||||
chat_input.focus()
|
||||
@@ -886,10 +1111,13 @@ class ChatRepl(Vertical):
|
||||
self._current_exec_id = None
|
||||
self._streaming_snapshot = ""
|
||||
self._waiting_for_input = False
|
||||
self._pending_ask_question = ""
|
||||
self._input_node_id = None
|
||||
self._active_node_id = None
|
||||
self._log_buffer.clear()
|
||||
|
||||
# Re-enable input
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.disabled = False
|
||||
chat_input.placeholder = "Enter input for agent..."
|
||||
chat_input.focus()
|
||||
@@ -902,17 +1130,54 @@ class ChatRepl(Vertical):
|
||||
and sets a flag so the next submission routes to inject_input().
|
||||
"""
|
||||
# Flush accumulated streaming text as agent output
|
||||
if self._streaming_snapshot:
|
||||
self._write_history(f"[bold blue]Agent:[/bold blue] {self._streaming_snapshot}")
|
||||
flushed_snapshot = self._streaming_snapshot
|
||||
if flushed_snapshot:
|
||||
self._write_history(f"[bold blue]Agent:[/bold blue] {flushed_snapshot}")
|
||||
self._streaming_snapshot = ""
|
||||
|
||||
# Display the ask_user question if stashed and not already
|
||||
# present in the streaming snapshot (avoids double-display).
|
||||
question = self._pending_ask_question
|
||||
self._pending_ask_question = ""
|
||||
if question and question not in flushed_snapshot:
|
||||
self._write_history(f"[bold blue]Agent:[/bold blue] {question}")
|
||||
|
||||
self._waiting_for_input = True
|
||||
self._input_node_id = node_id or None
|
||||
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
indicator.update("Waiting for your input...")
|
||||
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.disabled = False
|
||||
chat_input.placeholder = "Type your response..."
|
||||
chat_input.focus()
|
||||
|
||||
def handle_node_completed(self, node_id: str) -> None:
|
||||
"""Clear active node when it finishes."""
|
||||
if self._active_node_id == node_id:
|
||||
self._active_node_id = None
|
||||
|
||||
def handle_internal_output(self, node_id: str, content: str) -> None:
|
||||
"""Show output from non-client-facing nodes."""
|
||||
self._write_history(f"[dim cyan]⟨{node_id}⟩[/dim cyan] {content}")
|
||||
|
||||
def handle_execution_paused(self, node_id: str, reason: str) -> None:
|
||||
"""Show that execution has been paused."""
|
||||
msg = f"[bold yellow]⏸ Paused[/bold yellow] at [cyan]{node_id}[/cyan]"
|
||||
if reason:
|
||||
msg += f" [dim]({reason})[/dim]"
|
||||
self._write_history(msg)
|
||||
|
||||
def handle_execution_resumed(self, node_id: str) -> None:
|
||||
"""Show that execution has been resumed."""
|
||||
self._write_history(f"[bold green]▶ Resumed[/bold green] from [cyan]{node_id}[/cyan]")
|
||||
|
||||
def handle_goal_achieved(self, data: dict[str, Any]) -> None:
|
||||
"""Show goal achievement prominently."""
|
||||
self._write_history("[bold green]★ Goal achieved![/bold green]")
|
||||
|
||||
def handle_constraint_violation(self, data: dict[str, Any]) -> None:
|
||||
"""Show constraint violation as a warning."""
|
||||
desc = data.get("description", "Unknown constraint")
|
||||
self._write_history(f"[bold red]⚠ Constraint violation:[/bold red] {desc}")
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Native OS file dialog for PDF selection.
|
||||
|
||||
Launches the platform's native file picker (macOS: NSOpenPanel via osascript,
|
||||
Linux: zenity/kdialog, Windows: PowerShell OpenFileDialog) in a background
|
||||
thread so Textual's event loop stays responsive.
|
||||
|
||||
Falls back to None when no GUI is available (SSH, headless).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _has_gui() -> bool:
|
||||
"""Detect whether a GUI display is available."""
|
||||
if sys.platform == "darwin":
|
||||
# macOS: GUI is available unless running over SSH without display forwarding.
|
||||
return "SSH_CONNECTION" not in os.environ or "DISPLAY" in os.environ
|
||||
elif sys.platform == "win32":
|
||||
return True
|
||||
else:
|
||||
# Linux/BSD: Need X11 or Wayland.
|
||||
return bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"))
|
||||
|
||||
|
||||
def _linux_file_dialog() -> subprocess.CompletedProcess | None:
|
||||
"""Try zenity, then kdialog, on Linux. Returns CompletedProcess or None."""
|
||||
# Try zenity (GTK)
|
||||
try:
|
||||
return subprocess.run(
|
||||
[
|
||||
"zenity",
|
||||
"--file-selection",
|
||||
"--title=Select a PDF file",
|
||||
"--file-filter=PDF files (*.pdf)|*.pdf",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# Try kdialog (KDE)
|
||||
try:
|
||||
return subprocess.run(
|
||||
[
|
||||
"kdialog",
|
||||
"--getopenfilename",
|
||||
".",
|
||||
"PDF files (*.pdf)",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _pick_pdf_subprocess() -> Path | None:
|
||||
"""Run the native file dialog. BLOCKS until user picks or cancels.
|
||||
|
||||
Returns a Path on success, None on cancel or error.
|
||||
Must be called from a non-main thread (via asyncio.to_thread).
|
||||
"""
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
result = subprocess.run(
|
||||
[
|
||||
"osascript",
|
||||
"-e",
|
||||
'POSIX path of (choose file of type {"com.adobe.pdf"} '
|
||||
'with prompt "Select a PDF file")',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
elif sys.platform == "win32":
|
||||
ps_script = (
|
||||
"Add-Type -AssemblyName System.Windows.Forms; "
|
||||
"$f = New-Object System.Windows.Forms.OpenFileDialog; "
|
||||
"$f.Filter = 'PDF files (*.pdf)|*.pdf'; "
|
||||
"$f.Title = 'Select a PDF file'; "
|
||||
"if ($f.ShowDialog() -eq 'OK') { $f.FileName }"
|
||||
)
|
||||
result = subprocess.run(
|
||||
["powershell", "-NoProfile", "-Command", ps_script],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
else:
|
||||
result = _linux_file_dialog()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
path_str = result.stdout.strip()
|
||||
if not path_str:
|
||||
return None
|
||||
|
||||
path = Path(path_str)
|
||||
if path.is_file() and path.suffix.lower() == ".pdf":
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
async def pick_pdf_file() -> Path | None:
|
||||
"""Open a native OS file dialog to pick a PDF file.
|
||||
|
||||
Non-blocking: runs the dialog subprocess in a background thread via
|
||||
asyncio.to_thread(), so the calling event loop stays responsive.
|
||||
|
||||
Returns:
|
||||
Path to the selected PDF, or None if the user cancelled,
|
||||
no GUI is available, or the dialog command was not found.
|
||||
"""
|
||||
if not _has_gui():
|
||||
return None
|
||||
|
||||
return await asyncio.to_thread(_pick_pdf_subprocess)
|
||||
@@ -1,7 +1,15 @@
|
||||
"""
|
||||
Graph/Tree Overview Widget - Displays real agent graph structure.
|
||||
|
||||
Supports rendering loops (back-edges) via right-side return channels:
|
||||
arrows drawn on the right margin that visually point back up to earlier nodes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Vertical
|
||||
|
||||
@@ -9,6 +17,17 @@ from framework.runtime.agent_runtime import AgentRuntime
|
||||
from framework.runtime.event_bus import EventType
|
||||
from framework.tui.widgets.selectable_rich_log import SelectableRichLog as RichLog
|
||||
|
||||
# Width of each return-channel column (padding + │ + gap)
|
||||
_CHANNEL_WIDTH = 5
|
||||
|
||||
# Regex to strip Rich markup tags for measuring visible width
|
||||
_MARKUP_RE = re.compile(r"\[/?[^\]]*\]")
|
||||
|
||||
|
||||
def _plain_len(s: str) -> int:
|
||||
"""Return the visible character length of a Rich-markup string."""
|
||||
return len(_MARKUP_RE.sub("", s))
|
||||
|
||||
|
||||
class GraphOverview(Vertical):
|
||||
"""Widget to display Agent execution graph/tree with real data."""
|
||||
@@ -46,6 +65,13 @@ class GraphOverview(Vertical):
|
||||
def on_mount(self) -> None:
|
||||
"""Display initial graph structure."""
|
||||
self._display_graph()
|
||||
# Refresh every 1s so timer countdowns stay current
|
||||
if self.runtime._timer_next_fire is not None:
|
||||
self.set_interval(1.0, self._display_graph)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph analysis helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _topo_order(self) -> list[str]:
|
||||
"""BFS from entry_node following edges."""
|
||||
@@ -68,6 +94,39 @@ class GraphOverview(Vertical):
|
||||
visited.append(node.id)
|
||||
return visited
|
||||
|
||||
def _detect_back_edges(self, ordered: list[str]) -> list[dict]:
|
||||
"""Find edges where target appears before (or equal to) source in topo order.
|
||||
|
||||
Returns a list of dicts with keys: edge, source, target, source_idx, target_idx.
|
||||
"""
|
||||
order_idx = {nid: i for i, nid in enumerate(ordered)}
|
||||
back_edges: list[dict] = []
|
||||
for node_id in ordered:
|
||||
for edge in self.runtime.graph.get_outgoing_edges(node_id):
|
||||
target_idx = order_idx.get(edge.target, -1)
|
||||
source_idx = order_idx.get(node_id, -1)
|
||||
if target_idx != -1 and target_idx <= source_idx:
|
||||
back_edges.append(
|
||||
{
|
||||
"edge": edge,
|
||||
"source": node_id,
|
||||
"target": edge.target,
|
||||
"source_idx": source_idx,
|
||||
"target_idx": target_idx,
|
||||
}
|
||||
)
|
||||
return back_edges
|
||||
|
||||
def _is_back_edge(self, source: str, target: str, order_idx: dict[str, int]) -> bool:
|
||||
"""Check whether an edge from *source* to *target* is a back-edge."""
|
||||
si = order_idx.get(source, -1)
|
||||
ti = order_idx.get(target, -1)
|
||||
return ti != -1 and ti <= si
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Line rendering (Pass 1)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _render_node_line(self, node_id: str) -> str:
|
||||
"""Render a single node with status symbol and optional status text."""
|
||||
graph = self.runtime.graph
|
||||
@@ -95,43 +154,349 @@ class GraphOverview(Vertical):
|
||||
suffix = f" [italic]{status}[/italic]" if status else ""
|
||||
return f" {sym} {name}{suffix}"
|
||||
|
||||
def _render_edges(self, node_id: str) -> list[str]:
|
||||
"""Render edge connectors from this node to its targets."""
|
||||
edges = self.runtime.graph.get_outgoing_edges(node_id)
|
||||
if not edges:
|
||||
def _render_edges(self, node_id: str, order_idx: dict[str, int]) -> list[str]:
|
||||
"""Render forward-edge connectors from *node_id*.
|
||||
|
||||
Back-edges are excluded here — they are drawn by the return-channel
|
||||
overlay in Pass 2.
|
||||
"""
|
||||
all_edges = self.runtime.graph.get_outgoing_edges(node_id)
|
||||
if not all_edges:
|
||||
return []
|
||||
if len(edges) == 1:
|
||||
|
||||
# Split into forward and back
|
||||
forward = [e for e in all_edges if not self._is_back_edge(node_id, e.target, order_idx)]
|
||||
|
||||
if not forward:
|
||||
# All edges are back-edges — nothing to render here
|
||||
return []
|
||||
|
||||
if len(forward) == 1:
|
||||
return [" │", " ▼"]
|
||||
|
||||
# Fan-out: show branches
|
||||
lines: list[str] = []
|
||||
for i, edge in enumerate(edges):
|
||||
connector = "└" if i == len(edges) - 1 else "├"
|
||||
for i, edge in enumerate(forward):
|
||||
connector = "└" if i == len(forward) - 1 else "├"
|
||||
cond = ""
|
||||
if edge.condition.value not in ("always", "on_success"):
|
||||
cond = f" [dim]({edge.condition.value})[/dim]"
|
||||
lines.append(f" {connector}──▶ {edge.target}{cond}")
|
||||
return lines
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Return-channel overlay (Pass 2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _overlay_return_channels(
|
||||
self,
|
||||
lines: list[str],
|
||||
node_line_map: dict[str, int],
|
||||
back_edges: list[dict],
|
||||
available_width: int,
|
||||
) -> list[str]:
|
||||
"""Overlay right-side return channels onto the line buffer.
|
||||
|
||||
Each back-edge gets a vertical channel on the right margin. Channels
|
||||
are allocated left-to-right by increasing span length so that shorter
|
||||
(inner) loops are closer to the graph body and longer (outer) loops are
|
||||
further right.
|
||||
|
||||
If the terminal is too narrow to fit even one channel, we fall back to
|
||||
simple inline ``↺`` annotations instead.
|
||||
"""
|
||||
if not back_edges:
|
||||
return lines
|
||||
|
||||
num_channels = len(back_edges)
|
||||
|
||||
# Sort by span length ascending → inner loops get nearest channel
|
||||
sorted_be = sorted(back_edges, key=lambda b: b["source_idx"] - b["target_idx"])
|
||||
|
||||
# --- Insert dedicated connector lines for back-edge sources ---
|
||||
# Each back-edge source gets a blank line inserted after its node
|
||||
# section (after any forward-edge lines). We process insertions in
|
||||
# reverse order so that earlier indices remain valid.
|
||||
all_node_lines_set = set(node_line_map.values())
|
||||
|
||||
insertions: list[tuple[int, int]] = [] # (insert_after_line, be_index)
|
||||
for be_idx, be in enumerate(sorted_be):
|
||||
source_node_line = node_line_map.get(be["source"])
|
||||
if source_node_line is None:
|
||||
continue
|
||||
# Walk forward to find the last line in this node's section
|
||||
last_section_line = source_node_line
|
||||
for li in range(source_node_line + 1, len(lines)):
|
||||
if li in all_node_lines_set:
|
||||
break
|
||||
last_section_line = li
|
||||
insertions.append((last_section_line, be_idx))
|
||||
|
||||
source_line_for_be: dict[int, int] = {}
|
||||
for insert_after, be_idx in sorted(insertions, reverse=True):
|
||||
insert_at = insert_after + 1
|
||||
lines.insert(insert_at, "") # placeholder for connector
|
||||
source_line_for_be[be_idx] = insert_at
|
||||
# Shift node_line_map entries that come after the insertion point
|
||||
for nid in node_line_map:
|
||||
if node_line_map[nid] > insert_after:
|
||||
node_line_map[nid] += 1
|
||||
# Also shift already-assigned source lines
|
||||
for prev_idx in source_line_for_be:
|
||||
if prev_idx != be_idx and source_line_for_be[prev_idx] > insert_after:
|
||||
source_line_for_be[prev_idx] += 1
|
||||
|
||||
# Recompute max content width after insertions
|
||||
max_content_w = max(_plain_len(ln) for ln in lines) if lines else 0
|
||||
|
||||
# Check if we have room for channels
|
||||
channels_total_w = num_channels * _CHANNEL_WIDTH
|
||||
if max_content_w + channels_total_w + 2 > available_width:
|
||||
return self._inline_back_edge_fallback(lines, node_line_map, back_edges)
|
||||
|
||||
content_pad = max_content_w + 3 # gap between content and first channel
|
||||
|
||||
# Build channel info with final line positions
|
||||
channel_info: list[dict] = []
|
||||
for ch_idx, be in enumerate(sorted_be):
|
||||
target_line = node_line_map.get(be["target"])
|
||||
source_line = source_line_for_be.get(ch_idx)
|
||||
if target_line is None or source_line is None:
|
||||
continue
|
||||
col = content_pad + ch_idx * _CHANNEL_WIDTH
|
||||
channel_info.append(
|
||||
{
|
||||
"target_line": target_line,
|
||||
"source_line": source_line,
|
||||
"col": col,
|
||||
}
|
||||
)
|
||||
|
||||
if not channel_info:
|
||||
return lines
|
||||
|
||||
# Build overlay grid — one row per line, columns for channel area
|
||||
total_width = content_pad + num_channels * _CHANNEL_WIDTH + 1
|
||||
overlay_width = total_width - max_content_w
|
||||
overlays: list[list[str]] = [[" "] * overlay_width for _ in range(len(lines))]
|
||||
|
||||
for ci in channel_info:
|
||||
tl = ci["target_line"]
|
||||
sl = ci["source_line"]
|
||||
col_offset = ci["col"] - max_content_w
|
||||
|
||||
if col_offset < 0 or col_offset >= overlay_width:
|
||||
continue
|
||||
|
||||
# Target line: ◄──...──┐
|
||||
if 0 <= tl < len(overlays):
|
||||
for c in range(col_offset):
|
||||
if overlays[tl][c] == " ":
|
||||
overlays[tl][c] = "─"
|
||||
overlays[tl][col_offset] = "┐"
|
||||
|
||||
# Source line: ──...──┘
|
||||
if 0 <= sl < len(overlays):
|
||||
for c in range(col_offset):
|
||||
if overlays[sl][c] == " ":
|
||||
overlays[sl][c] = "─"
|
||||
overlays[sl][col_offset] = "┘"
|
||||
|
||||
# Vertical lines between target+1 and source-1
|
||||
for li in range(tl + 1, sl):
|
||||
if 0 <= li < len(overlays) and overlays[li][col_offset] == " ":
|
||||
overlays[li][col_offset] = "│"
|
||||
|
||||
# Merge overlays into the line strings
|
||||
result: list[str] = []
|
||||
for i, line in enumerate(lines):
|
||||
pw = _plain_len(line)
|
||||
pad = max_content_w - pw
|
||||
overlay_chars = overlays[i] if i < len(overlays) else []
|
||||
overlay_str = "".join(overlay_chars)
|
||||
overlay_trimmed = overlay_str.rstrip()
|
||||
if overlay_trimmed:
|
||||
is_target_line = any(ci["target_line"] == i for ci in channel_info)
|
||||
if is_target_line:
|
||||
overlay_trimmed = "◄" + overlay_trimmed[1:]
|
||||
|
||||
is_source_line = any(ci["source_line"] == i for ci in channel_info)
|
||||
if is_source_line and not line.strip():
|
||||
# Inserted blank line → build └───┘ connector.
|
||||
# " └" = 3 chars of content prefix, so remaining pad = max_content_w - 3
|
||||
remaining_pad = max_content_w - 3
|
||||
full = list(" " * remaining_pad + overlay_trimmed)
|
||||
# Find the ┘ corner for this source connector
|
||||
corner_pos = -1
|
||||
for ci_s in channel_info:
|
||||
if ci_s["source_line"] == i:
|
||||
corner_pos = remaining_pad + (ci_s["col"] - max_content_w)
|
||||
break
|
||||
# Fill everything up to the corner with ─
|
||||
if corner_pos >= 0:
|
||||
for c in range(corner_pos):
|
||||
if full[c] not in ("│", "┘", "┐"):
|
||||
full[c] = "─"
|
||||
connector = " └" + "".join(full).rstrip()
|
||||
result.append(f"[dim]{connector}[/dim]")
|
||||
continue
|
||||
|
||||
colored_overlay = f"[dim]{' ' * pad}{overlay_trimmed}[/dim]"
|
||||
result.append(f"{line}{colored_overlay}")
|
||||
else:
|
||||
result.append(line)
|
||||
|
||||
return result
|
||||
|
||||
def _inline_back_edge_fallback(
|
||||
self,
|
||||
lines: list[str],
|
||||
node_line_map: dict[str, int],
|
||||
back_edges: list[dict],
|
||||
) -> list[str]:
|
||||
"""Fallback: add inline ↺ annotations when terminal is too narrow for channels."""
|
||||
# Group back-edges by source node
|
||||
source_to_be: dict[str, list[dict]] = {}
|
||||
for be in back_edges:
|
||||
source_to_be.setdefault(be["source"], []).append(be)
|
||||
|
||||
result = list(lines)
|
||||
# Insert annotation lines after each source node's section
|
||||
offset = 0
|
||||
all_node_lines = sorted(node_line_map.values())
|
||||
for source, bes in source_to_be.items():
|
||||
source_line = node_line_map.get(source)
|
||||
if source_line is None:
|
||||
continue
|
||||
# Find end of source node section
|
||||
end_line = source_line
|
||||
for nl in all_node_lines:
|
||||
if nl > source_line:
|
||||
end_line = nl - 1
|
||||
break
|
||||
else:
|
||||
end_line = len(lines) - 1
|
||||
# Insert after last content line of this node's section
|
||||
insert_at = end_line + offset + 1
|
||||
for be in bes:
|
||||
cond = ""
|
||||
edge = be["edge"]
|
||||
if edge.condition.value not in ("always", "on_success"):
|
||||
cond = f" [dim]({edge.condition.value})[/dim]"
|
||||
annotation = f" [yellow]↺[/yellow] {be['target']}{cond}"
|
||||
result.insert(insert_at, annotation)
|
||||
insert_at += 1
|
||||
offset += 1
|
||||
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main display
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _display_graph(self) -> None:
|
||||
"""Display the graph as an ASCII DAG with edge connectors."""
|
||||
"""Display the graph as an ASCII DAG with edge connectors and loop channels."""
|
||||
display = self.query_one("#graph-display", RichLog)
|
||||
display.clear()
|
||||
|
||||
graph = self.runtime.graph
|
||||
display.write(f"[bold cyan]Agent Graph:[/bold cyan] {graph.id}\n")
|
||||
|
||||
# Render each node in topological order with edges
|
||||
ordered = self._topo_order()
|
||||
order_idx = {nid: i for i, nid in enumerate(ordered)}
|
||||
|
||||
# --- Pass 1: Build line buffer ---
|
||||
lines: list[str] = []
|
||||
node_line_map: dict[str, int] = {}
|
||||
|
||||
for node_id in ordered:
|
||||
display.write(self._render_node_line(node_id))
|
||||
for edge_line in self._render_edges(node_id):
|
||||
display.write(edge_line)
|
||||
node_line_map[node_id] = len(lines)
|
||||
lines.append(self._render_node_line(node_id))
|
||||
for edge_line in self._render_edges(node_id, order_idx):
|
||||
lines.append(edge_line)
|
||||
|
||||
# --- Pass 2: Overlay return channels for back-edges ---
|
||||
back_edges = self._detect_back_edges(ordered)
|
||||
if back_edges:
|
||||
# Try to get actual widget width; default to a reasonable value
|
||||
try:
|
||||
available_width = self.size.width or 60
|
||||
except Exception:
|
||||
available_width = 60
|
||||
lines = self._overlay_return_channels(lines, node_line_map, back_edges, available_width)
|
||||
|
||||
# Write all lines
|
||||
for line in lines:
|
||||
display.write(line)
|
||||
|
||||
# Execution path footer
|
||||
if self.execution_path:
|
||||
display.write("")
|
||||
display.write(f"[dim]Path:[/dim] {' → '.join(self.execution_path[-5:])}")
|
||||
|
||||
# Event sources section
|
||||
self._render_event_sources(display)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event sources display
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _render_event_sources(self, display: RichLog) -> None:
|
||||
"""Render event source info (webhooks, timers) below the graph."""
|
||||
entry_points = self.runtime.get_entry_points()
|
||||
|
||||
# Filter to non-manual entry points (webhooks, timers, events)
|
||||
event_sources = [ep for ep in entry_points if ep.trigger_type not in ("manual",)]
|
||||
if not event_sources:
|
||||
return
|
||||
|
||||
display.write("")
|
||||
display.write("[bold cyan]Event Sources[/bold cyan]")
|
||||
|
||||
config = self.runtime._config
|
||||
|
||||
for ep in event_sources:
|
||||
if ep.trigger_type == "timer":
|
||||
interval = ep.trigger_config.get("interval_minutes", "?")
|
||||
display.write(f" [green]⏱[/green] {ep.name} [dim]→ {ep.entry_node}[/dim]")
|
||||
# Show interval + next fire countdown
|
||||
next_fire = self.runtime._timer_next_fire.get(ep.id)
|
||||
if next_fire is not None:
|
||||
remaining = max(0, next_fire - time.monotonic())
|
||||
mins, secs = divmod(int(remaining), 60)
|
||||
display.write(
|
||||
f" [dim]every {interval} min — next in {mins}m {secs:02d}s[/dim]"
|
||||
)
|
||||
else:
|
||||
display.write(f" [dim]every {interval} min[/dim]")
|
||||
|
||||
elif ep.trigger_type in ("event", "webhook"):
|
||||
display.write(f" [yellow]⚡[/yellow] {ep.name} [dim]→ {ep.entry_node}[/dim]")
|
||||
# Show webhook endpoint if configured
|
||||
route = None
|
||||
for r in config.webhook_routes:
|
||||
src = r.get("source_id", "")
|
||||
if src and src in ep.id:
|
||||
route = r
|
||||
break
|
||||
if not route and config.webhook_routes:
|
||||
# Fall back to first route
|
||||
route = config.webhook_routes[0]
|
||||
|
||||
if route:
|
||||
host = config.webhook_host
|
||||
port = config.webhook_port
|
||||
path = route.get("path", "/webhook")
|
||||
display.write(f" [dim]{host}:{port}{path}[/dim]")
|
||||
else:
|
||||
event_types = ep.trigger_config.get("event_types", [])
|
||||
if event_types:
|
||||
display.write(f" [dim]events: {', '.join(event_types)}[/dim]")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API (called by app.py)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def update_active_node(self, node_id: str) -> None:
|
||||
"""Update the currently active node."""
|
||||
self.active_node = node_id
|
||||
@@ -177,6 +542,8 @@ class GraphOverview(Vertical):
|
||||
def handle_node_loop_completed(self, node_id: str) -> None:
|
||||
"""A node's event loop completed."""
|
||||
self._node_status.pop(node_id, None)
|
||||
if self.active_node == node_id:
|
||||
self.active_node = None
|
||||
self._display_graph()
|
||||
|
||||
def handle_tool_call(self, node_id: str, tool_name: str, *, started: bool) -> None:
|
||||
@@ -192,3 +559,8 @@ class GraphOverview(Vertical):
|
||||
"""Highlight a stalled node."""
|
||||
self._node_status[node_id] = f"[red]stalled: {reason}[/red]"
|
||||
self._display_graph()
|
||||
|
||||
def handle_edge_traversed(self, source_node: str, target_node: str) -> None:
|
||||
"""Highlight an edge being traversed."""
|
||||
self._node_status[source_node] = f"[dim]→ {target_node}[/dim]"
|
||||
self._display_graph()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""
|
||||
Log Pane Widget - Uses RichLog for reliable rendering.
|
||||
Log formatting utilities and LogPane widget.
|
||||
|
||||
The module-level functions (format_event, extract_event_text, format_python_log)
|
||||
can be used by any widget that needs to render log lines without instantiating LogPane.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -11,36 +14,108 @@ from textual.containers import Container
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.tui.widgets.selectable_rich_log import SelectableRichLog as RichLog
|
||||
|
||||
# --- Module-level formatting constants ---
|
||||
|
||||
EVENT_FORMAT: dict[EventType, tuple[str, str]] = {
|
||||
EventType.EXECUTION_STARTED: (">>", "bold cyan"),
|
||||
EventType.EXECUTION_COMPLETED: ("<<", "bold green"),
|
||||
EventType.EXECUTION_FAILED: ("!!", "bold red"),
|
||||
EventType.TOOL_CALL_STARTED: ("->", "yellow"),
|
||||
EventType.TOOL_CALL_COMPLETED: ("<-", "green"),
|
||||
EventType.NODE_LOOP_STARTED: ("@@", "cyan"),
|
||||
EventType.NODE_LOOP_ITERATION: ("..", "dim"),
|
||||
EventType.NODE_LOOP_COMPLETED: ("@@", "dim"),
|
||||
EventType.NODE_STALLED: ("!!", "bold yellow"),
|
||||
EventType.NODE_INPUT_BLOCKED: ("!!", "yellow"),
|
||||
EventType.GOAL_PROGRESS: ("%%", "blue"),
|
||||
EventType.GOAL_ACHIEVED: ("**", "bold green"),
|
||||
EventType.CONSTRAINT_VIOLATION: ("!!", "bold red"),
|
||||
EventType.STATE_CHANGED: ("~~", "dim"),
|
||||
EventType.CLIENT_INPUT_REQUESTED: ("??", "magenta"),
|
||||
}
|
||||
|
||||
LOG_LEVEL_COLORS: dict[int, str] = {
|
||||
logging.DEBUG: "dim",
|
||||
logging.INFO: "",
|
||||
logging.WARNING: "yellow",
|
||||
logging.ERROR: "red",
|
||||
logging.CRITICAL: "bold red",
|
||||
}
|
||||
|
||||
|
||||
# --- Module-level formatting functions ---
|
||||
|
||||
|
||||
def extract_event_text(event: AgentEvent) -> str:
|
||||
"""Extract human-readable text from an event's data dict."""
|
||||
et = event.type
|
||||
data = event.data
|
||||
|
||||
if et == EventType.EXECUTION_STARTED:
|
||||
return "Execution started"
|
||||
elif et == EventType.EXECUTION_COMPLETED:
|
||||
return "Execution completed"
|
||||
elif et == EventType.EXECUTION_FAILED:
|
||||
return f"Execution FAILED: {data.get('error', 'unknown')}"
|
||||
elif et == EventType.TOOL_CALL_STARTED:
|
||||
return f"Tool call: {data.get('tool_name', 'unknown')}"
|
||||
elif et == EventType.TOOL_CALL_COMPLETED:
|
||||
name = data.get("tool_name", "unknown")
|
||||
if data.get("is_error"):
|
||||
preview = str(data.get("result", ""))[:80]
|
||||
return f"Tool error: {name} - {preview}"
|
||||
return f"Tool done: {name}"
|
||||
elif et == EventType.NODE_LOOP_STARTED:
|
||||
return f"Node started: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.NODE_LOOP_ITERATION:
|
||||
return f"{event.node_id or 'unknown'} iteration {data.get('iteration', '?')}"
|
||||
elif et == EventType.NODE_LOOP_COMPLETED:
|
||||
return f"Node done: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.NODE_STALLED:
|
||||
reason = data.get("reason", "")
|
||||
node = event.node_id or "unknown"
|
||||
return f"Node stalled: {node} - {reason}" if reason else f"Node stalled: {node}"
|
||||
elif et == EventType.NODE_INPUT_BLOCKED:
|
||||
return f"Node input blocked: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.GOAL_PROGRESS:
|
||||
return f"Goal progress: {data.get('progress', '?')}"
|
||||
elif et == EventType.GOAL_ACHIEVED:
|
||||
return "Goal achieved"
|
||||
elif et == EventType.CONSTRAINT_VIOLATION:
|
||||
return f"Constraint violated: {data.get('description', 'unknown')}"
|
||||
elif et == EventType.STATE_CHANGED:
|
||||
return f"State changed: {data.get('key', 'unknown')}"
|
||||
elif et == EventType.CLIENT_INPUT_REQUESTED:
|
||||
return "Waiting for user input"
|
||||
else:
|
||||
return f"{et.value}: {data}"
|
||||
|
||||
|
||||
def format_event(event: AgentEvent) -> str:
|
||||
"""Format an AgentEvent as a Rich markup string with timestamp + symbol."""
|
||||
ts = event.timestamp.strftime("%H:%M:%S")
|
||||
symbol, color = EVENT_FORMAT.get(event.type, ("--", "dim"))
|
||||
text = extract_event_text(event)
|
||||
return f"[dim]{ts}[/dim] [{color}]{symbol} {text}[/{color}]"
|
||||
|
||||
|
||||
def format_python_log(record: logging.LogRecord) -> str:
|
||||
"""Format a Python log record as a Rich markup string with timestamp and severity color."""
|
||||
ts = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
|
||||
color = LOG_LEVEL_COLORS.get(record.levelno, "")
|
||||
msg = record.getMessage()
|
||||
if color:
|
||||
return f"[dim]{ts}[/dim] [{color}]{record.levelname}[/{color}] {msg}"
|
||||
else:
|
||||
return f"[dim]{ts}[/dim] {record.levelname} {msg}"
|
||||
|
||||
|
||||
# --- LogPane widget (kept for backward compatibility) ---
|
||||
|
||||
|
||||
class LogPane(Container):
|
||||
"""Widget to display logs with reliable rendering."""
|
||||
|
||||
_EVENT_FORMAT: dict[EventType, tuple[str, str]] = {
|
||||
EventType.EXECUTION_STARTED: (">>", "bold cyan"),
|
||||
EventType.EXECUTION_COMPLETED: ("<<", "bold green"),
|
||||
EventType.EXECUTION_FAILED: ("!!", "bold red"),
|
||||
EventType.TOOL_CALL_STARTED: ("->", "yellow"),
|
||||
EventType.TOOL_CALL_COMPLETED: ("<-", "green"),
|
||||
EventType.NODE_LOOP_STARTED: ("@@", "cyan"),
|
||||
EventType.NODE_LOOP_ITERATION: ("..", "dim"),
|
||||
EventType.NODE_LOOP_COMPLETED: ("@@", "dim"),
|
||||
EventType.NODE_STALLED: ("!!", "bold yellow"),
|
||||
EventType.NODE_INPUT_BLOCKED: ("!!", "yellow"),
|
||||
EventType.GOAL_PROGRESS: ("%%", "blue"),
|
||||
EventType.GOAL_ACHIEVED: ("**", "bold green"),
|
||||
EventType.CONSTRAINT_VIOLATION: ("!!", "bold red"),
|
||||
EventType.STATE_CHANGED: ("~~", "dim"),
|
||||
EventType.CLIENT_INPUT_REQUESTED: ("??", "magenta"),
|
||||
}
|
||||
|
||||
_LOG_LEVEL_COLORS = {
|
||||
logging.DEBUG: "dim",
|
||||
logging.INFO: "",
|
||||
logging.WARNING: "yellow",
|
||||
logging.ERROR: "red",
|
||||
logging.CRITICAL: "bold red",
|
||||
}
|
||||
|
||||
DEFAULT_CSS = """
|
||||
LogPane {
|
||||
width: 100%;
|
||||
@@ -58,84 +133,27 @@ class LogPane(Container):
|
||||
"""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
# RichLog is designed for log display and doesn't have TextArea's rendering issues
|
||||
yield RichLog(id="main-log", highlight=True, markup=True, auto_scroll=False)
|
||||
|
||||
def write_event(self, event: AgentEvent) -> None:
|
||||
"""Format an AgentEvent with timestamp + symbol and write to the log."""
|
||||
ts = event.timestamp.strftime("%H:%M:%S")
|
||||
symbol, color = self._EVENT_FORMAT.get(event.type, ("--", "dim"))
|
||||
text = self._extract_event_text(event)
|
||||
self.write_log(f"[dim]{ts}[/dim] [{color}]{symbol} {text}[/{color}]")
|
||||
|
||||
def _extract_event_text(self, event: AgentEvent) -> str:
|
||||
"""Extract human-readable text from an event's data dict."""
|
||||
et = event.type
|
||||
data = event.data
|
||||
|
||||
if et == EventType.EXECUTION_STARTED:
|
||||
return "Execution started"
|
||||
elif et == EventType.EXECUTION_COMPLETED:
|
||||
return "Execution completed"
|
||||
elif et == EventType.EXECUTION_FAILED:
|
||||
return f"Execution FAILED: {data.get('error', 'unknown')}"
|
||||
elif et == EventType.TOOL_CALL_STARTED:
|
||||
return f"Tool call: {data.get('tool_name', 'unknown')}"
|
||||
elif et == EventType.TOOL_CALL_COMPLETED:
|
||||
name = data.get("tool_name", "unknown")
|
||||
if data.get("is_error"):
|
||||
preview = str(data.get("result", ""))[:80]
|
||||
return f"Tool error: {name} - {preview}"
|
||||
return f"Tool done: {name}"
|
||||
elif et == EventType.NODE_LOOP_STARTED:
|
||||
return f"Node started: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.NODE_LOOP_ITERATION:
|
||||
return f"{event.node_id or 'unknown'} iteration {data.get('iteration', '?')}"
|
||||
elif et == EventType.NODE_LOOP_COMPLETED:
|
||||
return f"Node done: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.NODE_STALLED:
|
||||
reason = data.get("reason", "")
|
||||
node = event.node_id or "unknown"
|
||||
return f"Node stalled: {node} - {reason}" if reason else f"Node stalled: {node}"
|
||||
elif et == EventType.NODE_INPUT_BLOCKED:
|
||||
return f"Node input blocked: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.GOAL_PROGRESS:
|
||||
return f"Goal progress: {data.get('progress', '?')}"
|
||||
elif et == EventType.GOAL_ACHIEVED:
|
||||
return "Goal achieved"
|
||||
elif et == EventType.CONSTRAINT_VIOLATION:
|
||||
return f"Constraint violated: {data.get('description', 'unknown')}"
|
||||
elif et == EventType.STATE_CHANGED:
|
||||
return f"State changed: {data.get('key', 'unknown')}"
|
||||
elif et == EventType.CLIENT_INPUT_REQUESTED:
|
||||
return "Waiting for user input"
|
||||
else:
|
||||
return f"{et.value}: {data}"
|
||||
self.write_log(format_event(event))
|
||||
|
||||
def write_python_log(self, record: logging.LogRecord) -> None:
|
||||
"""Format a Python log record with timestamp and severity color."""
|
||||
ts = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
|
||||
color = self._LOG_LEVEL_COLORS.get(record.levelno, "")
|
||||
msg = record.getMessage()
|
||||
if color:
|
||||
self.write_log(f"[dim]{ts}[/dim] [{color}]{record.levelname}[/{color}] {msg}")
|
||||
else:
|
||||
self.write_log(f"[dim]{ts}[/dim] {record.levelname} {msg}")
|
||||
self.write_log(format_python_log(record))
|
||||
|
||||
def write_log(self, message: str) -> None:
|
||||
"""Write a log message to the log pane."""
|
||||
try:
|
||||
# Check if widget is mounted
|
||||
if not self.is_mounted:
|
||||
return
|
||||
|
||||
log = self.query_one("#main-log", RichLog)
|
||||
|
||||
# Check if log is mounted
|
||||
if not log.is_mounted:
|
||||
return
|
||||
|
||||
# Only auto-scroll if user is already at the bottom
|
||||
was_at_bottom = log.is_vertical_scroll_end
|
||||
|
||||
log.write(message)
|
||||
|
||||
@@ -195,12 +195,27 @@ def _copy_to_clipboard(text: str) -> None:
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.run(["pbcopy"], input=text.encode(), check=True, timeout=5)
|
||||
elif sys.platform.startswith("linux"):
|
||||
elif sys.platform == "win32":
|
||||
subprocess.run(
|
||||
["xclip", "-selection", "clipboard"],
|
||||
input=text.encode(),
|
||||
["clip.exe"],
|
||||
input=text.encode("utf-16le"),
|
||||
check=True,
|
||||
timeout=5,
|
||||
)
|
||||
elif sys.platform.startswith("linux"):
|
||||
try:
|
||||
subprocess.run(
|
||||
["xclip", "-selection", "clipboard"],
|
||||
input=text.encode(),
|
||||
check=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
subprocess.run(
|
||||
["xsel", "--clipboard", "--input"],
|
||||
input=text.encode(),
|
||||
check=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Utility functions for the Hive framework."""
|
||||
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
__all__ = ["atomic_write"]
|
||||
|
||||
@@ -20,6 +20,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
tui = ["textual>=0.75.0"]
|
||||
webhook = ["aiohttp>=3.9.0"]
|
||||
|
||||
[project.scripts]
|
||||
hive = "framework.cli:main"
|
||||
|
||||
@@ -1,342 +0,0 @@
|
||||
"""Tests for the BuilderQuery interface - how Builder analyzes agent runs.
|
||||
|
||||
DEPRECATED: These tests rely on the deprecated FileStorage backend.
|
||||
BuilderQuery and Runtime both use FileStorage which is deprecated.
|
||||
New code should use unified session storage instead.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework import BuilderQuery, Runtime
|
||||
from framework.schemas.run import RunStatus
|
||||
|
||||
# Mark all tests in this module as skipped - they rely on deprecated FileStorage
|
||||
pytestmark = pytest.mark.skip(reason="Tests rely on deprecated FileStorage backend")
|
||||
|
||||
|
||||
def create_successful_run(runtime: Runtime, goal_id: str = "test_goal") -> str:
|
||||
"""Helper to create a successful run with decisions."""
|
||||
run_id = runtime.start_run(goal_id, f"Test goal: {goal_id}")
|
||||
|
||||
runtime.set_node("search-node")
|
||||
d1 = runtime.decide(
|
||||
intent="Search for data",
|
||||
options=[
|
||||
{"id": "web", "description": "Web search", "pros": ["Fresh data"]},
|
||||
{"id": "cache", "description": "Use cache", "pros": ["Fast"]},
|
||||
],
|
||||
chosen="web",
|
||||
reasoning="Need fresh data",
|
||||
)
|
||||
runtime.record_outcome(d1, success=True, result={"items": 3}, tokens_used=50)
|
||||
|
||||
runtime.set_node("process-node")
|
||||
d2 = runtime.decide(
|
||||
intent="Process results",
|
||||
options=[{"id": "filter", "description": "Filter and transform"}],
|
||||
chosen="filter",
|
||||
reasoning="Standard processing",
|
||||
)
|
||||
runtime.record_outcome(d2, success=True, result={"processed": 3}, tokens_used=30)
|
||||
|
||||
runtime.end_run(success=True, narrative="Successfully processed data")
|
||||
return run_id
|
||||
|
||||
|
||||
def create_failed_run(runtime: Runtime, goal_id: str = "test_goal") -> str:
|
||||
"""Helper to create a failed run."""
|
||||
run_id = runtime.start_run(goal_id, f"Test goal: {goal_id}")
|
||||
|
||||
runtime.set_node("search-node")
|
||||
d1 = runtime.decide(
|
||||
intent="Search for data",
|
||||
options=[{"id": "web", "description": "Web search"}],
|
||||
chosen="web",
|
||||
reasoning="Need data",
|
||||
)
|
||||
runtime.record_outcome(d1, success=True, result={"items": 0})
|
||||
|
||||
runtime.set_node("process-node")
|
||||
d2 = runtime.decide(
|
||||
intent="Process results",
|
||||
options=[{"id": "process", "description": "Process data"}],
|
||||
chosen="process",
|
||||
reasoning="Continue pipeline",
|
||||
)
|
||||
runtime.record_outcome(d2, success=False, error="No data to process")
|
||||
|
||||
runtime.report_problem(
|
||||
severity="critical",
|
||||
description="Processing failed due to empty input",
|
||||
decision_id=d2,
|
||||
suggested_fix="Add empty input handling",
|
||||
)
|
||||
|
||||
runtime.end_run(success=False, narrative="Failed to process - no data")
|
||||
return run_id
|
||||
|
||||
|
||||
class TestBuilderQueryBasics:
|
||||
"""Test basic query operations."""
|
||||
|
||||
def test_get_run_summary(self, tmp_path: Path):
|
||||
"""Test getting a run summary."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
summary = query.get_run_summary(run_id)
|
||||
|
||||
assert summary is not None
|
||||
assert summary.run_id == run_id
|
||||
assert summary.status == RunStatus.COMPLETED
|
||||
assert summary.decision_count == 2
|
||||
assert summary.success_rate == 1.0
|
||||
|
||||
def test_get_full_run(self, tmp_path: Path):
|
||||
"""Test getting the full run details."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
run = query.get_full_run(run_id)
|
||||
|
||||
assert run is not None
|
||||
assert run.id == run_id
|
||||
assert len(run.decisions) == 2
|
||||
assert run.decisions[0].node_id == "search-node"
|
||||
assert run.decisions[1].node_id == "process-node"
|
||||
|
||||
def test_list_runs_for_goal(self, tmp_path: Path):
|
||||
"""Test listing all runs for a goal."""
|
||||
runtime = Runtime(tmp_path)
|
||||
create_successful_run(runtime, "goal_a")
|
||||
create_successful_run(runtime, "goal_a")
|
||||
create_successful_run(runtime, "goal_b")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
summaries = query.list_runs_for_goal("goal_a")
|
||||
|
||||
assert len(summaries) == 2
|
||||
for s in summaries:
|
||||
assert s.goal_id == "goal_a"
|
||||
|
||||
def test_get_recent_failures(self, tmp_path: Path):
|
||||
"""Test getting recent failed runs."""
|
||||
runtime = Runtime(tmp_path)
|
||||
create_successful_run(runtime)
|
||||
create_failed_run(runtime)
|
||||
create_failed_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
failures = query.get_recent_failures()
|
||||
|
||||
assert len(failures) == 2
|
||||
for f in failures:
|
||||
assert f.status == RunStatus.FAILED
|
||||
|
||||
|
||||
class TestFailureAnalysis:
|
||||
"""Test failure analysis capabilities."""
|
||||
|
||||
def test_analyze_failure(self, tmp_path: Path):
|
||||
"""Test analyzing why a run failed."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_failed_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
analysis = query.analyze_failure(run_id)
|
||||
|
||||
assert analysis is not None
|
||||
assert analysis.run_id == run_id
|
||||
assert "No data to process" in analysis.root_cause
|
||||
assert len(analysis.decision_chain) >= 2
|
||||
assert len(analysis.problems) == 1
|
||||
assert "critical" in analysis.problems[0].lower()
|
||||
|
||||
def test_analyze_failure_returns_none_for_success(self, tmp_path: Path):
|
||||
"""analyze_failure returns None for successful runs."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
analysis = query.analyze_failure(run_id)
|
||||
|
||||
assert analysis is None
|
||||
|
||||
def test_failure_analysis_has_suggestions(self, tmp_path: Path):
|
||||
"""Failure analysis should include suggestions."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_failed_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
analysis = query.analyze_failure(run_id)
|
||||
|
||||
assert len(analysis.suggestions) > 0
|
||||
# Should include the suggested fix from the problem
|
||||
assert any("empty input" in s.lower() for s in analysis.suggestions)
|
||||
|
||||
def test_get_decision_trace(self, tmp_path: Path):
|
||||
"""Test getting a readable decision trace."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
trace = query.get_decision_trace(run_id)
|
||||
|
||||
assert len(trace) == 2
|
||||
assert "search-node" in trace[0]
|
||||
assert "process-node" in trace[1]
|
||||
|
||||
|
||||
class TestPatternAnalysis:
|
||||
"""Test pattern detection across runs."""
|
||||
|
||||
def test_find_patterns_basic(self, tmp_path: Path):
|
||||
"""Test basic pattern finding."""
|
||||
runtime = Runtime(tmp_path)
|
||||
create_successful_run(runtime, "goal_x")
|
||||
create_successful_run(runtime, "goal_x")
|
||||
create_failed_run(runtime, "goal_x")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
patterns = query.find_patterns("goal_x")
|
||||
|
||||
assert patterns is not None
|
||||
assert patterns.goal_id == "goal_x"
|
||||
assert patterns.run_count == 3
|
||||
assert 0 < patterns.success_rate < 1 # 2/3 success
|
||||
|
||||
def test_find_patterns_common_failures(self, tmp_path: Path):
|
||||
"""Test finding common failures."""
|
||||
runtime = Runtime(tmp_path)
|
||||
# Create multiple runs with the same failure
|
||||
for _ in range(3):
|
||||
create_failed_run(runtime, "failing_goal")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
patterns = query.find_patterns("failing_goal")
|
||||
|
||||
assert len(patterns.common_failures) > 0
|
||||
# "No data to process" should be a common failure
|
||||
failure_messages = [f[0] for f in patterns.common_failures]
|
||||
assert any("No data to process" in msg for msg in failure_messages)
|
||||
|
||||
def test_find_patterns_problematic_nodes(self, tmp_path: Path):
|
||||
"""Test finding problematic nodes."""
|
||||
runtime = Runtime(tmp_path)
|
||||
# Create runs where process-node always fails
|
||||
for _ in range(3):
|
||||
create_failed_run(runtime, "node_test")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
patterns = query.find_patterns("node_test")
|
||||
|
||||
# process-node should be flagged as problematic
|
||||
problematic_node_ids = [n[0] for n in patterns.problematic_nodes]
|
||||
assert "process-node" in problematic_node_ids
|
||||
|
||||
def test_compare_runs(self, tmp_path: Path):
|
||||
"""Test comparing two runs."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run1 = create_successful_run(runtime)
|
||||
run2 = create_failed_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
comparison = query.compare_runs(run1, run2)
|
||||
|
||||
assert comparison["run_1"]["status"] == "completed"
|
||||
assert comparison["run_2"]["status"] == "failed"
|
||||
assert len(comparison["differences"]) > 0
|
||||
|
||||
|
||||
class TestImprovementSuggestions:
|
||||
"""Test improvement suggestion generation."""
|
||||
|
||||
def test_suggest_improvements(self, tmp_path: Path):
|
||||
"""Test generating improvement suggestions."""
|
||||
runtime = Runtime(tmp_path)
|
||||
# Create runs with failures to trigger suggestions
|
||||
for _ in range(3):
|
||||
create_failed_run(runtime, "improve_goal")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
suggestions = query.suggest_improvements("improve_goal")
|
||||
|
||||
assert len(suggestions) > 0
|
||||
# Should suggest improving the problematic node
|
||||
node_suggestions = [s for s in suggestions if s["type"] == "node_improvement"]
|
||||
assert len(node_suggestions) > 0
|
||||
|
||||
def test_suggest_improvements_for_low_success_rate(self, tmp_path: Path):
|
||||
"""Should suggest architecture review for low success rate."""
|
||||
runtime = Runtime(tmp_path)
|
||||
# 4 failures, 1 success = 20% success rate
|
||||
for _ in range(4):
|
||||
create_failed_run(runtime, "low_success")
|
||||
create_successful_run(runtime, "low_success")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
suggestions = query.suggest_improvements("low_success")
|
||||
|
||||
arch_suggestions = [s for s in suggestions if s["type"] == "architecture"]
|
||||
assert len(arch_suggestions) > 0
|
||||
assert arch_suggestions[0]["priority"] == "high"
|
||||
|
||||
def test_get_node_performance(self, tmp_path: Path):
|
||||
"""Test getting performance metrics for a node."""
|
||||
runtime = Runtime(tmp_path)
|
||||
create_successful_run(runtime)
|
||||
create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
perf = query.get_node_performance("search-node")
|
||||
|
||||
assert perf["node_id"] == "search-node"
|
||||
assert perf["total_decisions"] == 2
|
||||
assert perf["success_rate"] == 1.0
|
||||
assert perf["total_tokens"] == 100 # 50 tokens per run
|
||||
|
||||
|
||||
class TestBuilderWorkflow:
|
||||
"""Test complete Builder workflows."""
|
||||
|
||||
def test_builder_investigation_workflow(self, tmp_path: Path):
|
||||
"""Test a complete investigation workflow as Builder would use it."""
|
||||
runtime = Runtime(tmp_path)
|
||||
|
||||
# Set up scenario: some successes, some failures
|
||||
for _ in range(2):
|
||||
create_successful_run(runtime, "customer_goal")
|
||||
for _ in range(2):
|
||||
create_failed_run(runtime, "customer_goal")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
|
||||
# Step 1: Get overview of the goal
|
||||
summaries = query.list_runs_for_goal("customer_goal")
|
||||
assert len(summaries) == 4
|
||||
|
||||
# Step 2: Find patterns
|
||||
patterns = query.find_patterns("customer_goal")
|
||||
assert patterns.success_rate == 0.5 # 2/4
|
||||
|
||||
# Step 3: Get recent failures
|
||||
failures = query.get_recent_failures()
|
||||
assert len(failures) == 2
|
||||
|
||||
# Step 4: Analyze a specific failure
|
||||
failure_id = failures[0].run_id
|
||||
analysis = query.analyze_failure(failure_id)
|
||||
assert analysis is not None
|
||||
assert len(analysis.suggestions) > 0
|
||||
|
||||
# Step 5: Generate improvement suggestions
|
||||
suggestions = query.suggest_improvements("customer_goal")
|
||||
assert len(suggestions) > 0
|
||||
|
||||
# Step 6: Check node performance
|
||||
perf = query.get_node_performance("process-node")
|
||||
assert perf["success_rate"] < 1.0 # process-node fails in failed runs
|
||||
@@ -1,185 +0,0 @@
|
||||
"""Tests for ConcurrentStorage race condition and cache invalidation fixes."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.schemas.run import Run, RunMetrics, RunStatus
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
|
||||
def create_test_run(
|
||||
run_id: str, goal_id: str = "test-goal", status: RunStatus = RunStatus.RUNNING
|
||||
) -> Run:
|
||||
"""Create a minimal test Run object."""
|
||||
return Run(
|
||||
id=run_id,
|
||||
goal_id=goal_id,
|
||||
status=status,
|
||||
narrative="Test run",
|
||||
metrics=RunMetrics(
|
||||
nodes_executed=[],
|
||||
),
|
||||
decisions=[],
|
||||
problems=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"ConcurrentStorage wraps FileStorage, so these tests no longer work. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidation_on_save(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated when a run is saved.
|
||||
|
||||
This tests the fix for the cache invalidation bug where load_summary()
|
||||
would return stale data after a run was updated.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-1"
|
||||
|
||||
# Create and save initial run
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to populate the cache
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.RUNNING
|
||||
|
||||
# Update run with new status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary again - should get fresh data, not cached stale data
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.COMPLETED, (
|
||||
"Summary cache should be invalidated on save - got stale data"
|
||||
)
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"ConcurrentStorage wraps FileStorage, so these tests no longer work. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_write_cache_consistency(tmp_path: Path):
|
||||
"""Test that cache is only updated after successful batched write.
|
||||
|
||||
This tests the fix for the race condition where cache was updated
|
||||
before the batched write completed.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path, batch_interval=0.05)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-2"
|
||||
|
||||
# Save via batching (immediate=False)
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=False)
|
||||
|
||||
# Before batch flush, cache should NOT contain the run
|
||||
# (This is the fix - previously cache was updated immediately)
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key not in storage._cache, (
|
||||
"Cache should not be updated before batch is flushed"
|
||||
)
|
||||
|
||||
# Wait for batch to flush (poll instead of fixed sleep for CI reliability)
|
||||
for _ in range(500): # 500 * 0.01s = 5s max
|
||||
if cache_key in storage._cache:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# After batch flush, cache should contain the run
|
||||
assert cache_key in storage._cache, "Cache should be updated after batch flush"
|
||||
|
||||
# Verify data on disk matches cache
|
||||
loaded_run = await storage.load_run(run_id, use_cache=False)
|
||||
assert loaded_run is not None
|
||||
assert loaded_run.id == run_id
|
||||
assert loaded_run.status == RunStatus.RUNNING
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"ConcurrentStorage wraps FileStorage, so these tests no longer work. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_immediate_write_updates_cache(tmp_path: Path):
|
||||
"""Test that immediate writes still update cache correctly."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-3"
|
||||
|
||||
# Save with immediate=True
|
||||
run = create_test_run(run_id, status=RunStatus.COMPLETED)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Cache should be updated immediately for immediate writes
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key in storage._cache, "Cache should be updated after immediate write"
|
||||
|
||||
# Verify cached value is correct
|
||||
cached_run = storage._cache[cache_key].value
|
||||
assert cached_run.id == run_id
|
||||
assert cached_run.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"ConcurrentStorage wraps FileStorage, so these tests no longer work. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_cache_invalidated_on_multiple_saves(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated on each save, not just the first."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-4"
|
||||
|
||||
# First save
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to cache it
|
||||
summary1 = await storage.load_summary(run_id)
|
||||
assert summary1.status == RunStatus.RUNNING
|
||||
|
||||
# Second save with new status
|
||||
run.status = RunStatus.RUNNING
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh
|
||||
summary2 = await storage.load_summary(run_id)
|
||||
assert summary2.status == RunStatus.RUNNING
|
||||
|
||||
# Third save with final status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh again
|
||||
summary3 = await storage.load_summary(run_id)
|
||||
assert summary3.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
@@ -0,0 +1,538 @@
|
||||
"""Tests for the Continuous Agent architecture (conversation threading + cumulative tools).
|
||||
|
||||
Validates:
|
||||
- conversation_mode="isolated" preserves existing behavior
|
||||
- conversation_mode="continuous" threads one conversation across nodes
|
||||
- Transition markers are inserted at phase boundaries
|
||||
- System prompt updates at each transition (layered prompt composition)
|
||||
- Tools accumulate across nodes in continuous mode
|
||||
- prompt_composer functions work correctly
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeResult, NodeSpec, SharedMemory
|
||||
from framework.graph.prompt_composer import (
|
||||
build_narrative,
|
||||
build_transition_marker,
|
||||
compose_system_prompt,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences."""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _set_output_scenario(key: str, value: str) -> list:
|
||||
"""LLM calls set_output then finishes."""
|
||||
return [
|
||||
ToolCallEvent(
|
||||
tool_use_id=f"call_{key}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": key, "value": value},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _text_then_set_output(text: str, key: str, value: str) -> list:
|
||||
"""LLM produces text, then calls set_output, then finishes (2 turns needed)."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
ToolCallEvent(
|
||||
tool_use_id=f"call_{key}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": key, "value": value},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _text_finish(text: str) -> list:
|
||||
"""LLM produces text and stops (triggers judge)."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _make_runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
def _make_goal():
|
||||
return Goal(id="g1", name="test", description="test goal")
|
||||
|
||||
|
||||
def _make_tool(name: str) -> Tool:
|
||||
return Tool(
|
||||
name=name,
|
||||
description=f"Tool {name}",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# prompt_composer unit tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestComposeSystemPrompt:
|
||||
def test_all_layers(self):
|
||||
result = compose_system_prompt(
|
||||
identity_prompt="I am a research agent.",
|
||||
focus_prompt="Focus on writing the report.",
|
||||
narrative="We found 5 sources on topic X.",
|
||||
)
|
||||
assert "I am a research agent." in result
|
||||
assert "Focus on writing the report." in result
|
||||
assert "We found 5 sources on topic X." in result
|
||||
# Identity comes first
|
||||
assert result.index("I am a research agent.") < result.index("Focus on writing")
|
||||
|
||||
def test_identity_only(self):
|
||||
result = compose_system_prompt(identity_prompt="I am an agent.", focus_prompt=None)
|
||||
assert result == "I am an agent."
|
||||
|
||||
def test_focus_only(self):
|
||||
result = compose_system_prompt(identity_prompt=None, focus_prompt="Do the thing.")
|
||||
assert "Current Focus" in result
|
||||
assert "Do the thing." in result
|
||||
|
||||
def test_empty(self):
|
||||
result = compose_system_prompt(identity_prompt=None, focus_prompt=None)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestBuildNarrative:
|
||||
def test_with_execution_path(self):
|
||||
memory = SharedMemory()
|
||||
memory.write("findings", "some findings")
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a", name="Research", description="Research the topic", node_type="event_loop"
|
||||
)
|
||||
node_b = NodeSpec(id="b", name="Report", description="Write report", node_type="event_loop")
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
result = build_narrative(memory, ["a"], graph)
|
||||
assert "Research" in result
|
||||
assert "findings" in result
|
||||
|
||||
def test_empty_state(self):
|
||||
memory = SharedMemory()
|
||||
graph = GraphSpec(id="g1", goal_id="g1", entry_node="a", nodes=[], edges=[])
|
||||
result = build_narrative(memory, [], graph)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestBuildTransitionMarker:
|
||||
def test_basic_marker(self):
|
||||
prev = NodeSpec(
|
||||
id="research", name="Research", description="Find sources", node_type="event_loop"
|
||||
)
|
||||
next_n = NodeSpec(
|
||||
id="report", name="Report", description="Write report", node_type="event_loop"
|
||||
)
|
||||
memory = SharedMemory()
|
||||
memory.write("findings", "important stuff")
|
||||
|
||||
marker = build_transition_marker(
|
||||
previous_node=prev,
|
||||
next_node=next_n,
|
||||
memory=memory,
|
||||
cumulative_tool_names=["web_search", "save_data"],
|
||||
)
|
||||
|
||||
assert "PHASE TRANSITION" in marker
|
||||
assert "Research" in marker
|
||||
assert "Report" in marker
|
||||
assert "findings" in marker
|
||||
assert "web_search" in marker
|
||||
assert "reflect" in marker.lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# NodeConversation.update_system_prompt
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestUpdateSystemPrompt:
|
||||
def test_update(self):
|
||||
conv = NodeConversation(system_prompt="original")
|
||||
assert conv.system_prompt == "original"
|
||||
conv.update_system_prompt("updated")
|
||||
assert conv.system_prompt == "updated"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Conversation threading through executor
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestContinuousConversation:
|
||||
"""Test that conversation_mode='continuous' threads a single conversation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_isolated_mode_no_conversation_in_result(self):
|
||||
"""In isolated mode, NodeResult.conversation should be None."""
|
||||
runtime = _make_runtime()
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_set_output_scenario("result", "done"),
|
||||
_text_finish("accepted"),
|
||||
]
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node1",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
conversation_mode="isolated",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuous_threads_conversation(self):
|
||||
"""In continuous mode, second node sees messages from first node."""
|
||||
runtime = _make_runtime()
|
||||
|
||||
# Node A: set_output("brief", "the brief"), then finish (accept)
|
||||
# Node B: set_output("report", "the report"), then finish (accept)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("I'll research this.", "brief", "the brief"),
|
||||
_text_finish(""), # triggers accept for node A (all keys set)
|
||||
_text_then_set_output("Here's the report.", "report", "the report"),
|
||||
_text_finish(""), # triggers accept for node B
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Intake",
|
||||
description="Gather requirements",
|
||||
node_type="event_loop",
|
||||
output_keys=["brief"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Report",
|
||||
description="Write report",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["report"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="continuous",
|
||||
identity_prompt="You are a thorough research agent.",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
|
||||
assert result.success
|
||||
assert result.path == ["a", "b"]
|
||||
|
||||
# Verify the LLM saw the identity prompt in system messages
|
||||
# The second node's system prompt should contain the identity
|
||||
if len(llm.stream_calls) >= 3:
|
||||
system_at_node_b = llm.stream_calls[2]["system"]
|
||||
assert "thorough research agent" in system_at_node_b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuous_transition_marker_present(self):
|
||||
"""Transition marker should appear in messages when switching nodes."""
|
||||
runtime = _make_runtime()
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("Research done.", "brief", "the brief"),
|
||||
_text_finish(""),
|
||||
_text_then_set_output("Report done.", "report", "the report"),
|
||||
_text_finish(""),
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Research",
|
||||
description="Do research",
|
||||
node_type="event_loop",
|
||||
output_keys=["brief"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Report",
|
||||
description="Write report",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["report"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="continuous",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
# When node B's first LLM call happens, its messages should contain
|
||||
# the transition marker from the executor
|
||||
if len(llm.stream_calls) >= 3:
|
||||
node_b_messages = llm.stream_calls[2]["messages"]
|
||||
all_content = " ".join(
|
||||
m.get("content", "") for m in node_b_messages if isinstance(m.get("content"), str)
|
||||
)
|
||||
assert "PHASE TRANSITION" in all_content
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Cumulative tools
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCumulativeTools:
|
||||
"""Test that tools accumulate in continuous mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_isolated_mode_tools_scoped(self):
|
||||
"""In isolated mode, each node only gets its own declared tools."""
|
||||
runtime = _make_runtime()
|
||||
tool_a = _make_tool("web_search")
|
||||
tool_b = _make_tool("save_data")
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("Done.", "brief", "brief"),
|
||||
_text_finish(""),
|
||||
_text_then_set_output("Done.", "report", "report"),
|
||||
_text_finish(""),
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Research",
|
||||
description="Research",
|
||||
node_type="event_loop",
|
||||
output_keys=["brief"],
|
||||
tools=["web_search"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Report",
|
||||
description="Report",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["report"],
|
||||
tools=["save_data"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="isolated",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime,
|
||||
llm=llm,
|
||||
tools=[tool_a, tool_b],
|
||||
)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
# In isolated mode, node B should NOT have web_search
|
||||
if len(llm.stream_calls) >= 3:
|
||||
node_b_tools = llm.stream_calls[2].get("tools") or []
|
||||
tool_names = [t.name for t in node_b_tools]
|
||||
assert "save_data" in tool_names or "set_output" in tool_names
|
||||
# web_search should NOT be present (only set_output + save_data)
|
||||
real_tools = [n for n in tool_names if n != "set_output"]
|
||||
assert "web_search" not in real_tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuous_mode_tools_accumulate(self):
|
||||
"""In continuous mode, node B should have both web_search and save_data."""
|
||||
runtime = _make_runtime()
|
||||
tool_a = _make_tool("web_search")
|
||||
tool_b = _make_tool("save_data")
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("Done.", "brief", "brief"),
|
||||
_text_finish(""),
|
||||
_text_then_set_output("Done.", "report", "report"),
|
||||
_text_finish(""),
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Research",
|
||||
description="Research",
|
||||
node_type="event_loop",
|
||||
output_keys=["brief"],
|
||||
tools=["web_search"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Report",
|
||||
description="Report",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["report"],
|
||||
tools=["save_data"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="continuous",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime,
|
||||
llm=llm,
|
||||
tools=[tool_a, tool_b],
|
||||
)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
# In continuous mode, node B should have BOTH tools
|
||||
if len(llm.stream_calls) >= 3:
|
||||
node_b_tools = llm.stream_calls[2].get("tools") or []
|
||||
tool_names = [t.name for t in node_b_tools]
|
||||
real_tools = [n for n in tool_names if n != "set_output"]
|
||||
assert "web_search" in real_tools
|
||||
assert "save_data" in real_tools
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Schema field defaults
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSchemaDefaults:
|
||||
def test_graphspec_defaults(self):
|
||||
"""New fields should have safe defaults."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[],
|
||||
edges=[],
|
||||
)
|
||||
assert graph.conversation_mode == "continuous"
|
||||
assert graph.identity_prompt is None
|
||||
|
||||
def test_nodespec_defaults(self):
|
||||
"""NodeSpec.success_criteria should default to None."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="test",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
assert spec.success_criteria is None
|
||||
|
||||
def test_noderesult_defaults(self):
|
||||
"""NodeResult.conversation should default to None."""
|
||||
result = NodeResult(success=True)
|
||||
assert result.conversation is None
|
||||
@@ -0,0 +1,380 @@
|
||||
"""Tests for Level 2 conversation-aware judge.
|
||||
|
||||
Validates:
|
||||
- No success_criteria → Level 0 only (existing behavior)
|
||||
- success_criteria set, good conversation → Level 2 ACCEPT
|
||||
- success_criteria set, poor conversation → Level 2 RETRY with feedback
|
||||
- Custom explicit judge takes priority over Level 2
|
||||
- Level 2 fires only when Level 0 passes (all keys set)
|
||||
- _parse_verdict correctly parses LLM responses
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.conversation_judge import (
|
||||
_parse_verdict,
|
||||
evaluate_phase_completion,
|
||||
)
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences."""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None, complete_response: str = ""):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
self.complete_response = complete_response
|
||||
self.complete_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
self.complete_calls.append({"messages": messages, "system": system})
|
||||
return LLMResponse(content=self.complete_response, model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _set_output_scenario(key: str, value: str) -> list:
|
||||
return [
|
||||
ToolCallEvent(
|
||||
tool_use_id=f"call_{key}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": key, "value": value},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _text_then_set_output(text: str, key: str, value: str) -> list:
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
ToolCallEvent(
|
||||
tool_use_id=f"call_{key}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": key, "value": value},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _text_finish(text: str) -> list:
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _make_runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
def _make_goal():
|
||||
return Goal(id="g1", name="test", description="test goal")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Unit tests for _parse_verdict
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestParseVerdict:
|
||||
def test_accept(self):
|
||||
v = _parse_verdict("ACTION: ACCEPT\nCONFIDENCE: 0.9\nFEEDBACK:")
|
||||
assert v.action == "ACCEPT"
|
||||
assert v.confidence == 0.9
|
||||
assert v.feedback == ""
|
||||
|
||||
def test_retry_with_feedback(self):
|
||||
v = _parse_verdict("ACTION: RETRY\nCONFIDENCE: 0.6\nFEEDBACK: Research is too shallow.")
|
||||
assert v.action == "RETRY"
|
||||
assert v.confidence == 0.6
|
||||
assert "shallow" in v.feedback
|
||||
|
||||
def test_defaults_on_garbage(self):
|
||||
v = _parse_verdict("some random text\nno structured output")
|
||||
assert v.action == "ACCEPT" # default
|
||||
assert v.confidence == 0.8 # default
|
||||
|
||||
def test_invalid_action_defaults_to_accept(self):
|
||||
v = _parse_verdict("ACTION: ESCALATE\nCONFIDENCE: 0.5")
|
||||
assert v.action == "ACCEPT" # ESCALATE not valid for Level 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Unit tests for evaluate_phase_completion
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEvaluatePhaseCompletion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_on_good_response(self):
|
||||
"""LLM says ACCEPT → verdict is ACCEPT."""
|
||||
llm = MockStreamingLLM(complete_response="ACTION: ACCEPT\nCONFIDENCE: 0.95\nFEEDBACK:")
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
await conv.add_user_message("Do research on topic X")
|
||||
await conv.add_assistant_message("I found 5 high-quality sources on X.")
|
||||
|
||||
verdict = await evaluate_phase_completion(
|
||||
llm=llm,
|
||||
conversation=conv,
|
||||
phase_name="Research",
|
||||
phase_description="Research the topic",
|
||||
success_criteria="Find at least 3 credible sources",
|
||||
accumulator_state={"findings": "5 sources found"},
|
||||
)
|
||||
assert verdict.action == "ACCEPT"
|
||||
assert verdict.confidence == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_on_poor_response(self):
|
||||
"""LLM says RETRY → verdict is RETRY with feedback."""
|
||||
llm = MockStreamingLLM(
|
||||
complete_response=(
|
||||
"ACTION: RETRY\nCONFIDENCE: 0.4\nFEEDBACK: Only found 1 source, need 3."
|
||||
)
|
||||
)
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
await conv.add_user_message("Do research")
|
||||
await conv.add_assistant_message("I found 1 source.")
|
||||
|
||||
verdict = await evaluate_phase_completion(
|
||||
llm=llm,
|
||||
conversation=conv,
|
||||
phase_name="Research",
|
||||
phase_description="Research the topic",
|
||||
success_criteria="Find at least 3 credible sources",
|
||||
accumulator_state={"findings": "1 source"},
|
||||
)
|
||||
assert verdict.action == "RETRY"
|
||||
assert "1 source" in verdict.feedback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_defaults_to_accept(self):
|
||||
"""When LLM fails, Level 2 should not block (Level 0 already passed)."""
|
||||
llm = MockStreamingLLM()
|
||||
# Make complete() raise an exception
|
||||
llm.complete = MagicMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
await conv.add_assistant_message("Done.")
|
||||
|
||||
verdict = await evaluate_phase_completion(
|
||||
llm=llm,
|
||||
conversation=conv,
|
||||
phase_name="Test",
|
||||
phase_description="Test phase",
|
||||
success_criteria="Do the thing",
|
||||
accumulator_state={"result": "done"},
|
||||
)
|
||||
assert verdict.action == "ACCEPT"
|
||||
assert verdict.confidence == 0.5
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Integration: Level 2 in EventLoopNode implicit judge
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLevel2InImplicitJudge:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_success_criteria_level0_only(self):
|
||||
"""Without success_criteria, Level 0 accepts normally (existing behavior)."""
|
||||
runtime = _make_runtime()
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_set_output_scenario("result", "done"),
|
||||
_text_finish("accepted"),
|
||||
]
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node1",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
# No success_criteria!
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
# LLM.complete should NOT have been called for Level 2
|
||||
assert len(llm.complete_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_criteria_accept(self):
|
||||
"""With success_criteria and good work, Level 2 accepts."""
|
||||
runtime = _make_runtime()
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("I did thorough research.", "result", "done"),
|
||||
_text_finish(""), # triggers judge
|
||||
],
|
||||
complete_response="ACTION: ACCEPT\nCONFIDENCE: 0.9\nFEEDBACK:",
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Research",
|
||||
description="Do research",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
success_criteria="Provide thorough research with multiple sources.",
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
# LLM.complete should have been called for Level 2
|
||||
assert len(llm.complete_calls) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_criteria_retry_then_accept(self):
|
||||
"""Level 2 rejects first attempt, LLM tries again, Level 2 accepts."""
|
||||
runtime = _make_runtime()
|
||||
|
||||
# Track complete calls to alternate responses
|
||||
complete_responses = [
|
||||
"ACTION: RETRY\nCONFIDENCE: 0.4\nFEEDBACK: Need more detail.",
|
||||
"ACTION: ACCEPT\nCONFIDENCE: 0.9\nFEEDBACK:",
|
||||
]
|
||||
call_count = [0]
|
||||
|
||||
class SequentialLLM(MockStreamingLLM):
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
idx = call_count[0]
|
||||
call_count[0] += 1
|
||||
resp = complete_responses[idx % len(complete_responses)]
|
||||
return LLMResponse(content=resp, model="mock", stop_reason="stop")
|
||||
|
||||
llm = SequentialLLM(
|
||||
scenarios=[
|
||||
# Turn 1: set output, then stop → Level 2 RETRY
|
||||
_text_then_set_output("Brief research.", "result", "brief"),
|
||||
_text_finish(""), # triggers judge → Level 2 RETRY
|
||||
# Turn 2: after retry feedback, set output again, stop → Level 2 ACCEPT
|
||||
_text_then_set_output("Much more detailed research.", "result", "detailed"),
|
||||
_text_finish(""), # triggers judge → Level 2 ACCEPT
|
||||
]
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Research",
|
||||
description="Do research",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
success_criteria="Provide thorough research with multiple sources.",
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
# Should have had 2 complete calls (first RETRY, second ACCEPT)
|
||||
assert call_count[0] >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_level2_only_fires_when_level0_passes(self):
|
||||
"""Level 2 should NOT fire when output keys are missing."""
|
||||
runtime = _make_runtime()
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: just text, no set_output → Level 0 RETRY (missing keys)
|
||||
_text_finish("I did some thinking."),
|
||||
# Turn 2: set output → Level 0 ACCEPT, Level 2 check
|
||||
_text_then_set_output("Now I have output.", "result", "done"),
|
||||
_text_finish(""), # triggers judge
|
||||
],
|
||||
complete_response="ACTION: ACCEPT\nCONFIDENCE: 0.9\nFEEDBACK:",
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Research",
|
||||
description="Do research",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
success_criteria="Provide results.",
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
# Level 2 should only fire once (when Level 0 passes)
|
||||
assert len(llm.complete_calls) == 1
|
||||
@@ -86,6 +86,7 @@ class ScriptableMockLLMProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="Conversation summary for compaction.",
|
||||
@@ -929,6 +930,7 @@ async def test_context_handoff_between_nodes(runtime):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_client_facing_node_streams_output():
|
||||
"""Client-facing node emits CLIENT_OUTPUT_DELTA events."""
|
||||
recorded: list[AgentEvent] = []
|
||||
@@ -951,7 +953,7 @@ async def test_client_facing_node_streams_output():
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
|
||||
# Text-only on client_facing no longer blocks (no ask_user called),
|
||||
# Text-only on client_facing does not block (no ask_user called),
|
||||
# so the node completes without needing a shutdown workaround.
|
||||
result = await node.execute(ctx)
|
||||
|
||||
|
||||
@@ -425,6 +425,7 @@ class TestEventBusLifecycle:
|
||||
assert EventType.NODE_LOOP_COMPLETED in received_events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
|
||||
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
|
||||
spec = NodeSpec(
|
||||
@@ -475,6 +476,7 @@ class TestClientFacingBlocking:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_text_only_no_blocking(self, runtime, memory, client_spec):
|
||||
"""client_facing + text-only (no ask_user) should NOT block."""
|
||||
llm = MockStreamingLLM(
|
||||
@@ -630,6 +632,7 @@ class TestClientFacingBlocking:
|
||||
assert received[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_ask_user_with_real_tools(self, runtime, memory):
|
||||
"""ask_user alongside real tool calls still triggers blocking."""
|
||||
spec = NodeSpec(
|
||||
@@ -993,3 +996,697 @@ class TestOutputAccumulator:
|
||||
assert acc.get("key1") == "val1"
|
||||
assert acc.get("key2") == "val2"
|
||||
assert acc.has_all_keys(["key1", "key2"]) is True
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Transient error retry (ITEM 2)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class ErrorThenSuccessLLM(LLMProvider):
|
||||
"""LLM that raises on the first N calls, then succeeds.
|
||||
|
||||
Used to test the retry-with-backoff wrapper around _run_single_turn().
|
||||
"""
|
||||
|
||||
def __init__(self, error: Exception, fail_count: int, success_scenario: list):
|
||||
self.error = error
|
||||
self.fail_count = fail_count
|
||||
self.success_scenario = success_scenario
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
call_num = self._call_index
|
||||
self._call_index += 1
|
||||
if call_num < self.fail_count:
|
||||
raise self.error
|
||||
for event in self.success_scenario:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="ok", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
class TestTransientErrorRetry:
|
||||
"""Test retry-with-backoff for transient LLM errors in EventLoopNode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_retries_then_succeeds(self, runtime, node_spec, memory):
|
||||
"""A transient error on the first try should retry and succeed."""
|
||||
node_spec.output_keys = []
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ConnectionError("connection reset"),
|
||||
fail_count=1,
|
||||
success_scenario=text_scenario("success"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01, # fast for tests
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert llm._call_index == 2 # 1 failure + 1 success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permanent_error_no_retry(self, runtime, node_spec, memory):
|
||||
"""A permanent error (ValueError) should NOT be retried."""
|
||||
node_spec.output_keys = []
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ValueError("bad request: invalid model"),
|
||||
fail_count=1,
|
||||
success_scenario=text_scenario("success"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError, match="bad request"):
|
||||
await node.execute(ctx)
|
||||
assert llm._call_index == 1 # only tried once
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_exhausts_retries(self, runtime, node_spec, memory):
|
||||
"""Transient errors that exhaust retries should raise."""
|
||||
node_spec.output_keys = []
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=TimeoutError("request timed out"),
|
||||
fail_count=100, # always fails
|
||||
success_scenario=text_scenario("unreachable"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=2,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
with pytest.raises(TimeoutError, match="request timed out"):
|
||||
await node.execute(ctx)
|
||||
assert llm._call_index == 3 # 1 initial + 2 retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_error_event_retried_as_runtime_error(self, runtime, node_spec, memory):
|
||||
"""StreamErrorEvent(recoverable=False) raises RuntimeError caught by retry."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
# Scenario: non-recoverable StreamErrorEvent with transient keywords
|
||||
error_scenario = [
|
||||
StreamErrorEvent(
|
||||
error="Stream error: 503 service unavailable",
|
||||
recoverable=False,
|
||||
)
|
||||
]
|
||||
success_scenario = text_scenario("recovered")
|
||||
|
||||
call_index = 0
|
||||
|
||||
class StreamErrorThenSuccessLLM(LLMProvider):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
nonlocal call_index
|
||||
idx = call_index
|
||||
call_index += 1
|
||||
if idx == 0:
|
||||
for event in error_scenario:
|
||||
yield event
|
||||
else:
|
||||
for event in success_scenario:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
return LLMResponse(
|
||||
content="ok",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kwargs,
|
||||
):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
llm = StreamErrorThenSuccessLLM()
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert call_index == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_emits_event_bus_event(self, runtime, node_spec, memory):
|
||||
"""Retry should emit NODE_RETRY event on the event bus."""
|
||||
node_spec.output_keys = []
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ConnectionError("network down"),
|
||||
fail_count=1,
|
||||
success_scenario=text_scenario("ok"),
|
||||
)
|
||||
bus = EventBus()
|
||||
retry_events = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.NODE_RETRY],
|
||||
handler=lambda e: retry_events.append(e),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert len(retry_events) == 1
|
||||
assert retry_events[0].data["retry_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recoverable_stream_error_retried_not_silent(self, runtime, node_spec, memory):
|
||||
"""Recoverable StreamErrorEvent with empty response should raise ConnectionError.
|
||||
|
||||
Previously, recoverable stream errors were silently swallowed,
|
||||
producing empty responses that the judge retried — creating an
|
||||
infinite loop of 50+ empty-response iterations. Now they raise
|
||||
ConnectionError so the outer transient-error retry handles them
|
||||
with proper backoff.
|
||||
"""
|
||||
node_spec.output_keys = ["result"]
|
||||
|
||||
call_index = 0
|
||||
|
||||
class RecoverableErrorThenSuccessLLM(LLMProvider):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
nonlocal call_index
|
||||
idx = call_index
|
||||
call_index += 1
|
||||
if idx == 0:
|
||||
# Recoverable error with no content
|
||||
yield StreamErrorEvent(
|
||||
error="503 service unavailable",
|
||||
recoverable=True,
|
||||
)
|
||||
elif idx == 1:
|
||||
# Success: set output
|
||||
for event in tool_call_scenario(
|
||||
"set_output", {"key": "result", "value": "done"}
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
# Subsequent calls: text-only (no more tool calls)
|
||||
for event in text_scenario("done"):
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
return LLMResponse(content="ok", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs):
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
llm = RecoverableErrorThenSuccessLLM()
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert result.output.get("result") == "done"
|
||||
# call 0: recoverable error → ConnectionError raised → outer retry
|
||||
# call 1: set_output tool call succeeds
|
||||
# call 2: inner tool loop re-invokes LLM after tool result → text "done"
|
||||
assert call_index == 3
|
||||
|
||||
|
||||
class TestIsTransientError:
|
||||
"""Unit tests for _is_transient_error() classification."""
|
||||
|
||||
def test_timeout_error(self):
|
||||
assert EventLoopNode._is_transient_error(TimeoutError("timed out")) is True
|
||||
|
||||
def test_connection_error(self):
|
||||
assert EventLoopNode._is_transient_error(ConnectionError("reset")) is True
|
||||
|
||||
def test_os_error(self):
|
||||
assert EventLoopNode._is_transient_error(OSError("network unreachable")) is True
|
||||
|
||||
def test_value_error_not_transient(self):
|
||||
assert EventLoopNode._is_transient_error(ValueError("bad input")) is False
|
||||
|
||||
def test_type_error_not_transient(self):
|
||||
assert EventLoopNode._is_transient_error(TypeError("wrong type")) is False
|
||||
|
||||
def test_runtime_error_with_transient_keywords(self):
|
||||
check = EventLoopNode._is_transient_error
|
||||
assert check(RuntimeError("Stream error: 429 rate limit")) is True
|
||||
assert check(RuntimeError("Stream error: 503")) is True
|
||||
assert check(RuntimeError("Stream error: connection reset")) is True
|
||||
assert check(RuntimeError("Stream error: timeout exceeded")) is True
|
||||
|
||||
def test_runtime_error_without_transient_keywords(self):
|
||||
assert EventLoopNode._is_transient_error(RuntimeError("authentication failed")) is False
|
||||
assert EventLoopNode._is_transient_error(RuntimeError("invalid JSON in response")) is False
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tool doom loop detection (ITEM 1)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestFingerprintToolCalls:
|
||||
"""Unit tests for _fingerprint_tool_calls()."""
|
||||
|
||||
def test_basic_fingerprint(self):
|
||||
results = [
|
||||
{"tool_name": "search", "tool_input": {"q": "hello"}},
|
||||
]
|
||||
fps = EventLoopNode._fingerprint_tool_calls(results)
|
||||
assert len(fps) == 1
|
||||
assert fps[0][0] == "search"
|
||||
# Args should be JSON with sort_keys
|
||||
assert fps[0][1] == '{"q": "hello"}'
|
||||
|
||||
def test_order_sensitive(self):
|
||||
r1 = [
|
||||
{"tool_name": "search", "tool_input": {"q": "a"}},
|
||||
{"tool_name": "fetch", "tool_input": {"url": "b"}},
|
||||
]
|
||||
r2 = [
|
||||
{"tool_name": "fetch", "tool_input": {"url": "b"}},
|
||||
{"tool_name": "search", "tool_input": {"q": "a"}},
|
||||
]
|
||||
assert EventLoopNode._fingerprint_tool_calls(r1) != (
|
||||
EventLoopNode._fingerprint_tool_calls(r2)
|
||||
)
|
||||
|
||||
def test_sort_keys_deterministic(self):
|
||||
r1 = [{"tool_name": "t", "tool_input": {"b": 2, "a": 1}}]
|
||||
r2 = [{"tool_name": "t", "tool_input": {"a": 1, "b": 2}}]
|
||||
assert EventLoopNode._fingerprint_tool_calls(r1) == EventLoopNode._fingerprint_tool_calls(
|
||||
r2
|
||||
)
|
||||
|
||||
|
||||
class TestIsToolDoomLoop:
|
||||
"""Unit tests for _is_tool_doom_loop()."""
|
||||
|
||||
def test_below_threshold(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
fp = [("search", '{"q": "hello"}')]
|
||||
is_doom, _ = node._is_tool_doom_loop([fp, fp])
|
||||
assert is_doom is False
|
||||
|
||||
def test_at_threshold_identical(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
fp = [("search", '{"q": "hello"}')]
|
||||
is_doom, desc = node._is_tool_doom_loop([fp, fp, fp])
|
||||
assert is_doom is True
|
||||
assert "search" in desc
|
||||
|
||||
def test_different_args_no_doom(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
fp1 = [("search", '{"q": "a"}')]
|
||||
fp2 = [("search", '{"q": "b"}')]
|
||||
fp3 = [("search", '{"q": "c"}')]
|
||||
is_doom, _ = node._is_tool_doom_loop([fp1, fp2, fp3])
|
||||
assert is_doom is False
|
||||
|
||||
def test_disabled_via_config(self):
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(tool_doom_loop_enabled=False),
|
||||
)
|
||||
fp = [("search", '{"q": "hello"}')]
|
||||
is_doom, _ = node._is_tool_doom_loop([fp, fp, fp])
|
||||
assert is_doom is False
|
||||
|
||||
def test_empty_fingerprints_no_doom(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
is_doom, _ = node._is_tool_doom_loop([[], [], []])
|
||||
assert is_doom is False
|
||||
|
||||
|
||||
class ToolRepeatLLM(LLMProvider):
|
||||
"""LLM that produces identical tool calls across outer iterations.
|
||||
|
||||
Alternates: even calls → tool call, odd calls → text (exits inner loop).
|
||||
This ensures each outer iteration = 2 LLM calls with 1 tool executed.
|
||||
After tool_turns outer iterations, always returns text.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
tool_turns: int,
|
||||
final_text: str = "done",
|
||||
):
|
||||
self.tool_name = tool_name
|
||||
self.tool_input = tool_input
|
||||
self.tool_turns = tool_turns
|
||||
self.final_text = final_text
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
idx = self._call_index
|
||||
self._call_index += 1
|
||||
# Which outer iteration we're in (2 calls per iteration)
|
||||
outer_iter = idx // 2
|
||||
is_tool_call = (idx % 2 == 0) and outer_iter < self.tool_turns
|
||||
if is_tool_call:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"call_{outer_iter}",
|
||||
tool_name=self.tool_name,
|
||||
tool_input=self.tool_input,
|
||||
)
|
||||
yield FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
else:
|
||||
# Unique text per call to avoid stall detection
|
||||
text = f"{self.final_text} (call {idx})"
|
||||
yield TextDeltaEvent(content=text, snapshot=text)
|
||||
yield FinishEvent(
|
||||
stop_reason="stop",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="ok",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
|
||||
class TestToolDoomLoopIntegration:
|
||||
"""Integration tests for doom loop detection in execute().
|
||||
|
||||
Uses ToolRepeatLLM: returns tool calls for first N calls, then text.
|
||||
Each outer iteration = 2 LLM calls (tool call + text exit for inner loop).
|
||||
logged_tool_calls accumulates across inner iterations.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doom_loop_injects_warning(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""3 identical tool call turns should inject a warning."""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
# 3 tool calls (6 LLM calls: tool+text each), then 1 text
|
||||
llm = ToolRepeatLLM("search", {"q": "hello"}, tool_turns=3)
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doom_loop_emits_event(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""Doom loop should emit NODE_TOOL_DOOM_LOOP event."""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
llm = ToolRepeatLLM("search", {"q": "hello"}, tool_turns=3)
|
||||
bus = EventBus()
|
||||
doom_events: list = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.NODE_TOOL_DOOM_LOOP],
|
||||
handler=lambda e: doom_events.append(e),
|
||||
)
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
event_bus=bus,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert len(doom_events) == 1
|
||||
assert "search" in doom_events[0].data["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doom_loop_disabled(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""Disabled doom loop should not trigger with identical calls."""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
llm = ToolRepeatLLM("search", {"q": "hello"}, tool_turns=4)
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_enabled=False,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_args_no_doom_loop(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""Different tool args each turn should NOT trigger doom loop."""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
# LLM that returns different args each call
|
||||
call_idx = 0
|
||||
|
||||
class DiffArgsLLM(LLMProvider):
|
||||
async def stream(self, messages, **kwargs):
|
||||
nonlocal call_idx
|
||||
idx = call_idx
|
||||
call_idx += 1
|
||||
if idx < 3:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"c{idx}",
|
||||
tool_name="search",
|
||||
tool_input={"q": f"query_{idx}"},
|
||||
)
|
||||
yield FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
else:
|
||||
text = f"done (call {idx})"
|
||||
yield TextDeltaEvent(
|
||||
content=text,
|
||||
snapshot=text,
|
||||
)
|
||||
yield FinishEvent(
|
||||
stop_reason="stop",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
|
||||
def complete(self, messages, **kwargs):
|
||||
return LLMResponse(
|
||||
content="ok",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kw,
|
||||
):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
llm = DiffArgsLLM()
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
@@ -26,6 +26,7 @@ class DummyLLMProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, object] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(content=json.dumps({"result": "ok"}), model="dummy")
|
||||
|
||||
@@ -120,3 +121,146 @@ async def test_execution_stream_retention(tmp_path):
|
||||
|
||||
await stream.stop()
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_session_reuses_directory_and_memory(tmp_path):
|
||||
"""When an async entry point uses resume_session_id, it should:
|
||||
1. Run in the same session directory as the primary execution
|
||||
2. Have access to the primary session's memory
|
||||
3. NOT overwrite the primary session's state.json
|
||||
"""
|
||||
goal = Goal(
|
||||
id="test-goal",
|
||||
name="Test",
|
||||
description="Shared session test",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="result",
|
||||
description="Result present",
|
||||
metric="output_contains",
|
||||
target="result",
|
||||
)
|
||||
],
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
node = NodeSpec(
|
||||
id="hello",
|
||||
name="Hello",
|
||||
description="Return a result",
|
||||
node_type="llm_generate",
|
||||
input_keys=["user_name"],
|
||||
output_keys=["result"],
|
||||
system_prompt='Return JSON: {"result": "ok"}',
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id=goal.id,
|
||||
version="1.0.0",
|
||||
entry_node="hello",
|
||||
entry_points={"start": "hello"},
|
||||
terminal_nodes=["hello"],
|
||||
pause_nodes=[],
|
||||
nodes=[node],
|
||||
edges=[],
|
||||
default_model="dummy",
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
from framework.storage.session_store import SessionStore
|
||||
|
||||
session_store = SessionStore(tmp_path)
|
||||
|
||||
# Primary stream
|
||||
primary_stream = ExecutionStream(
|
||||
stream_id="primary",
|
||||
entry_spec=EntryPointSpec(
|
||||
id="primary",
|
||||
name="Primary",
|
||||
entry_node="hello",
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
state_manager=SharedStateManager(),
|
||||
storage=storage,
|
||||
outcome_aggregator=OutcomeAggregator(goal, EventBus()),
|
||||
event_bus=None,
|
||||
llm=DummyLLMProvider(),
|
||||
tools=[],
|
||||
tool_executor=None,
|
||||
session_store=session_store,
|
||||
)
|
||||
|
||||
await primary_stream.start()
|
||||
|
||||
# Run primary execution — creates session directory and state.json
|
||||
primary_exec_id = await primary_stream.execute({"user_name": "alice"})
|
||||
primary_result = await primary_stream.wait_for_completion(primary_exec_id, timeout=5)
|
||||
assert primary_result is not None
|
||||
assert primary_result.success
|
||||
|
||||
# Verify primary session's state.json exists and has the primary entry_point
|
||||
primary_state_path = tmp_path / "sessions" / primary_exec_id / "state.json"
|
||||
assert primary_state_path.exists()
|
||||
primary_state = json.loads(primary_state_path.read_text())
|
||||
assert primary_state["entry_point"] == "primary"
|
||||
|
||||
# Async stream — simulates a webhook entry point sharing the session
|
||||
async_stream = ExecutionStream(
|
||||
stream_id="webhook",
|
||||
entry_spec=EntryPointSpec(
|
||||
id="webhook",
|
||||
name="Webhook",
|
||||
entry_node="hello",
|
||||
trigger_type="event",
|
||||
isolation_level="shared",
|
||||
),
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
state_manager=SharedStateManager(),
|
||||
storage=storage,
|
||||
outcome_aggregator=OutcomeAggregator(goal, EventBus()),
|
||||
event_bus=None,
|
||||
llm=DummyLLMProvider(),
|
||||
tools=[],
|
||||
tool_executor=None,
|
||||
session_store=session_store,
|
||||
)
|
||||
|
||||
await async_stream.start()
|
||||
|
||||
# Run async execution with resume_session_id pointing to primary session
|
||||
session_state = {
|
||||
"resume_session_id": primary_exec_id,
|
||||
"memory": {"rules": "star important emails"},
|
||||
}
|
||||
async_exec_id = await async_stream.execute({"event": "new_email"}, session_state=session_state)
|
||||
|
||||
# Should reuse the primary session ID
|
||||
assert async_exec_id == primary_exec_id
|
||||
|
||||
async_result = await async_stream.wait_for_completion(async_exec_id, timeout=5)
|
||||
assert async_result is not None
|
||||
assert async_result.success
|
||||
|
||||
# State.json should NOT have been overwritten by the async execution
|
||||
# (it should still show the primary entry point)
|
||||
final_state = json.loads(primary_state_path.read_text())
|
||||
assert final_state["entry_point"] == "primary"
|
||||
|
||||
# Verify only ONE session directory exists (not two)
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
session_dirs = [d for d in sessions_dir.iterdir() if d.is_dir()]
|
||||
assert len(session_dirs) == 1
|
||||
assert session_dirs[0].name == primary_exec_id
|
||||
|
||||
await primary_stream.stop()
|
||||
await async_stream.stop()
|
||||
await storage.stop()
|
||||
|
||||
@@ -143,6 +143,18 @@ class FakeEventBus:
|
||||
async def emit_node_loop_completed(self, **kwargs):
|
||||
self.events.append(("completed", kwargs))
|
||||
|
||||
async def emit_edge_traversed(self, **kwargs):
|
||||
self.events.append(("edge_traversed", kwargs))
|
||||
|
||||
async def emit_execution_paused(self, **kwargs):
|
||||
self.events.append(("execution_paused", kwargs))
|
||||
|
||||
async def emit_execution_resumed(self, **kwargs):
|
||||
self.events.append(("execution_resumed", kwargs))
|
||||
|
||||
async def emit_node_retry(self, **kwargs):
|
||||
self.events.append(("node_retry", kwargs))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_emits_node_events():
|
||||
@@ -201,15 +213,19 @@ async def test_executor_emits_node_events():
|
||||
assert result.success is True
|
||||
assert result.path == ["n1", "n2"]
|
||||
|
||||
# Should have 4 events: started/completed for n1, then started/completed for n2
|
||||
assert len(event_bus.events) == 4
|
||||
# Should have 5 events: started/completed for n1, edge_traversed, then started/completed for n2
|
||||
assert len(event_bus.events) == 5
|
||||
assert event_bus.events[0] == ("started", {"stream_id": "test-stream", "node_id": "n1"})
|
||||
assert event_bus.events[1] == (
|
||||
"completed",
|
||||
{"stream_id": "test-stream", "node_id": "n1", "iterations": 1},
|
||||
)
|
||||
assert event_bus.events[2] == ("started", {"stream_id": "test-stream", "node_id": "n2"})
|
||||
assert event_bus.events[3] == (
|
||||
assert event_bus.events[2] == (
|
||||
"edge_traversed",
|
||||
{"stream_id": "test-stream", "source_node": "n1", "target_node": "n2"},
|
||||
)
|
||||
assert event_bus.events[3] == ("started", {"stream_id": "test-stream", "node_id": "n2"})
|
||||
assert event_bus.events[4] == (
|
||||
"completed",
|
||||
{"stream_id": "test-stream", "node_id": "n2", "iterations": 1},
|
||||
)
|
||||
|
||||
@@ -9,12 +9,18 @@ For live tests (requires API keys):
|
||||
OPENAI_API_KEY=sk-... pytest tests/test_litellm_provider.py -v -m live
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import LLMProvider, Tool, ToolResult, ToolUse
|
||||
from framework.llm.litellm import LiteLLMProvider, _compute_retry_delay
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
|
||||
|
||||
class TestLiteLLMProviderInit:
|
||||
@@ -532,3 +538,291 @@ class TestJsonMode:
|
||||
messages = call_kwargs["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Please respond with a valid JSON object" in messages[0]["content"]
|
||||
|
||||
|
||||
class TestComputeRetryDelay:
|
||||
"""Test _compute_retry_delay() header parsing and fallback logic."""
|
||||
|
||||
def test_fallback_exponential_backoff(self):
|
||||
"""No exception -> exponential backoff."""
|
||||
assert _compute_retry_delay(0) == 2 # 2 * 2^0
|
||||
assert _compute_retry_delay(1) == 4 # 2 * 2^1
|
||||
assert _compute_retry_delay(2) == 8 # 2 * 2^2
|
||||
assert _compute_retry_delay(3) == 16 # 2 * 2^3
|
||||
|
||||
def test_max_delay_cap(self):
|
||||
"""Backoff should be capped at RATE_LIMIT_MAX_DELAY."""
|
||||
# 2 * 2^10 = 2048, should be capped at 120
|
||||
assert _compute_retry_delay(10) == 120
|
||||
|
||||
def test_custom_max_delay(self):
|
||||
"""Custom max_delay should be respected."""
|
||||
assert _compute_retry_delay(5, max_delay=10) == 10
|
||||
|
||||
def test_retry_after_ms_header(self):
|
||||
"""retry-after-ms header should be parsed as milliseconds."""
|
||||
exc = _make_exception_with_headers({"retry-after-ms": "5000"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 5.0
|
||||
|
||||
def test_retry_after_ms_fractional(self):
|
||||
"""retry-after-ms should handle fractional values."""
|
||||
exc = _make_exception_with_headers({"retry-after-ms": "1500"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 1.5
|
||||
|
||||
def test_retry_after_seconds_header(self):
|
||||
"""retry-after header as seconds should be parsed."""
|
||||
exc = _make_exception_with_headers({"retry-after": "3"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 3.0
|
||||
|
||||
def test_retry_after_seconds_fractional(self):
|
||||
"""retry-after header should handle fractional seconds."""
|
||||
exc = _make_exception_with_headers({"retry-after": "2.5"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 2.5
|
||||
|
||||
def test_retry_after_ms_takes_priority(self):
|
||||
"""retry-after-ms should take priority over retry-after."""
|
||||
exc = _make_exception_with_headers(
|
||||
{
|
||||
"retry-after-ms": "2000",
|
||||
"retry-after": "10",
|
||||
}
|
||||
)
|
||||
assert _compute_retry_delay(0, exception=exc) == 2.0
|
||||
|
||||
def test_retry_after_http_date(self):
|
||||
"""retry-after as HTTP-date should be parsed."""
|
||||
from email.utils import format_datetime
|
||||
|
||||
future = datetime.now(UTC) + timedelta(seconds=5)
|
||||
date_str = format_datetime(future, usegmt=True)
|
||||
exc = _make_exception_with_headers({"retry-after": date_str})
|
||||
delay = _compute_retry_delay(0, exception=exc)
|
||||
assert 3.0 <= delay <= 6.0 # within tolerance
|
||||
|
||||
def test_exception_without_response(self):
|
||||
"""Exception with response=None should fall back to exponential."""
|
||||
exc = Exception("test")
|
||||
exc.response = None # type: ignore[attr-defined]
|
||||
assert _compute_retry_delay(0, exception=exc) == 2 # exponential fallback
|
||||
|
||||
def test_exception_without_response_attr(self):
|
||||
"""Exception without .response attr should fall back to exponential."""
|
||||
exc = ValueError("no response attr")
|
||||
assert _compute_retry_delay(0, exception=exc) == 2
|
||||
|
||||
def test_negative_retry_after_clamped_to_zero(self):
|
||||
"""Negative retry-after should be clamped to 0."""
|
||||
exc = _make_exception_with_headers({"retry-after": "-5"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 0
|
||||
|
||||
def test_negative_retry_after_ms_clamped_to_zero(self):
|
||||
"""Negative retry-after-ms should be clamped to 0."""
|
||||
exc = _make_exception_with_headers({"retry-after-ms": "-1000"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 0
|
||||
|
||||
def test_invalid_retry_after_falls_back(self):
|
||||
"""Non-numeric, non-date retry-after should fall back to exponential."""
|
||||
exc = _make_exception_with_headers({"retry-after": "not-a-number-or-date"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 2 # exponential fallback
|
||||
|
||||
def test_invalid_retry_after_ms_falls_back_to_retry_after(self):
|
||||
"""Invalid retry-after-ms should fall through to retry-after."""
|
||||
exc = _make_exception_with_headers(
|
||||
{
|
||||
"retry-after-ms": "garbage",
|
||||
"retry-after": "7",
|
||||
}
|
||||
)
|
||||
assert _compute_retry_delay(0, exception=exc) == 7.0
|
||||
|
||||
def test_retry_after_capped_at_max_delay(self):
|
||||
"""Server-provided delay should be capped at max_delay."""
|
||||
exc = _make_exception_with_headers({"retry-after": "3600"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 120 # capped
|
||||
|
||||
def test_retry_after_ms_capped_at_max_delay(self):
|
||||
"""Server-provided ms delay should be capped at max_delay."""
|
||||
exc = _make_exception_with_headers({"retry-after-ms": "300000"}) # 300s
|
||||
assert _compute_retry_delay(0, exception=exc) == 120 # capped
|
||||
|
||||
|
||||
def _make_exception_with_headers(headers: dict[str, str]) -> BaseException:
|
||||
"""Create a mock exception with response headers for testing."""
|
||||
exc = Exception("rate limited")
|
||||
response = MagicMock()
|
||||
response.headers = headers
|
||||
exc.response = response # type: ignore[attr-defined]
|
||||
return exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async LLM methods — non-blocking event loop tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAsyncComplete:
|
||||
"""Test that acomplete/acomplete_with_tools don't block the event loop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_acomplete_uses_acompletion(self, mock_acompletion):
|
||||
"""acomplete() should call litellm.acompletion (async), not litellm.completion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "async hello"
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
# acompletion is async, so mock must return a coroutine
|
||||
async def async_return(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
mock_acompletion.side_effect = async_return
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
result = await provider.acomplete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system="You are helpful.",
|
||||
)
|
||||
|
||||
assert result.content == "async hello"
|
||||
assert result.model == "gpt-4o-mini"
|
||||
assert result.input_tokens == 10
|
||||
assert result.output_tokens == 5
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_acomplete_does_not_block_event_loop(self, mock_acompletion):
|
||||
"""Verify event loop stays responsive during acomplete()."""
|
||||
heartbeat_ticks = []
|
||||
|
||||
async def heartbeat():
|
||||
start = time.monotonic()
|
||||
for _ in range(10):
|
||||
heartbeat_ticks.append(time.monotonic() - start)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def slow_acompletion(*args, **kwargs):
|
||||
# Simulate a 300ms LLM call — async, so event loop should stay free
|
||||
await asyncio.sleep(0.3)
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "done"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.model = "gpt-4o-mini"
|
||||
resp.usage.prompt_tokens = 5
|
||||
resp.usage.completion_tokens = 3
|
||||
return resp
|
||||
|
||||
mock_acompletion.side_effect = slow_acompletion
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
# Run heartbeat + acomplete concurrently
|
||||
_, result = await asyncio.gather(
|
||||
heartbeat(),
|
||||
provider.acomplete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
),
|
||||
)
|
||||
|
||||
assert result.content == "done"
|
||||
# Heartbeat should have ticked multiple times during the 300ms LLM call
|
||||
# (if the event loop were blocked, we'd see 0-1 ticks)
|
||||
assert len(heartbeat_ticks) >= 3, (
|
||||
f"Event loop was blocked — only {len(heartbeat_ticks)} heartbeat ticks"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_acomplete_with_tools_uses_acompletion(self, mock_acompletion):
|
||||
"""acomplete_with_tools() should use litellm.acompletion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "tool result"
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
async def async_return(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
mock_acompletion.side_effect = async_return
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
tools = [
|
||||
Tool(
|
||||
name="search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"q": {"type": "string"}}, "required": ["q"]},
|
||||
)
|
||||
]
|
||||
|
||||
result = await provider.acomplete_with_tools(
|
||||
messages=[{"role": "user", "content": "Search for cats"}],
|
||||
system="You are helpful.",
|
||||
tools=tools,
|
||||
tool_executor=lambda tu: ToolResult(tool_use_id=tu.id, content="cats"),
|
||||
)
|
||||
|
||||
assert result.content == "tool result"
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_provider_acomplete(self):
|
||||
"""MockLLMProvider.acomplete() should work without blocking."""
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
provider = MockLLMProvider()
|
||||
result = await provider.acomplete(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
system="Be helpful.",
|
||||
)
|
||||
|
||||
assert result.content # Should have some mock content
|
||||
assert result.model == "mock-model"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_provider_acomplete_offloads_to_executor(self):
|
||||
"""Base LLMProvider.acomplete() should offload sync complete() to thread pool."""
|
||||
call_thread_ids = []
|
||||
|
||||
class SlowSyncProvider(LLMProvider):
|
||||
def complete(
|
||||
self,
|
||||
messages,
|
||||
system="",
|
||||
tools=None,
|
||||
max_tokens=1024,
|
||||
response_format=None,
|
||||
json_mode=False,
|
||||
max_retries=None,
|
||||
):
|
||||
call_thread_ids.append(threading.current_thread().ident)
|
||||
time.sleep(0.1) # Sync blocking
|
||||
return LLMResponse(content="sync done", model="slow")
|
||||
|
||||
def complete_with_tools(
|
||||
self, messages, system, tools, tool_executor, max_iterations=10
|
||||
):
|
||||
return LLMResponse(content="sync tools done", model="slow")
|
||||
|
||||
provider = SlowSyncProvider()
|
||||
main_thread_id = threading.current_thread().ident
|
||||
|
||||
result = await provider.acomplete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert result.content == "sync done"
|
||||
# The sync complete() should have run on a different thread
|
||||
assert call_thread_ids[0] != main_thread_id, (
|
||||
"Base acomplete() should offload sync complete() to a thread pool"
|
||||
)
|
||||
|
||||
@@ -143,6 +143,7 @@ def _has_api_key(env_var: str) -> bool:
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — text streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.skip(reason="Requires valid live API keys — run manually")
|
||||
class TestRealAPITextStreaming:
|
||||
"""Stream a simple text response from each provider and dump events."""
|
||||
|
||||
@@ -204,6 +205,7 @@ class TestRealAPITextStreaming:
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — tool call streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.skip(reason="Requires valid live API keys — run manually")
|
||||
class TestRealAPIToolCallStreaming:
|
||||
"""Stream a tool call response from each provider and dump events."""
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ class MockLLMProvider(LLMProvider):
|
||||
max_tokens=1024,
|
||||
response_format=None,
|
||||
json_mode=False,
|
||||
max_retries=None,
|
||||
):
|
||||
self.complete_calls.append(
|
||||
{
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
"""Tests for phase-aware compaction in continuous conversation mode.
|
||||
|
||||
Validates:
|
||||
- Phase tags persist through storage roundtrip
|
||||
- Transition markers survive compaction
|
||||
- Current phase messages protected during compaction
|
||||
- Older phase tool results pruned first
|
||||
- Phase metadata fields have safe defaults
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import Message, NodeConversation
|
||||
|
||||
|
||||
class TestPhaseMetadata:
|
||||
"""Phase metadata on Message dataclass."""
|
||||
|
||||
def test_defaults(self):
|
||||
msg = Message(seq=0, role="user", content="hello")
|
||||
assert msg.phase_id is None
|
||||
assert msg.is_transition_marker is False
|
||||
|
||||
def test_set_phase(self):
|
||||
msg = Message(seq=0, role="user", content="hello", phase_id="research")
|
||||
assert msg.phase_id == "research"
|
||||
|
||||
def test_transition_marker(self):
|
||||
msg = Message(
|
||||
seq=0,
|
||||
role="user",
|
||||
content="PHASE TRANSITION",
|
||||
is_transition_marker=True,
|
||||
phase_id="report",
|
||||
)
|
||||
assert msg.is_transition_marker is True
|
||||
assert msg.phase_id == "report"
|
||||
|
||||
def test_storage_roundtrip(self):
|
||||
"""Phase metadata should survive to_storage_dict → from_storage_dict."""
|
||||
msg = Message(
|
||||
seq=5,
|
||||
role="user",
|
||||
content="transition",
|
||||
phase_id="review",
|
||||
is_transition_marker=True,
|
||||
)
|
||||
d = msg.to_storage_dict()
|
||||
assert d["phase_id"] == "review"
|
||||
assert d["is_transition_marker"] is True
|
||||
|
||||
restored = Message.from_storage_dict(d)
|
||||
assert restored.phase_id == "review"
|
||||
assert restored.is_transition_marker is True
|
||||
|
||||
def test_storage_roundtrip_no_phase(self):
|
||||
"""Messages without phase metadata should roundtrip cleanly."""
|
||||
msg = Message(seq=0, role="assistant", content="hello")
|
||||
d = msg.to_storage_dict()
|
||||
assert "phase_id" not in d
|
||||
assert "is_transition_marker" not in d
|
||||
|
||||
restored = Message.from_storage_dict(d)
|
||||
assert restored.phase_id is None
|
||||
assert restored.is_transition_marker is False
|
||||
|
||||
def test_to_llm_dict_no_metadata(self):
|
||||
"""Phase metadata should NOT appear in LLM-facing dicts."""
|
||||
msg = Message(
|
||||
seq=0,
|
||||
role="user",
|
||||
content="hello",
|
||||
phase_id="research",
|
||||
is_transition_marker=True,
|
||||
)
|
||||
d = msg.to_llm_dict()
|
||||
assert "phase_id" not in d
|
||||
assert "is_transition_marker" not in d
|
||||
assert d == {"role": "user", "content": "hello"}
|
||||
|
||||
|
||||
class TestPhaseStamping:
|
||||
"""Messages are stamped with current phase."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_stamped_with_phase(self):
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
conv.set_current_phase("research")
|
||||
|
||||
msg1 = await conv.add_user_message("search for X")
|
||||
msg2 = await conv.add_assistant_message("Found it.")
|
||||
|
||||
assert msg1.phase_id == "research"
|
||||
assert msg2.phase_id == "research"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phase_changes_stamp(self):
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
conv.set_current_phase("research")
|
||||
|
||||
msg1 = await conv.add_user_message("research msg")
|
||||
|
||||
conv.set_current_phase("report")
|
||||
msg2 = await conv.add_user_message("report msg")
|
||||
|
||||
assert msg1.phase_id == "research"
|
||||
assert msg2.phase_id == "report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_phase_no_stamp(self):
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
msg = await conv.add_user_message("no phase")
|
||||
assert msg.phase_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transition_marker_flag(self):
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
conv.set_current_phase("report")
|
||||
|
||||
msg = await conv.add_user_message(
|
||||
"PHASE TRANSITION: Research → Report",
|
||||
is_transition_marker=True,
|
||||
)
|
||||
assert msg.is_transition_marker is True
|
||||
assert msg.phase_id == "report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_result_stamped(self):
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
conv.set_current_phase("research")
|
||||
|
||||
msg = await conv.add_tool_result("call_1", "tool output here")
|
||||
assert msg.phase_id == "research"
|
||||
|
||||
|
||||
class TestPhaseAwareCompaction:
|
||||
"""prune_old_tool_results protects current phase and transition markers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transition_marker_survives_compaction(self):
|
||||
"""Transition markers should never be pruned."""
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
|
||||
# Old phase with a big tool result
|
||||
conv.set_current_phase("research")
|
||||
await conv.add_assistant_message(
|
||||
"calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result("call_1", "x" * 20000) # big tool result
|
||||
|
||||
# Transition marker
|
||||
await conv.add_user_message(
|
||||
"PHASE TRANSITION: Research → Report",
|
||||
is_transition_marker=True,
|
||||
)
|
||||
|
||||
# New phase
|
||||
conv.set_current_phase("report")
|
||||
await conv.add_assistant_message(
|
||||
"calling another tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {"name": "save", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result("call_2", "y" * 200)
|
||||
|
||||
pruned = await conv.prune_old_tool_results(protect_tokens=0, min_prune_tokens=100)
|
||||
assert pruned >= 1
|
||||
|
||||
# Transition marker should still be intact
|
||||
marker_msgs = [m for m in conv.messages if m.is_transition_marker]
|
||||
assert len(marker_msgs) == 1
|
||||
assert "PHASE TRANSITION" in marker_msgs[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_phase_protected(self):
|
||||
"""Tool results in the current phase should not be pruned."""
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
|
||||
# Old phase
|
||||
conv.set_current_phase("research")
|
||||
await conv.add_assistant_message(
|
||||
"tool call",
|
||||
tool_calls=[
|
||||
{"id": "c1", "type": "function", "function": {"name": "s", "arguments": "{}"}}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result("c1", "old_data " * 5000)
|
||||
|
||||
# Current phase
|
||||
conv.set_current_phase("report")
|
||||
await conv.add_assistant_message(
|
||||
"tool call",
|
||||
tool_calls=[
|
||||
{"id": "c2", "type": "function", "function": {"name": "s", "arguments": "{}"}}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result("c2", "current_data " * 5000)
|
||||
|
||||
await conv.prune_old_tool_results(protect_tokens=0, min_prune_tokens=100)
|
||||
|
||||
# Old phase's tool result should be pruned
|
||||
msgs = conv.messages
|
||||
old_tool = [m for m in msgs if m.role == "tool" and m.phase_id == "research"]
|
||||
assert len(old_tool) == 1
|
||||
assert old_tool[0].content.startswith("[Pruned tool result")
|
||||
|
||||
# Current phase's tool result should be intact
|
||||
current_tool = [m for m in msgs if m.role == "tool" and m.phase_id == "report"]
|
||||
assert len(current_tool) == 1
|
||||
assert "current_data" in current_tool[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_phase_metadata_works_normally(self):
|
||||
"""Without phase metadata, compaction works as before (no regression)."""
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
|
||||
# No phase set — messages have phase_id=None
|
||||
await conv.add_assistant_message(
|
||||
"tool call",
|
||||
tool_calls=[
|
||||
{"id": "c1", "type": "function", "function": {"name": "s", "arguments": "{}"}}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result("c1", "data " * 5000) # ~6250 tokens
|
||||
|
||||
await conv.add_assistant_message(
|
||||
"another tool call",
|
||||
tool_calls=[
|
||||
{"id": "c2", "type": "function", "function": {"name": "s", "arguments": "{}"}}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result("c2", "more " * 100) # ~125 tokens
|
||||
|
||||
# protect_tokens=100: c2 (~125 tokens) fills the budget,
|
||||
# c1 (~6250 tokens) becomes pruneable
|
||||
pruned = await conv.prune_old_tool_results(protect_tokens=100, min_prune_tokens=100)
|
||||
assert pruned >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pruned_message_preserves_phase_metadata(self):
|
||||
"""Pruned messages should keep their phase_id."""
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
conv.set_current_phase("research")
|
||||
|
||||
await conv.add_assistant_message(
|
||||
"tool call",
|
||||
tool_calls=[
|
||||
{"id": "c1", "type": "function", "function": {"name": "s", "arguments": "{}"}}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result("c1", "data " * 5000)
|
||||
|
||||
# Switch to new phase so research messages become pruneable
|
||||
conv.set_current_phase("report")
|
||||
await conv.add_assistant_message(
|
||||
"recent",
|
||||
tool_calls=[
|
||||
{"id": "c2", "type": "function", "function": {"name": "s", "arguments": "{}"}}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result("c2", "x" * 200)
|
||||
|
||||
await conv.prune_old_tool_results(protect_tokens=0, min_prune_tokens=100)
|
||||
|
||||
pruned_msg = [m for m in conv.messages if m.content.startswith("[Pruned")][0]
|
||||
assert pruned_msg.phase_id == "research"
|
||||
@@ -0,0 +1,172 @@
|
||||
# Agent Runtime
|
||||
|
||||
Unified execution system for all Hive agents. Every agent — single-entry or multi-entry, headless or TUI — runs through the same runtime stack.
|
||||
|
||||
## Topology
|
||||
|
||||
```
|
||||
AgentRunner.load(agent_path)
|
||||
|
|
||||
AgentRunner
|
||||
(factory + public API)
|
||||
|
|
||||
_setup_agent_runtime()
|
||||
|
|
||||
AgentRuntime
|
||||
(lifecycle + orchestration)
|
||||
/ | \\
|
||||
Stream A Stream B Stream C ← one per entry point
|
||||
| | |
|
||||
GraphExecutor GraphExecutor GraphExecutor
|
||||
| | |
|
||||
Node → Node → Node (graph traversal)
|
||||
```
|
||||
|
||||
Single-entry agents get a `"default"` entry point automatically. There is no separate code path.
|
||||
|
||||
## Components
|
||||
|
||||
| Component | File | Role |
|
||||
| --- | --- | --- |
|
||||
| `AgentRunner` | `runner/runner.py` | Load agents, configure tools/LLM, expose high-level API |
|
||||
| `AgentRuntime` | `runtime/agent_runtime.py` | Lifecycle management, entry point routing, event bus |
|
||||
| `ExecutionStream` | `runtime/execution_stream.py` | Per-entry-point execution queue, session persistence |
|
||||
| `GraphExecutor` | `graph/executor.py` | Node traversal, tool dispatch, checkpointing |
|
||||
| `EventBus` | `runtime/event_bus.py` | Pub/sub for execution events (streaming, I/O) |
|
||||
| `SharedStateManager` | `runtime/shared_state.py` | Cross-stream state with isolation levels |
|
||||
| `OutcomeAggregator` | `runtime/outcome_aggregator.py` | Goal progress tracking across streams |
|
||||
| `SessionStore` | `storage/session_store.py` | Session state persistence (`sessions/{id}/state.json`) |
|
||||
|
||||
## Programming Interface
|
||||
|
||||
### AgentRunner (high-level)
|
||||
|
||||
```python
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Load and run
|
||||
runner = AgentRunner.load("exports/my_agent", model="anthropic/claude-sonnet-4-20250514")
|
||||
result = await runner.run({"query": "hello"})
|
||||
|
||||
# Resume from paused session
|
||||
result = await runner.run({"query": "continue"}, session_state=saved_state)
|
||||
|
||||
# Lifecycle
|
||||
await runner.start() # Start the runtime
|
||||
await runner.stop() # Stop the runtime
|
||||
exec_id = await runner.trigger("default", {}) # Non-blocking trigger
|
||||
progress = await runner.get_goal_progress() # Goal evaluation
|
||||
entry_points = runner.get_entry_points() # List entry points
|
||||
|
||||
# Context manager
|
||||
async with AgentRunner.load("exports/my_agent") as runner:
|
||||
result = await runner.run({"query": "hello"})
|
||||
|
||||
# Cleanup
|
||||
runner.cleanup() # Synchronous
|
||||
await runner.cleanup_async() # Asynchronous
|
||||
```
|
||||
|
||||
### AgentRuntime (lower-level)
|
||||
|
||||
```python
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
|
||||
# Create runtime with entry points
|
||||
runtime = create_agent_runtime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path("~/.hive/agents/my_agent"),
|
||||
entry_points=[
|
||||
EntryPointSpec(id="default", name="Default", entry_node="start", trigger_type="manual"),
|
||||
],
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
|
||||
# Lifecycle
|
||||
await runtime.start()
|
||||
await runtime.stop()
|
||||
|
||||
# Execution
|
||||
exec_id = await runtime.trigger("default", {"query": "hello"}) # Non-blocking
|
||||
result = await runtime.trigger_and_wait("default", {"query": "hello"}) # Blocking
|
||||
result = await runtime.trigger_and_wait("default", {}, session_state=state) # Resume
|
||||
|
||||
# Client-facing node I/O
|
||||
await runtime.inject_input(node_id="chat", content="user response")
|
||||
|
||||
# Events
|
||||
sub_id = runtime.subscribe_to_events(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA],
|
||||
handler=my_handler,
|
||||
)
|
||||
runtime.unsubscribe_from_events(sub_id)
|
||||
|
||||
# Inspection
|
||||
runtime.is_running # bool
|
||||
runtime.event_bus # EventBus
|
||||
runtime.state_manager # SharedStateManager
|
||||
runtime.get_stats() # Runtime statistics
|
||||
```
|
||||
|
||||
## Execution Flow
|
||||
|
||||
1. `AgentRunner.run()` calls `AgentRuntime.trigger_and_wait()`
|
||||
2. `AgentRuntime` routes to the `ExecutionStream` for the entry point
|
||||
3. `ExecutionStream` creates a `GraphExecutor` and calls `execute()`
|
||||
4. `GraphExecutor` traverses nodes, dispatches tools, manages checkpoints
|
||||
5. `ExecutionResult` flows back up through the stack
|
||||
6. `ExecutionStream` writes session state to disk
|
||||
|
||||
## Session Resume
|
||||
|
||||
All execution paths support session resume:
|
||||
|
||||
```python
|
||||
# First run (agent pauses at a client-facing node)
|
||||
result = await runner.run({"query": "start task"})
|
||||
# result.paused_at = "review-node"
|
||||
# result.session_state = {"memory": {...}, "paused_at": "review-node", ...}
|
||||
|
||||
# Resume
|
||||
result = await runner.run({"input": "approved"}, session_state=result.session_state)
|
||||
```
|
||||
|
||||
Session state flows: `AgentRunner.run()` → `AgentRuntime.trigger_and_wait()` → `ExecutionStream.execute()` → `GraphExecutor.execute()`.
|
||||
|
||||
Checkpoints are saved at node boundaries (`sessions/{id}/checkpoints/`) for crash recovery.
|
||||
|
||||
## Event Bus
|
||||
|
||||
The `EventBus` provides real-time execution visibility:
|
||||
|
||||
| Event | When |
|
||||
| --- | --- |
|
||||
| `NODE_STARTED` | Node begins execution |
|
||||
| `NODE_COMPLETED` | Node finishes |
|
||||
| `TOOL_CALL_STARTED` | Tool invocation begins |
|
||||
| `TOOL_CALL_COMPLETED` | Tool invocation finishes |
|
||||
| `CLIENT_OUTPUT_DELTA` | Agent streams text to user |
|
||||
| `CLIENT_INPUT_REQUESTED` | Agent needs user input |
|
||||
| `EXECUTION_COMPLETED` | Full execution finishes |
|
||||
|
||||
In headless mode, `AgentRunner` subscribes to `CLIENT_OUTPUT_DELTA` and `CLIENT_INPUT_REQUESTED` to print output and read stdin. In TUI mode, `AdenTUI` subscribes to route events to UI widgets.
|
||||
|
||||
## Storage Layout
|
||||
|
||||
```
|
||||
~/.hive/agents/{agent_name}/
|
||||
sessions/
|
||||
session_YYYYMMDD_HHMMSS_{uuid}/
|
||||
state.json # Session state (status, memory, progress)
|
||||
checkpoints/ # Node-boundary snapshots
|
||||
logs/
|
||||
summary.json # Execution summary
|
||||
details.jsonl # Detailed event log
|
||||
tool_logs.jsonl # Tool call log
|
||||
runtime_logs/ # Cross-session runtime logs
|
||||
```
|
||||
@@ -0,0 +1,214 @@
|
||||
# Antigravity IDE Setup
|
||||
|
||||
Use the Hive agent framework (MCP servers and skills) inside [Antigravity IDE](https://antigravity.google/) (Google’s AI IDE).
|
||||
|
||||
---
|
||||
|
||||
## Quick start (3 steps)
|
||||
|
||||
**Repo root** = the folder that contains `core/`, `tools/`, and `.agent/` (where you cloned the project).
|
||||
|
||||
1. **Open a terminal** and go to the hive repo root (e.g. `cd ~/hive`).
|
||||
2. **Run the setup script** (use `./` so the script runs from this repo; don't use `/scripts/...`):
|
||||
```bash
|
||||
./scripts/setup-antigravity-mcp.sh
|
||||
```
|
||||
3. **Restart Antigravity IDE.** You should see **agent-builder** and **tools** as available MCP servers.
|
||||
|
||||
> **Important:** Always restart/refresh Antigravity IDE after running the setup script or making any changes to MCP configuration. The IDE only loads MCP servers on startup.
|
||||
|
||||
Done. For details, prerequisites, and troubleshooting, read on.
|
||||
|
||||
---
|
||||
|
||||
## What you get after setup
|
||||
|
||||
- **agent-builder** – Create and manage agents (goals, nodes, edges).
|
||||
- **tools** – File operations, web search, and other agent tools.
|
||||
- **Skills** – Guided docs for building and testing agents (in `.agent/skills/` or `.claude/skills/`).
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- [Antigravity IDE](https://antigravity.google/) installed.
|
||||
- **Python 3.11+** and project dependencies. If you haven’t set up the repo yet, from repo root run:
|
||||
```bash
|
||||
./quickstart.sh
|
||||
```
|
||||
- **MCP server dependencies** (one-time). From repo root:
|
||||
```bash
|
||||
cd core && ./setup_mcp.sh
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Full setup (step by step)
|
||||
|
||||
### Step 1: Install MCP dependencies (one-time)
|
||||
|
||||
From the **repo root**:
|
||||
|
||||
```bash
|
||||
cd core
|
||||
./setup_mcp.sh
|
||||
```
|
||||
|
||||
This installs the framework and MCP packages and checks that the server can start.
|
||||
|
||||
### Step 2: Register MCP servers with Antigravity
|
||||
|
||||
Antigravity reads MCP config from your **user config file** (`~/.gemini/antigravity/mcp_config.json`), not from the project. The easiest way is to run the setup script from the **hive repo folder**:
|
||||
|
||||
```bash
|
||||
./scripts/setup-antigravity-mcp.sh
|
||||
```
|
||||
|
||||
The script finds the repo root, writes `~/.gemini/antigravity/mcp_config.json` with the right paths, and you don't edit any paths by hand.
|
||||
|
||||
> **Important:** Always restart/refresh Antigravity IDE after running the setup script. MCP servers are only loaded on IDE startup.
|
||||
|
||||
The **agent-builder** and **tools** servers should show up after restart.
|
||||
|
||||
**Using Claude Code instead?** Run:
|
||||
|
||||
```bash
|
||||
./scripts/setup-antigravity-mcp.sh --claude
|
||||
```
|
||||
|
||||
That writes `~/.claude/mcp.json` as well.
|
||||
|
||||
**Prefer to do it manually?** See [Manual MCP config](#manual-mcp-config-template) below. You’ll create `~/.gemini/mcp.json` (or `~/.claude/mcp.json`) with absolute paths to your repo’s `core` and `tools` folders.
|
||||
|
||||
### Step 3: Use skills
|
||||
|
||||
Skills are guides (workflow, building, testing) in `.agent/skills/` (they point to `.claude/skills/`). If Antigravity doesn’t show a “skills” UI, open those folders in the project and use the files as reference while you use the MCP tools.
|
||||
|
||||
| Skill | What it's for |
|
||||
|-------|----------------|
|
||||
| **hive** | End-to-end workflow for building and testing agents |
|
||||
| **hive-concepts** | Core ideas for goal-driven agents |
|
||||
| **hive-create** | Step-by-step agent construction |
|
||||
| **hive-patterns** | Patterns and best practices |
|
||||
| **hive-test** | Goal-based evaluation and testing |
|
||||
| **hive-credentials** | Set up and manage agent credentials |
|
||||
|
||||
---
|
||||
|
||||
## What’s in the repo (`.agent/`)
|
||||
|
||||
```
|
||||
.agent/
|
||||
├── mcp_config.json # Template for MCP servers (agent-builder, tools)
|
||||
└── skills/ # Symlinks to .claude/skills/
|
||||
```
|
||||
|
||||
The **setup script** writes your **user** config (`~/.gemini/antigravity/mcp_config.json`) using paths from **this repo**. The file in `.agent/` is the template; Antigravity itself uses the file in your home directory.
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**MCP servers don’t connect**
|
||||
|
||||
- Run the setup script again from the hive repo root: `./scripts/setup-antigravity-mcp.sh`, then restart Antigravity.
|
||||
- Make sure Python and deps are installed: from repo root run `./quickstart.sh`.
|
||||
- Check that the servers can start: from repo root run
|
||||
`cd core && uv run -m framework.mcp.agent_builder_server` (Ctrl+C to stop), and in another terminal
|
||||
`cd tools && uv run mcp_server.py --stdio` (Ctrl+C to stop).
|
||||
If those fail, fix the errors first (e.g. install deps with `uv sync`).
|
||||
|
||||
**"Module not found" or import errors**
|
||||
|
||||
- Open the **repo root** as the project in the IDE (the folder that has `core/` and `tools/`).
|
||||
- If you edited `~/.gemini/antigravity/mcp_config.json` by hand, make sure `--directory` paths are **absolute** (e.g. `/Users/you/hive/core` and `/Users/you/hive/tools`).
|
||||
|
||||
**Skills don’t show up in the UI**
|
||||
|
||||
- Antigravity may not have a dedicated “skills” panel. Use the files in `.claude/skills/` or `.agent/skills/` as docs; the MCP tools (agent-builder, tools) still work.
|
||||
|
||||
---
|
||||
|
||||
## Verification prompt (optional)
|
||||
|
||||
Paste this into Antigravity to check that MCP and skills are set up. It doesn’t use your machine’s paths; anyone can use it.
|
||||
|
||||
```
|
||||
Check the Hive + Antigravity integration:
|
||||
|
||||
1. MCP: List available MCP servers/tools. Confirm that "agent-builder" and "tools" (or equivalent) are connected. If not, tell the user to run ./scripts/setup-antigravity-mcp.sh from the hive repo root, then restart Antigravity (see docs/antigravity-setup.md).
|
||||
|
||||
2. Skills: Confirm that the project has .agent/skills/ (or .claude/skills/) with: hive, hive-concepts, hive-create, hive-patterns, hive-test, hive-credentials.
|
||||
|
||||
3. Result: Reply with PASS (MCP + skills OK), PARTIAL (only skills or only MCP), or FAIL (neither), and one line on what to fix if not PASS.
|
||||
```
|
||||
|
||||
If you get **PARTIAL** (e.g. MCP not connected), run `./scripts/setup-antigravity-mcp.sh` from the repo root and restart Antigravity.
|
||||
|
||||
---
|
||||
|
||||
## Manual MCP config template
|
||||
|
||||
Use this only if you don’t want to run the setup script. Replace `/path/to/hive` with your actual repo root (e.g. the output of `pwd` when you’re in the hive folder).
|
||||
|
||||
Save as `~/.gemini/antigravity/mcp_config.json` (Antigravity) or `~/.claude/mcp.json` (Claude Code), then **restart the IDE** to load the new configuration.
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "uv",
|
||||
"args": ["run", "--directory", "/path/to/hive/core", "-m", "framework.mcp.agent_builder_server"],
|
||||
"disabled": false
|
||||
},
|
||||
"tools": {
|
||||
"command": "uv",
|
||||
"args": ["run", "--directory", "/path/to/hive/tools", "mcp_server.py", "--stdio"],
|
||||
"disabled": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Make sure `uv` is installed and available in your PATH. Note: Use `--directory` in args instead of `cwd` for Antigravity compatibility.
|
||||
|
||||
---
|
||||
|
||||
## Verify from the command line (optional)
|
||||
|
||||
From the **repo root**:
|
||||
|
||||
**Check that config and skills exist**
|
||||
|
||||
```bash
|
||||
test -f .agent/mcp_config.json && echo "OK: mcp_config.json" || echo "MISSING"
|
||||
for s in hive hive-concepts hive-create hive-patterns hive-test hive-credentials; do
|
||||
test -L .agent/skills/$s && test -d .agent/skills/$s && echo "OK: $s" || echo "BROKEN: $s"
|
||||
done
|
||||
```
|
||||
|
||||
**Check that the config is valid JSON**
|
||||
|
||||
```bash
|
||||
python3 -c "import json; json.load(open('.agent/mcp_config.json')); print('OK: valid JSON')"
|
||||
```
|
||||
|
||||
**Test that MCP servers start** (two terminals)
|
||||
|
||||
```bash
|
||||
# Terminal 1
|
||||
cd core && uv run -m framework.mcp.agent_builder_server
|
||||
|
||||
# Terminal 2
|
||||
cd tools && uv run mcp_server.py --stdio
|
||||
```
|
||||
|
||||
If both start without errors, the config is fine.
|
||||
|
||||
---
|
||||
|
||||
## See also
|
||||
|
||||
- [Cursor IDE support](../README.md#cursor-ide-support) – Same MCP servers and skills for Cursor
|
||||
- [MCP Integration Guide](../core/MCP_INTEGRATION_GUIDE.md) – How the framework MCP works
|
||||
- [Environment setup](../ENVIRONMENT_SETUP.md) – Repo and Python setup
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user