Compare commits
169 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a03b378e9b | |||
| 893053ede7 | |||
| 596ec6fec5 | |||
| 5863b83172 | |||
| 20c92b197a | |||
| ec9c6b4666 | |||
| 8a73e5c119 | |||
| 717f0eee9a | |||
| 09fb47f089 | |||
| b46d943e71 | |||
| 2920b5ab01 | |||
| 81ad0467b0 | |||
| 115ca55ea0 | |||
| f2814a26e6 | |||
| 4d309950b0 | |||
| 39216a4c12 | |||
| c7fa621aeb | |||
| 8c3ad3d70a | |||
| 9eb3fc6285 | |||
| e95f7e7339 | |||
| d949551399 | |||
| a7dbd85ed4 | |||
| 1f288dab1c | |||
| 021754d941 | |||
| 7412904fbf | |||
| 5f3e9379a3 | |||
| 0e565d6cea | |||
| 67b249dcd5 | |||
| bbf1c8c790 | |||
| 44a8b453b5 | |||
| 26511fe962 | |||
| ce5893216a | |||
| 4e821e4dbf | |||
| d11e97de59 | |||
| 4b10d3e360 | |||
| e04479930f | |||
| 8a8c4cc3f5 | |||
| 1e06ff611e | |||
| 1edc7bb9c7 | |||
| 7b1e0af155 | |||
| 7b15616e29 | |||
| 99ed00fd02 | |||
| f7af5f9ee8 | |||
| e5bcc8005f | |||
| 352d285212 | |||
| 3ef60f9d14 | |||
| a103312127 | |||
| 3d0bba4167 | |||
| 3df718cc14 | |||
| c7497a180e | |||
| 3f39039a21 | |||
| 88fbd90fcc | |||
| e0bf09dd78 | |||
| 3e158b07af | |||
| 5319ed7ee1 | |||
| 978904d2a4 | |||
| 4d876ecc54 | |||
| ba327d0b9e | |||
| b69cf3523c | |||
| 4d8c8e9308 | |||
| b70885934c | |||
| 722b087fc0 | |||
| 0c7ea272db | |||
| 5e4f322fc0 | |||
| c02e45f1aa | |||
| a7217f138c | |||
| 3502f25048 | |||
| 93c026fe31 | |||
| e515977b96 | |||
| 045490a097 | |||
| b25903fb7f | |||
| acf4bd5152 | |||
| 1f5711e1a1 | |||
| ca2dd90313 | |||
| 21e07f3b65 | |||
| e8a06ddd34 | |||
| 34cc09904f | |||
| f6bba8b62f | |||
| d241ad60f8 | |||
| 5a3fcf9a8a | |||
| 1f8a47203f | |||
| 7240090274 | |||
| 2e6a47c2df | |||
| 7f5ecd7913 | |||
| 105b98b113 | |||
| 114e65ab41 | |||
| 0fc13a5cc3 | |||
| e651799e9e | |||
| fcd3e514de | |||
| 7ab41de3a2 | |||
| 58e023f277 | |||
| a98f2d5b86 | |||
| eca43231c0 | |||
| 6763077887 | |||
| f85ff8a2f8 | |||
| 1a5c3480e6 | |||
| 69a7fe7b92 | |||
| a5418d760f | |||
| 0deeb87c63 | |||
| d1d5f49c5a | |||
| 917e23ccc8 | |||
| 988922304f | |||
| ab2bd726c3 | |||
| 713fefb163 | |||
| 83140a1398 | |||
| cafa6dd930 | |||
| 82e1af1a7a | |||
| 30c3dc9205 | |||
| 9a3c6703e1 | |||
| e26468aa19 | |||
| fe14992696 | |||
| d0775b95c6 | |||
| 96121b5757 | |||
| 11c003c48d | |||
| fbe72c58ae | |||
| 816156e87f | |||
| 7bceab3cea | |||
| 83d7f56728 | |||
| 76deba2a6a | |||
| d9d048b9e3 | |||
| 930f417729 | |||
| 8e214d06c1 | |||
| 63e0348963 | |||
| b46a5f0247 | |||
| 79dfd90068 | |||
| f9d5c7c751 | |||
| 8958fb2d88 | |||
| 3c51f2ac36 | |||
| 170a0918f7 | |||
| e3da3b619c | |||
| 6e32513b79 | |||
| 520e1963ee | |||
| 843b9b55e2 | |||
| ccd305ff96 | |||
| 3bd0d1e48c | |||
| d9bfa8e675 | |||
| 27746147e2 | |||
| 3a0b642980 | |||
| 8c0241f087 | |||
| 958d016174 | |||
| 913d318ada | |||
| 8212920cb7 | |||
| 6414be7bd4 | |||
| ac62a82d08 | |||
| a670548a57 | |||
| c4a7463f9d | |||
| edf0ac5270 | |||
| 8ff6b76f37 | |||
| c9f9eb365c | |||
| 7a17c115d3 | |||
| 9a2a11055f | |||
| f21aecd91c | |||
| 4aef73c1d7 | |||
| 9df147b450 | |||
| b71b4b0fc2 | |||
| 1bd2510c52 | |||
| 28b81092f9 | |||
| 4b9a3abba6 | |||
| 0c76b6dcb1 | |||
| 090a85b41b | |||
| 992d573573 | |||
| 9e768e660b | |||
| 26b9ed362e | |||
| 976ae75fde | |||
| d63dd021ab | |||
| 697ba89314 | |||
| 373ad77008 | |||
| 7fae57f311 | |||
| 1f653969a9 |
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "uv",
|
||||
"args": ["run", "--directory", "core", "-m", "framework.mcp.agent_builder_server"],
|
||||
"disabled": false
|
||||
}
|
||||
}
|
||||
}
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-concepts
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-create
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-credentials
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-patterns
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-test
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-concepts
|
||||
---
|
||||
|
||||
use hive-concepts skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-create
|
||||
---
|
||||
|
||||
use hive-create skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-credentials
|
||||
---
|
||||
|
||||
use hive-credentials skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-patterns
|
||||
---
|
||||
|
||||
use hive-patterns skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive-test
|
||||
---
|
||||
|
||||
use hive-test skill
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: hive
|
||||
---
|
||||
|
||||
use hive skill
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-concepts
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-create
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-credentials
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-patterns
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/hive-test
|
||||
@@ -492,7 +492,7 @@ AskUserQuestion(questions=[{
|
||||
- node_id (kebab-case)
|
||||
- name
|
||||
- description
|
||||
- node_type: `"event_loop"` (recommended for all LLM work) or `"function"` (deterministic, no LLM)
|
||||
- node_type: `"event_loop"` (the only valid type; use `client_facing: True` for HITL)
|
||||
- input_keys (what data this node receives)
|
||||
- output_keys (what data this node produces)
|
||||
- tools (ONLY tools that exist from Step 1 — empty list if no tools needed)
|
||||
@@ -553,6 +553,26 @@ AskUserQuestion(questions=[{
|
||||
- condition_expr (Python expression, only if conditional)
|
||||
- priority (positive = forward, negative = feedback/loop-back)
|
||||
|
||||
**DETERMINE the graph lifecycle.** Not every agent needs a terminal node:
|
||||
|
||||
| Pattern | `terminal_nodes` | When to Use |
|
||||
|---------|-------------------|-------------|
|
||||
| **Linear (finish)** | `["last-node"]` | Agent completes a task and exits (batch processing, one-shot generation) |
|
||||
| **Forever-alive (loop)** | `[]` (empty) | Agent stays alive for continuous interaction (research assistant, personal assistant, monitoring) |
|
||||
|
||||
**Forever-alive pattern:** The deep_research_agent example uses `terminal_nodes=[]`. Every leaf node has edges that loop back to earlier nodes, creating a perpetual session. The agent only stops when the user explicitly exits. This is the preferred pattern for interactive, multi-turn agents.
|
||||
|
||||
**Key design rules for forever-alive graphs:**
|
||||
- Every node must have at least one outgoing edge (no dead ends)
|
||||
- Client-facing nodes block for user input — these are the natural "pause points"
|
||||
- The user controls when to stop, not the graph
|
||||
- Sessions accumulate memory across loops — plan for conversation compaction
|
||||
- Use `conversation_mode="continuous"` to preserve conversation history across node transitions
|
||||
- `max_iterations` should be set high (e.g., 100) since the agent is designed to run indefinitely
|
||||
- The agent will NOT enter a "completed" execution state — this is intentional, not a bug
|
||||
|
||||
**Ask the user** which lifecycle pattern fits their agent. Default to forever-alive for interactive agents, linear for batch/one-shot tasks.
|
||||
|
||||
**RENDER the complete graph as ASCII art.** Make it large and clear — the user needs to see and understand the full workflow at a glance.
|
||||
|
||||
**IMPORTANT: Make the ASCII art BIG and READABLE.** Use a box-and-arrow style with generous spacing. Do NOT make it tiny or compressed. Example format:
|
||||
@@ -832,8 +852,7 @@ cd /home/timothy/oss/hive && PYTHONPATH=exports uv run python -m AGENT_NAME vali
|
||||
|
||||
| Type | tools param | Use when |
|
||||
| ------------ | ----------------------- | --------------------------------------- |
|
||||
| `event_loop` | `'["tool1"]'` or `'[]'` | LLM-powered work with or without tools |
|
||||
| `function` | N/A | Deterministic Python operations, no LLM |
|
||||
| `event_loop` | `'["tool1"]'` or `'[]'` | All agent work (with or without tools, HITL via client_facing) |
|
||||
|
||||
---
|
||||
|
||||
@@ -912,6 +931,46 @@ result = await executor.execute(graph=graph, goal=goal, input_data=input_data)
|
||||
|
||||
---
|
||||
|
||||
## REFERENCE: Graph Lifecycle & Conversation Memory
|
||||
|
||||
### Terminal vs Forever-Alive Graphs
|
||||
|
||||
Agents have two lifecycle patterns:
|
||||
|
||||
**Linear (terminal) graphs** have `terminal_nodes=["last-node"]`. Execution ends when the terminal node completes. The session enters a "completed" state. Use for batch processing, one-shot generation, and fire-and-forget tasks.
|
||||
|
||||
**Forever-alive graphs** have `terminal_nodes=[]` (empty). Every node has at least one outgoing edge — the graph loops indefinitely. The session **never enters a "completed" state** — this is intentional. The agent stays alive until the user explicitly exits. Use for interactive assistants, research tools, and any agent where the user drives the conversation.
|
||||
|
||||
The deep_research_agent example demonstrates this: `report` loops back to either `research` (dig deeper) or `intake` (new topic). The agent is a persistent, interactive assistant.
|
||||
|
||||
### Continuous Conversation Mode
|
||||
|
||||
When `conversation_mode="continuous"` is set on the GraphSpec, the framework preserves a **single conversation thread** across all node transitions:
|
||||
|
||||
**What the framework does automatically:**
|
||||
- **Inherits conversation**: Same message history carries forward to the next node
|
||||
- **Composes layered system prompts**: Identity (agent-level) + Narrative (auto-generated state summary) + Focus (per-node instructions)
|
||||
- **Inserts transition markers**: At each node boundary, a "State of the World" message showing completed phases, current memory, and available data files
|
||||
- **Accumulates tools**: Once a tool becomes available, it stays available in subsequent nodes
|
||||
- **Compacts opportunistically**: At phase transitions, old tool results are pruned to stay within token budget
|
||||
|
||||
**What this means for agent builders:**
|
||||
- Nodes don't need to re-explain context — the conversation carries it forward
|
||||
- Output keys from earlier nodes are available in memory for edge conditions and later nodes
|
||||
- For forever-alive agents, conversation memory persists across the entire session lifetime
|
||||
- Plan for compaction: very long sessions will have older tool results summarized automatically
|
||||
|
||||
**When to use continuous mode:**
|
||||
- Interactive agents with client-facing nodes (always)
|
||||
- Multi-phase workflows where context matters across phases
|
||||
- Forever-alive agents that loop indefinitely
|
||||
|
||||
**When NOT to use continuous mode:**
|
||||
- Embarrassingly parallel fan-out nodes (each branch should be independent)
|
||||
- Stateless utility agents that process items independently
|
||||
|
||||
---
|
||||
|
||||
## REFERENCE: Framework Capabilities for Qualification
|
||||
|
||||
Use this reference during STEP 2 to give accurate, honest assessments.
|
||||
@@ -944,11 +1003,11 @@ Use this reference during STEP 2 to give accurate, honest assessments.
|
||||
|
||||
| Use Case | Why It's Problematic | Alternative |
|
||||
|----------|---------------------|-------------|
|
||||
| Long-running daemons | Framework is request-response, not persistent | External scheduler + agent |
|
||||
| Persistent background daemons (no user) | Forever-alive graphs need a user at client-facing nodes; no autonomous background polling without user | External scheduler triggering agent runs |
|
||||
| Sub-second responses | LLM latency is inherent | Traditional code, no LLM |
|
||||
| Processing millions of items | Context windows and rate limits | Batch processing + sampling |
|
||||
| Real-time streaming data | No built-in pub/sub or streaming input | Custom MCP server + agent |
|
||||
| Guaranteed determinism | LLM outputs vary | Function nodes for deterministic parts |
|
||||
| Guaranteed determinism | LLM outputs vary | Traditional code for deterministic parts |
|
||||
| Offline/air-gapped | Requires LLM API access | Local models (not currently supported) |
|
||||
| Multi-user concurrency | Single-user session model | Separate agent instances per user |
|
||||
|
||||
@@ -979,3 +1038,6 @@ Use this reference during STEP 2 to give accurate, honest assessments.
|
||||
11. **Adding framework gating for LLM behavior** - Fix prompts or use judges, not ad-hoc code
|
||||
12. **Writing code before user approves the graph** - Always get approval on goal, nodes, and graph BEFORE writing any agent code
|
||||
13. **Wrong mcp_servers.json format** - Use flat format (no `"mcpServers"` wrapper), `cwd` must be `"../../tools"`, and `command` must be `"uv"` with args `["run", "python", ...]`
|
||||
14. **Assuming all agents need terminal nodes** - Interactive agents often work best with `terminal_nodes=[]` (forever-alive pattern). The agent never enters "completed" state — this is intentional. Only batch/one-shot agents need terminal nodes
|
||||
15. **Creating dead-end nodes in forever-alive graphs** - Every node must have at least one outgoing edge. A node with no outgoing edges will cause execution to end unexpectedly, breaking the forever-alive loop
|
||||
16. **Not using continuous conversation mode for interactive agents** - Multi-phase interactive agents should use `conversation_mode="continuous"` to preserve context across node transitions. Without it, each node starts with a blank conversation and loses all prior context
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Agent graph construction for Deep Research Agent."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import EdgeSpec, EdgeCondition, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import ExecutionResult, GraphExecutor
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.llm import LiteLLMProvider
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
|
||||
from .config import default_config, metadata
|
||||
from .nodes import (
|
||||
@@ -120,13 +123,31 @@ edges = [
|
||||
condition_expr="needs_more_research == False",
|
||||
priority=2,
|
||||
),
|
||||
# report -> research (user wants deeper research on current topic)
|
||||
EdgeSpec(
|
||||
id="report-to-research",
|
||||
source="report",
|
||||
target="research",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="str(next_action).lower() == 'more_research'",
|
||||
priority=2,
|
||||
),
|
||||
# report -> intake (user wants a new topic — default when not more_research)
|
||||
EdgeSpec(
|
||||
id="report-to-intake",
|
||||
source="report",
|
||||
target="intake",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="str(next_action).lower() != 'more_research'",
|
||||
priority=1,
|
||||
),
|
||||
]
|
||||
|
||||
# Graph configuration
|
||||
entry_node = "intake"
|
||||
entry_points = {"start": "intake"}
|
||||
pause_nodes = []
|
||||
terminal_nodes = ["report"]
|
||||
terminal_nodes = []
|
||||
|
||||
|
||||
class DeepResearchAgent:
|
||||
@@ -136,6 +157,12 @@ class DeepResearchAgent:
|
||||
Flow: intake -> research -> review -> report
|
||||
^ |
|
||||
+-- feedback loop (if user wants more)
|
||||
|
||||
Uses AgentRuntime for proper session management:
|
||||
- Session-scoped storage (sessions/{session_id}/)
|
||||
- Checkpointing for resume capability
|
||||
- Runtime logging
|
||||
- Data folder for save_data/load_data
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
@@ -147,10 +174,10 @@ class DeepResearchAgent:
|
||||
self.entry_points = entry_points
|
||||
self.pause_nodes = pause_nodes
|
||||
self.terminal_nodes = terminal_nodes
|
||||
self._executor: GraphExecutor | None = None
|
||||
self._graph: GraphSpec | None = None
|
||||
self._event_bus: EventBus | None = None
|
||||
self._agent_runtime: AgentRuntime | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
self._storage_path: Path | None = None
|
||||
|
||||
def _build_graph(self) -> GraphSpec:
|
||||
"""Build the GraphSpec."""
|
||||
@@ -171,16 +198,20 @@ class DeepResearchAgent:
|
||||
"max_tool_calls_per_turn": 20,
|
||||
"max_history_tokens": 32000,
|
||||
},
|
||||
conversation_mode="continuous",
|
||||
identity_prompt=(
|
||||
"You are a rigorous research agent. You search for information "
|
||||
"from diverse, authoritative sources, analyze findings critically, "
|
||||
"and produce well-cited reports. You never fabricate information — "
|
||||
"every claim must trace back to a source you actually retrieved."
|
||||
),
|
||||
)
|
||||
|
||||
def _setup(self, mock_mode=False) -> GraphExecutor:
|
||||
"""Set up the executor with all components."""
|
||||
from pathlib import Path
|
||||
def _setup(self, mock_mode=False) -> None:
|
||||
"""Set up the agent runtime with sessions, checkpoints, and logging."""
|
||||
self._storage_path = Path.home() / ".hive" / "agents" / "deep_research_agent"
|
||||
self._storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
storage_path = Path.home() / ".hive" / "agents" / "deep_research_agent"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._event_bus = EventBus()
|
||||
self._tool_registry = ToolRegistry()
|
||||
|
||||
mcp_config_path = Path(__file__).parent / "mcp_servers.json"
|
||||
@@ -199,47 +230,63 @@ class DeepResearchAgent:
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
|
||||
self._graph = self._build_graph()
|
||||
runtime = Runtime(storage_path)
|
||||
|
||||
self._executor = GraphExecutor(
|
||||
runtime=runtime,
|
||||
checkpoint_config = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=False,
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
async_checkpoint=True,
|
||||
)
|
||||
|
||||
entry_point_specs = [
|
||||
EntryPointSpec(
|
||||
id="default",
|
||||
name="Default",
|
||||
entry_node=self.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
)
|
||||
]
|
||||
|
||||
self._agent_runtime = create_agent_runtime(
|
||||
graph=self._graph,
|
||||
goal=self.goal,
|
||||
storage_path=self._storage_path,
|
||||
entry_points=entry_point_specs,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
event_bus=self._event_bus,
|
||||
storage_path=storage_path,
|
||||
loop_config=self._graph.loop_config,
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
|
||||
return self._executor
|
||||
|
||||
async def start(self, mock_mode=False) -> None:
|
||||
"""Set up the agent (initialize executor and tools)."""
|
||||
if self._executor is None:
|
||||
"""Set up and start the agent runtime."""
|
||||
if self._agent_runtime is None:
|
||||
self._setup(mock_mode=mock_mode)
|
||||
if not self._agent_runtime.is_running:
|
||||
await self._agent_runtime.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self._executor = None
|
||||
self._event_bus = None
|
||||
"""Stop the agent runtime and clean up."""
|
||||
if self._agent_runtime and self._agent_runtime.is_running:
|
||||
await self._agent_runtime.stop()
|
||||
self._agent_runtime = None
|
||||
|
||||
async def trigger_and_wait(
|
||||
self,
|
||||
entry_point: str,
|
||||
input_data: dict,
|
||||
entry_point: str = "default",
|
||||
input_data: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
session_state: dict | None = None,
|
||||
) -> ExecutionResult | None:
|
||||
"""Execute the graph and wait for completion."""
|
||||
if self._executor is None:
|
||||
if self._agent_runtime is None:
|
||||
raise RuntimeError("Agent not started. Call start() first.")
|
||||
if self._graph is None:
|
||||
raise RuntimeError("Graph not built. Call start() first.")
|
||||
|
||||
return await self._executor.execute(
|
||||
graph=self._graph,
|
||||
goal=self.goal,
|
||||
input_data=input_data,
|
||||
return await self._agent_runtime.trigger_and_wait(
|
||||
entry_point_id=entry_point,
|
||||
input_data=input_data or {},
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
@@ -250,7 +297,7 @@ class DeepResearchAgent:
|
||||
await self.start(mock_mode=mock_mode)
|
||||
try:
|
||||
result = await self.trigger_and_wait(
|
||||
"start", context, session_state=session_state
|
||||
"default", context, session_state=session_state
|
||||
)
|
||||
return result or ExecutionResult(success=False, error="Execution timeout")
|
||||
finally:
|
||||
|
||||
@@ -10,8 +10,13 @@ intake_node = NodeSpec(
|
||||
description="Discuss the research topic with the user, clarify scope, and confirm direction",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=0,
|
||||
input_keys=["topic"],
|
||||
output_keys=["research_brief"],
|
||||
success_criteria=(
|
||||
"The research brief is specific and actionable: it states the topic, "
|
||||
"the key questions to answer, the desired scope, and depth."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are a research intake specialist. The user wants to research a topic.
|
||||
Have a brief conversation to clarify what they need.
|
||||
@@ -38,10 +43,14 @@ research_node = NodeSpec(
|
||||
name="Research",
|
||||
description="Search the web, fetch source content, and compile findings",
|
||||
node_type="event_loop",
|
||||
max_node_visits=3,
|
||||
max_node_visits=0,
|
||||
input_keys=["research_brief", "feedback"],
|
||||
output_keys=["findings", "sources", "gaps"],
|
||||
nullable_output_keys=["feedback"],
|
||||
success_criteria=(
|
||||
"Findings reference at least 3 distinct sources with URLs. "
|
||||
"Key claims are substantiated by fetched content, not generated."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are a research agent. Given a research brief, find and analyze sources.
|
||||
|
||||
@@ -56,18 +65,26 @@ Work in phases:
|
||||
and any contradictions between sources.
|
||||
|
||||
Important:
|
||||
- Work in batches of 3-4 tool calls at a time to manage context
|
||||
- Work in batches of 3-4 tool calls at a time — never more than 10 per turn
|
||||
- After each batch, assess whether you have enough material
|
||||
- Prefer quality over quantity — 5 good sources beat 15 thin ones
|
||||
- Track which URL each finding comes from (you'll need citations later)
|
||||
- Call set_output for each key in a SEPARATE turn (not in the same turn as other tool calls)
|
||||
|
||||
When done, use set_output:
|
||||
When done, use set_output (one key at a time, separate turns):
|
||||
- set_output("findings", "Structured summary: key findings with source URLs for each claim. \
|
||||
Include themes, contradictions, and confidence levels.")
|
||||
- set_output("sources", [{"url": "...", "title": "...", "summary": "..."}])
|
||||
- set_output("gaps", "What aspects of the research brief are NOT well-covered yet, if any.")
|
||||
""",
|
||||
tools=["web_search", "web_scrape", "load_data", "save_data", "list_data_files"],
|
||||
tools=[
|
||||
"web_search",
|
||||
"web_scrape",
|
||||
"load_data",
|
||||
"save_data",
|
||||
"append_data",
|
||||
"list_data_files",
|
||||
],
|
||||
)
|
||||
|
||||
# Node 3: Review (client-facing)
|
||||
@@ -78,9 +95,13 @@ review_node = NodeSpec(
|
||||
description="Present findings to user and decide whether to research more or write the report",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=3,
|
||||
max_node_visits=0,
|
||||
input_keys=["findings", "sources", "gaps", "research_brief"],
|
||||
output_keys=["needs_more_research", "feedback"],
|
||||
success_criteria=(
|
||||
"The user has been presented with findings and has explicitly indicated "
|
||||
"whether they want more research or are ready for the report."
|
||||
),
|
||||
system_prompt="""\
|
||||
Present the research findings to the user clearly and concisely.
|
||||
|
||||
@@ -109,49 +130,70 @@ report_node = NodeSpec(
|
||||
description="Write a cited HTML report from the findings and present it to the user",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=0,
|
||||
input_keys=["findings", "sources", "research_brief"],
|
||||
output_keys=["delivery_status"],
|
||||
output_keys=["delivery_status", "next_action"],
|
||||
success_criteria=(
|
||||
"An HTML report has been saved, the file link has been presented to the user, "
|
||||
"and the user has indicated what they want to do next."
|
||||
),
|
||||
system_prompt="""\
|
||||
Write a comprehensive research report as an HTML file and present it to the user.
|
||||
Write a research report as an HTML file and present it to the user.
|
||||
|
||||
**STEP 1 — Write the HTML report (tool calls, NO text to user yet):**
|
||||
IMPORTANT: save_data requires TWO separate arguments: filename and data.
|
||||
Call it like: save_data(filename="report.html", data="<html>...</html>")
|
||||
Do NOT use _raw, do NOT nest arguments inside a JSON string.
|
||||
|
||||
1. Compose a complete, self-contained HTML document with embedded CSS styling.
|
||||
Use a clean, readable design: max-width container, pleasant typography,
|
||||
numbered citation links, a table of contents, and a references section.
|
||||
**STEP 1 — Write and save the HTML report (tool calls, NO text to user yet):**
|
||||
|
||||
Report structure inside the HTML:
|
||||
- Title & date
|
||||
- Executive Summary (2-3 paragraphs)
|
||||
- Table of Contents
|
||||
- Findings (organized by theme, with [n] citation links)
|
||||
- Analysis (synthesis, implications, areas of debate)
|
||||
- Conclusion (key takeaways, confidence assessment)
|
||||
- References (numbered list with clickable URLs)
|
||||
Build a clean HTML document. Keep the HTML concise — aim for clarity over length.
|
||||
Use minimal embedded CSS (a few lines of style, not a full framework).
|
||||
|
||||
Requirements:
|
||||
- Every factual claim must cite its source with [n] notation
|
||||
- Be objective — present multiple viewpoints where sources disagree
|
||||
- Distinguish well-supported conclusions from speculation
|
||||
- Answer the original research questions from the brief
|
||||
Report structure:
|
||||
- Title & date
|
||||
- Executive Summary (2-3 paragraphs)
|
||||
- Key Findings (organized by theme, with [n] citation links)
|
||||
- Analysis (synthesis, implications)
|
||||
- Conclusion (key takeaways)
|
||||
- References (numbered list with clickable URLs)
|
||||
|
||||
2. Save the HTML file:
|
||||
save_data(filename="report.html", data=<your_html>)
|
||||
Requirements:
|
||||
- Every factual claim must cite its source with [n] notation
|
||||
- Be objective — present multiple viewpoints where sources disagree
|
||||
- Answer the original research questions from the brief
|
||||
|
||||
3. Get the clickable link:
|
||||
serve_file_to_user(filename="report.html", label="Research Report")
|
||||
Save the HTML:
|
||||
save_data(filename="report.html", data="<html>...</html>")
|
||||
|
||||
Then get the clickable link:
|
||||
serve_file_to_user(filename="report.html", label="Research Report")
|
||||
|
||||
If save_data fails, simplify and shorten the HTML, then retry.
|
||||
|
||||
**STEP 2 — Present the link to the user (text only, NO tool calls):**
|
||||
|
||||
Tell the user the report is ready and include the file:// URI from
|
||||
serve_file_to_user so they can click it to open. Give a brief summary
|
||||
of what the report covers. Ask if they have questions.
|
||||
of what the report covers. Ask if they have questions or want to continue.
|
||||
|
||||
**STEP 3 — After the user responds:**
|
||||
- Answer follow-up questions from the research material
|
||||
- When the user is satisfied: set_output("delivery_status", "completed")
|
||||
- Answer any follow-up questions from the research material
|
||||
- When the user is ready to move on, ask what they'd like to do next:
|
||||
- Research a new topic?
|
||||
- Dig deeper into the current topic?
|
||||
- Then call set_output:
|
||||
- set_output("delivery_status", "completed")
|
||||
- set_output("next_action", "new_topic") — if they want a new topic
|
||||
- set_output("next_action", "more_research") — if they want deeper research
|
||||
""",
|
||||
tools=["save_data", "serve_file_to_user", "load_data", "list_data_files"],
|
||||
tools=[
|
||||
"save_data",
|
||||
"append_data",
|
||||
"edit_data",
|
||||
"serve_file_to_user",
|
||||
"load_data",
|
||||
"list_data_files",
|
||||
],
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -508,7 +508,7 @@ All credential specs are defined in `tools/src/aden_tools/credentials/`:
|
||||
| `llm.py` | LLM Providers | `anthropic` | No |
|
||||
| `search.py` | Search Tools | `brave_search`, `google_search`, `google_cse` | No |
|
||||
| `email.py` | Email | `resend` | No |
|
||||
| `integrations.py` | Integrations | `github`, `hubspot` | No / Yes |
|
||||
| `integrations.py` | Integrations | `github`, `hubspot`, `google_calendar_oauth` | No / Yes |
|
||||
|
||||
**Note:** Additional LLM providers (Cerebras, Groq, OpenAI) are handled by LiteLLM via environment
|
||||
variables (`CEREBRAS_API_KEY`, `GROQ_API_KEY`, `OPENAI_API_KEY`) but are not yet in CREDENTIAL_SPECS.
|
||||
|
||||
@@ -26,6 +26,17 @@ Use `/hive-debugger` when:
|
||||
|
||||
This skill works alongside agents running in TUI mode and provides supervisor-level insights into execution behavior.
|
||||
|
||||
### Forever-Alive Agent Awareness
|
||||
|
||||
Some agents use `terminal_nodes=[]` (the "forever-alive" pattern), meaning they loop indefinitely and never enter a "completed" execution state. For these agents:
|
||||
- Sessions with status "in_progress" or "paused" are **normal**, not failures
|
||||
- High step counts, long durations, and many node visits are expected behavior
|
||||
- The agent stops only when the user explicitly exits — there is no graph-driven completion
|
||||
- Debug focus should be on **quality of individual node visits and iterations**, not whether the session reached a terminal state
|
||||
- Conversation memory accumulates across loops — watch for context overflow and stale data issues
|
||||
|
||||
**How to identify forever-alive agents:** Check `agent.py` or `agent.json` for `terminal_nodes=[]` (empty list). If empty, the agent is forever-alive.
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
@@ -142,6 +153,7 @@ Store the selected mode for the session.
|
||||
- Check `attention_summary.categories` for issue types
|
||||
- Note the `run_id` of problematic sessions
|
||||
- Check `status` field: "degraded", "failure", "in_progress"
|
||||
- **For forever-alive agents:** Sessions with status "in_progress" or "paused" are normal — these agents never reach "completed". Only flag sessions with `needs_attention: true` or actual error indicators (tool failures, retry loops, missing outputs). High step counts alone do not indicate a problem.
|
||||
|
||||
3. **Attention flag triggers to understand:**
|
||||
From runtime_logger.py, runs are flagged when:
|
||||
@@ -199,13 +211,20 @@ Which run would you like to investigate?
|
||||
| **Tool Errors** | `tool_error_count > 0`, `attention_reasons` contains "tool_failures" | Tool calls failed (API errors, timeouts, auth issues) |
|
||||
| **Retry Loops** | `retry_count > 3`, `verdict_counts.RETRY > 5` | Judge repeatedly rejecting outputs |
|
||||
| **Guard Failures** | `guard_reject_count > 0` | Output validation failed (wrong types, missing keys) |
|
||||
| **Stalled Execution** | `total_steps > 20`, `verdict_counts.CONTINUE > 10` | EventLoopNode not making progress |
|
||||
| **Stalled Execution** | `total_steps > 20`, `verdict_counts.CONTINUE > 10` | EventLoopNode not making progress. **Caveat:** Forever-alive agents may legitimately have high step counts — check if agent is blocked at a client-facing node (normal) vs genuinely stuck in a loop |
|
||||
| **High Latency** | `latency_ms > 60000`, `avg_step_latency > 5000` | Slow tool calls or LLM responses |
|
||||
| **Client-Facing Issues** | `client_input_requested` but no `user_input_received` | Premature set_output before user input |
|
||||
| **Edge Routing Errors** | `exit_status == "no_valid_edge"`, `attention_reasons` contains "routing_issue" | No edges match current state |
|
||||
| **Memory/Context Issues** | `tokens_used > 100000`, `context_overflow_count > 0` | Conversation history too long |
|
||||
| **Constraint Violations** | Compare output against goal constraints | Agent violated goal-level rules |
|
||||
|
||||
**Forever-Alive Agent Caveat:** If the agent uses `terminal_nodes=[]`, sessions will never reach "completed" status. This is by design. When debugging these agents, focus on:
|
||||
- Whether individual node visits succeed (not whether the graph "finishes")
|
||||
- Quality of each loop iteration — are outputs improving or degrading across loops?
|
||||
- Whether client-facing nodes are correctly blocking for user input
|
||||
- Memory accumulation issues: stale data from previous loops, context overflow across many iterations
|
||||
- Conversation compaction behavior: is the conversation growing unbounded?
|
||||
|
||||
3. **Analyze each flagged node:**
|
||||
- Node ID and name
|
||||
- Exit status
|
||||
@@ -1015,6 +1034,9 @@ Your agent should now work correctly!"
|
||||
3. **Don't ignore edge conditions** - Missing edges cause routing failures
|
||||
4. **Don't overlook judge configuration** - Mismatched expectations cause retry loops
|
||||
5. **Don't forget nullable_output_keys** - Optional inputs need explicit marking
|
||||
6. **Don't diagnose "in_progress" as a failure for forever-alive agents** - Agents with `terminal_nodes=[]` are designed to never enter "completed" state. This is intentional. Focus on quality of individual node visits, not session completion status
|
||||
7. **Don't ignore conversation memory issues in long-running sessions** - In continuous conversation mode, history grows across node transitions and loop iterations. Watch for context overflow (tokens_used > 100K), stale data from previous loops affecting edge conditions, and compaction failures that cause the LLM to lose important context
|
||||
8. **Don't confuse "waiting for user" with "stalled"** - Client-facing nodes in forever-alive agents block for user input by design. A session paused at a client-facing node is working correctly, not stalled
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# Project-level Codex config for Hive.
|
||||
# Keep this file minimal: MCP connectivity + skill discovery.
|
||||
|
||||
[mcp_servers.agent-builder]
|
||||
command = "uv"
|
||||
args = ["run", "--directory", "core", "-m", "framework.mcp.agent_builder_server"]
|
||||
cwd = "."
|
||||
+3
-1
@@ -74,4 +74,6 @@ exports/*
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
screenshots/*
|
||||
|
||||
screenshots/*
|
||||
|
||||
|
||||
@@ -4,11 +4,6 @@
|
||||
"command": "uv",
|
||||
"args": ["run", "-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core"
|
||||
},
|
||||
"tools": {
|
||||
"command": "uv",
|
||||
"args": ["run", "mcp_server.py", "--stdio"],
|
||||
"cwd": "tools"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
---
|
||||
name: hive
|
||||
description: Hive Agent Builder & Manager
|
||||
mode: primary
|
||||
tools:
|
||||
agent-builder: true
|
||||
tools: true
|
||||
---
|
||||
|
||||
# Hive Agent
|
||||
You are the Hive Agent Builder. Your goal is to help the user construct, configure, and deploy AI agents using the Hive framework.
|
||||
|
||||
## Capabilities
|
||||
1. **Scaffold Agents:** Create new agent directories/configs.
|
||||
2. **Manage Tools:** Add/remove tools via MCP.
|
||||
3. **Debug:** Analyze agent workflows.
|
||||
|
||||
## Context
|
||||
- You are an expert in the Hive framework architecture.
|
||||
- Always use the `agent-builder` MCP server for filesystem operations.
|
||||
+1
-2
@@ -99,8 +99,7 @@ docs(readme): update installation instructions
|
||||
2. Update documentation if needed
|
||||
3. Add tests for new functionality
|
||||
4. Ensure `make check` and `make test` pass
|
||||
5. Update the CHANGELOG.md if applicable
|
||||
6. Request review from maintainers
|
||||
5. Request review from maintainers
|
||||
|
||||
### PR Title Format
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@
|
||||
<img src="https://img.shields.io/badge/MCP-102_Tools-00ADD8?style=flat-square" alt="MCP" />
|
||||
</p>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/AI_Agents-Self--Improving-brightgreen?style=flat-square" alt="AI Agents" />
|
||||
<img src="https://img.shields.io/badge/Multi--Agent-Systems-blue?style=flat-square" alt="Multi-Agent" />
|
||||
@@ -82,12 +81,17 @@ Use Hive when you need:
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.11+ for agent development
|
||||
- Claude Code or Cursor for utilizing agent skills
|
||||
- Claude Code, Codex CLI, or Cursor for utilizing agent skills
|
||||
|
||||
> **Note for Windows Users:** It is strongly recommended to use **WSL (Windows Subsystem for Linux)** or **Git Bash** to run this framework. Some core automation scripts may not execute correctly in standard Command Prompt or PowerShell.
|
||||
|
||||
### Installation
|
||||
|
||||
>**Note**
|
||||
> Hive uses a `uv` workspace layout and is not installed with `pip install`.
|
||||
> Running `pip install -e .` from the repository root will create a placeholder package and Hive will not function correctly.
|
||||
> Please use the quickstart script below to set up the environment.
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/adenhq/hive.git
|
||||
@@ -121,6 +125,18 @@ hive tui
|
||||
hive run exports/your_agent_name --input '{"key": "value"}'
|
||||
```
|
||||
## Coding Agent Support
|
||||
### Codex CLI
|
||||
Hive includes native support for [OpenAI Codex CLI](https://github.com/openai/codex) (v0.101.0+).
|
||||
|
||||
1. **Config:** `.codex/config.toml` with `agent-builder` MCP server (tracked in git)
|
||||
2. **Skills:** `.agents/skills/` symlinks to Hive skills (tracked in git)
|
||||
3. **Launch:** Run `codex` in the repo root, then type `use hive`
|
||||
|
||||
Example:
|
||||
```
|
||||
codex> use hive
|
||||
```
|
||||
|
||||
### Opencode
|
||||
Hive includes native support for [Opencode](https://github.com/opencode-ai/opencode).
|
||||
|
||||
@@ -133,6 +149,16 @@ The agent has access to all Hive skills and can scaffold agents, add tools, and
|
||||
|
||||
**[📖 Complete Setup Guide](docs/environment-setup.md)** - Detailed instructions for agent development
|
||||
|
||||
### Antigravity IDE Support
|
||||
|
||||
Skills and MCP servers are also available in [Antigravity IDE](https://antigravity.google/) (Google's AI-powered IDE). **Easiest:** open a terminal in the hive repo folder and run (use `./` — the script is inside the repo):
|
||||
|
||||
```bash
|
||||
./scripts/setup-antigravity-mcp.sh
|
||||
```
|
||||
|
||||
**Important:** Always restart/refresh Antigravity IDE after running the setup script—MCP servers only load on startup. After restart, **agent-builder** and **tools** MCP servers should connect. Skills are under `.agent/skills/` (symlinks to `.claude/skills/`). See [docs/antigravity-setup.md](docs/antigravity-setup.md) for manual setup and troubleshooting.
|
||||
|
||||
## Features
|
||||
|
||||
- **[Goal-Driven Development](docs/key_concepts/goals_outcome.md)** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
|
||||
@@ -312,6 +338,7 @@ subgraph Expansion
|
||||
j2["Cursor"]
|
||||
j3["Opencode"]
|
||||
j4["Antigravity"]
|
||||
j5["Codex CLI"]
|
||||
end
|
||||
subgraph plat["Platform"]
|
||||
k1["JavaScript/TypeScript SDK"]
|
||||
|
||||
@@ -82,7 +82,7 @@ Register an MCP server as a tool source for your agent.
|
||||
"example_tool"
|
||||
],
|
||||
"total_mcp_servers": 1,
|
||||
"note": "MCP server 'tools' registered with 6 tools. These tools can now be used in llm_tool_use nodes."
|
||||
"note": "MCP server 'tools' registered with 6 tools. These tools can now be used in event_loop nodes."
|
||||
}
|
||||
```
|
||||
|
||||
@@ -149,7 +149,7 @@ List tools available from registered MCP servers.
|
||||
]
|
||||
},
|
||||
"total_tools": 6,
|
||||
"note": "Use these tool names in the 'tools' parameter when adding llm_tool_use nodes"
|
||||
"note": "Use these tool names in the 'tools' parameter when adding event_loop nodes"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -246,7 +246,7 @@ Here's a complete workflow for building an agent with MCP tools:
|
||||
"node_id": "web-searcher",
|
||||
"name": "Web Search",
|
||||
"description": "Search the web for information",
|
||||
"node_type": "llm_tool_use",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": "[\"query\"]",
|
||||
"output_keys": "[\"search_results\"]",
|
||||
"system_prompt": "Search for {query} using the web_search tool",
|
||||
|
||||
@@ -119,7 +119,7 @@ builder = WorkflowBuilder()
|
||||
builder.add_node(
|
||||
node_id="researcher",
|
||||
name="Web Researcher",
|
||||
node_type="llm_tool_use",
|
||||
node_type="event_loop",
|
||||
system_prompt="Research the topic using web_search",
|
||||
tools=["web_search"], # Tool from tools MCP server
|
||||
input_keys=["topic"],
|
||||
@@ -137,7 +137,7 @@ Tools from MCP servers can be referenced in your agent.json just like built-in t
|
||||
{
|
||||
"id": "searcher",
|
||||
"name": "Web Searcher",
|
||||
"node_type": "llm_tool_use",
|
||||
"node_type": "event_loop",
|
||||
"system_prompt": "Search for information about {topic}",
|
||||
"tools": ["web_search", "web_scrape"],
|
||||
"input_keys": ["topic"],
|
||||
|
||||
+17
-70
@@ -103,31 +103,20 @@ Add a processing node to the agent graph.
|
||||
- `node_id` (string, required): Unique node identifier
|
||||
- `name` (string, required): Human-readable name
|
||||
- `description` (string, required): What this node does
|
||||
- `node_type` (string, required): One of: `llm_generate`, `llm_tool_use`, `router`, `function`
|
||||
- `node_type` (string, required): Must be `event_loop` (the only valid type)
|
||||
- `input_keys` (string, required): JSON array of input variable names
|
||||
- `output_keys` (string, required): JSON array of output variable names
|
||||
- `system_prompt` (string, optional): System prompt for LLM nodes
|
||||
- `tools` (string, optional): JSON array of tool names for tool_use nodes
|
||||
- `routes` (string, optional): JSON object of route mappings for router nodes
|
||||
- `system_prompt` (string, optional): System prompt for the LLM
|
||||
- `tools` (string, optional): JSON array of tool names
|
||||
- `client_facing` (boolean, optional): Set to true for human-in-the-loop interaction
|
||||
|
||||
**Node Types:**
|
||||
**Node Type:**
|
||||
|
||||
1. **llm_generate**: Uses LLM to generate output from inputs
|
||||
- Requires: `system_prompt`
|
||||
- Tools: Not used
|
||||
|
||||
2. **llm_tool_use**: Uses LLM with tools to accomplish tasks
|
||||
- Requires: `system_prompt`, `tools`
|
||||
- Tools: Array of tool names (e.g., `["web_search", "web_fetch"]`)
|
||||
|
||||
3. **router**: LLM-powered routing to different paths
|
||||
- Requires: `system_prompt`, `routes`
|
||||
- Routes: Object mapping route names to target node IDs
|
||||
- Example: `{"pass": "success_node", "fail": "retry_node"}`
|
||||
|
||||
4. **function**: Executes a pre-defined function
|
||||
- System prompt describes the function behavior
|
||||
- No LLM calls, pure computation
|
||||
**event_loop**: LLM-powered node with self-correction loop
|
||||
- Requires: `system_prompt`
|
||||
- Optional: `tools` (array of tool names, e.g., `["web_search", "web_fetch"]`)
|
||||
- Optional: `client_facing` (set to true for HITL / user interaction)
|
||||
- Supports: iterative refinement, judge-based evaluation, tool use, streaming
|
||||
|
||||
**Example:**
|
||||
```json
|
||||
@@ -135,7 +124,7 @@ Add a processing node to the agent graph.
|
||||
"node_id": "search_sources",
|
||||
"name": "Search Sources",
|
||||
"description": "Searches for relevant sources on the topic",
|
||||
"node_type": "llm_tool_use",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": "[\"topic\", \"search_queries\"]",
|
||||
"output_keys": "[\"sources\", \"source_count\"]",
|
||||
"system_prompt": "Search for sources using the provided queries...",
|
||||
@@ -198,7 +187,7 @@ Export the validated graph as an agent specification.
|
||||
|
||||
**What it does:**
|
||||
1. Validates the graph
|
||||
2. Auto-generates missing edges from router routes
|
||||
2. Validates edge connectivity
|
||||
3. Writes files to disk:
|
||||
- `exports/{agent-name}/agent.json` - Full agent specification
|
||||
- `exports/{agent-name}/README.md` - Auto-generated documentation
|
||||
@@ -252,47 +241,6 @@ Test the complete agent graph with sample inputs.
|
||||
|
||||
---
|
||||
|
||||
### Evaluation Rules
|
||||
|
||||
#### `add_evaluation_rule`
|
||||
Add a rule for the HybridJudge to evaluate node outputs.
|
||||
|
||||
**Parameters:**
|
||||
- `rule_id` (string, required): Unique rule identifier
|
||||
- `description` (string, required): What this rule checks
|
||||
- `condition` (string, required): Python expression to evaluate
|
||||
- `action` (string, required): Action to take: `accept`, `retry`, `escalate`
|
||||
- `priority` (integer, optional): Rule priority (default: 0)
|
||||
- `feedback_template` (string, optional): Feedback message template
|
||||
|
||||
**Condition Examples:**
|
||||
- `'result.get("success") == True'` - Check for success flag
|
||||
- `'result.get("error_type") == "timeout"'` - Check error type
|
||||
- `'len(result.get("data", [])) > 0'` - Check for non-empty data
|
||||
|
||||
**Example:**
|
||||
```json
|
||||
{
|
||||
"rule_id": "timeout_retry",
|
||||
"description": "Retry on timeout errors",
|
||||
"condition": "result.get('error_type') == 'timeout'",
|
||||
"action": "retry",
|
||||
"priority": 10,
|
||||
"feedback_template": "Timeout occurred, retrying..."
|
||||
}
|
||||
```
|
||||
|
||||
#### `list_evaluation_rules`
|
||||
List all configured evaluation rules.
|
||||
|
||||
#### `remove_evaluation_rule`
|
||||
Remove an evaluation rule.
|
||||
|
||||
**Parameters:**
|
||||
- `rule_id` (string, required): Rule to remove
|
||||
|
||||
---
|
||||
|
||||
## Example Workflow
|
||||
|
||||
Here's a complete workflow for building a research agent:
|
||||
@@ -320,7 +268,7 @@ add_node(
|
||||
node_id="planner",
|
||||
name="Research Planner",
|
||||
description="Creates research strategy",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
input_keys='["topic"]',
|
||||
output_keys='["strategy", "queries"]',
|
||||
system_prompt="Analyze topic and create research plan..."
|
||||
@@ -330,7 +278,7 @@ add_node(
|
||||
node_id="searcher",
|
||||
name="Search Sources",
|
||||
description="Find relevant sources",
|
||||
node_type="llm_tool_use",
|
||||
node_type="event_loop",
|
||||
input_keys='["queries"]',
|
||||
output_keys='["sources"]',
|
||||
system_prompt="Search for sources...",
|
||||
@@ -359,10 +307,9 @@ The exported agent will be saved to `exports/research-agent/`.
|
||||
|
||||
1. **Start with the goal**: Define clear success criteria before building nodes
|
||||
2. **Test nodes individually**: Use `test_node` to verify each node works
|
||||
3. **Use router nodes for branching**: Don't create edges manually for routers - define routes and they'll be auto-generated
|
||||
4. **Add evaluation rules**: Help the judge evaluate outputs deterministically
|
||||
5. **Validate early, validate often**: Run `validate_graph` after adding nodes/edges
|
||||
6. **Check exports**: Review the generated README.md to verify your agent structure
|
||||
3. **Use conditional edges for branching**: Define condition_expr on edges for decision points
|
||||
4. **Validate early, validate often**: Run `validate_graph` after adding nodes/edges
|
||||
5. **Check exports**: Review the generated README.md to verify your agent structure
|
||||
|
||||
---
|
||||
|
||||
|
||||
+1
-1
@@ -73,7 +73,7 @@ To use the agent builder with Claude Desktop or other MCP clients, add this to y
|
||||
The MCP server provides tools for:
|
||||
- Creating agent building sessions
|
||||
- Defining goals with success criteria
|
||||
- Adding nodes (llm_generate, llm_tool_use, router, function)
|
||||
- Adding nodes (event_loop only)
|
||||
- Connecting nodes with edges
|
||||
- Validating and exporting agent graphs
|
||||
- Testing nodes and full agent graphs
|
||||
|
||||
@@ -68,7 +68,7 @@ from framework.graph.event_loop_node import ( # noqa: E402
|
||||
)
|
||||
from framework.graph.executor import GraphExecutor # noqa: E402
|
||||
from framework.graph.goal import Goal # noqa: E402
|
||||
from framework.graph.node import NodeSpec # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
@@ -654,7 +654,7 @@ NODE_SPECS = {
|
||||
id="sender",
|
||||
name="Sender",
|
||||
description="Send approved campaign emails",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["approved_emails"],
|
||||
output_keys=["send_results"],
|
||||
),
|
||||
@@ -823,11 +823,20 @@ def _send_email_via_resend(
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
|
||||
class SenderNode(NodeProtocol):
|
||||
"""Node wrapper for send_emails function."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
approved = ctx.input_data.get("approved_emails", "")
|
||||
result_str = send_emails(approved_emails=approved)
|
||||
ctx.memory.write("send_results", result_str)
|
||||
return NodeResult(success=True, output={"send_results": result_str})
|
||||
|
||||
|
||||
def send_emails(approved_emails: str = "") -> str:
|
||||
"""Send approved campaign emails via Resend, or log if unconfigured.
|
||||
|
||||
Called by FunctionNode which unpacks input_keys as kwargs.
|
||||
Returns a JSON string (FunctionNode wraps it in NodeResult).
|
||||
Returns a JSON string.
|
||||
"""
|
||||
approved = approved_emails
|
||||
if not approved:
|
||||
@@ -1780,7 +1789,7 @@ async def _run_pipeline(websocket, initial_message: str):
|
||||
)
|
||||
for nid, impl in nodes.items():
|
||||
executor.register_node(nid, impl)
|
||||
executor.register_function("sender", send_emails)
|
||||
executor.register_node("sender", SenderNode())
|
||||
|
||||
# --- Event forwarding: bus → WebSocket ---
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ Minimal Manual Agent Example
|
||||
This example demonstrates how to build and run an agent programmatically
|
||||
without using the Claude Code CLI or external LLM APIs.
|
||||
|
||||
It uses 'function' nodes to define logic in pure Python, making it perfect
|
||||
for understanding the core runtime loop:
|
||||
It uses custom NodeProtocol implementations to define logic in pure Python,
|
||||
making it perfect for understanding the core runtime loop:
|
||||
Setup -> Graph definition -> Execution -> Result
|
||||
|
||||
Run with:
|
||||
@@ -16,22 +16,33 @@ import asyncio
|
||||
|
||||
from framework.graph import EdgeCondition, EdgeSpec, Goal, GraphSpec, NodeSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
# 1. Define Node Logic (Pure Python Functions)
|
||||
def greet(name: str) -> str:
|
||||
# 1. Define Node Logic (Custom NodeProtocol implementations)
|
||||
class GreeterNode(NodeProtocol):
|
||||
"""Generate a simple greeting."""
|
||||
return f"Hello, {name}!"
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
name = ctx.input_data.get("name", "World")
|
||||
greeting = f"Hello, {name}!"
|
||||
ctx.memory.write("greeting", greeting)
|
||||
return NodeResult(success=True, output={"greeting": greeting})
|
||||
|
||||
|
||||
def uppercase(greeting: str) -> str:
|
||||
class UppercaserNode(NodeProtocol):
|
||||
"""Convert text to uppercase."""
|
||||
return greeting.upper()
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
greeting = ctx.input_data.get("greeting") or ctx.memory.read("greeting") or ""
|
||||
result = greeting.upper()
|
||||
ctx.memory.write("final_greeting", result)
|
||||
return NodeResult(success=True, output={"final_greeting": result})
|
||||
|
||||
|
||||
async def main():
|
||||
print("🚀 Setting up Manual Agent...")
|
||||
print("Setting up Manual Agent...")
|
||||
|
||||
# 2. Define the Goal
|
||||
# Every agent needs a goal with success criteria
|
||||
@@ -55,8 +66,7 @@ async def main():
|
||||
id="greeter",
|
||||
name="Greeter",
|
||||
description="Generates a simple greeting",
|
||||
node_type="function",
|
||||
function="greet", # Matches the registered function name
|
||||
node_type="event_loop",
|
||||
input_keys=["name"],
|
||||
output_keys=["greeting"],
|
||||
)
|
||||
@@ -65,8 +75,7 @@ async def main():
|
||||
id="uppercaser",
|
||||
name="Uppercaser",
|
||||
description="Converts greeting to uppercase",
|
||||
node_type="function",
|
||||
function="uppercase",
|
||||
node_type="event_loop",
|
||||
input_keys=["greeting"],
|
||||
output_keys=["final_greeting"],
|
||||
)
|
||||
@@ -98,23 +107,23 @@ async def main():
|
||||
runtime = Runtime(storage_path=Path("./agent_logs"))
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
|
||||
# 7. Register Function Implementations
|
||||
# Connect string names in NodeSpecs to actual Python functions
|
||||
executor.register_function("greeter", greet)
|
||||
executor.register_function("uppercaser", uppercase)
|
||||
# 7. Register Node Implementations
|
||||
# Connect node IDs in the graph to actual Python implementations
|
||||
executor.register_node("greeter", GreeterNode())
|
||||
executor.register_node("uppercaser", UppercaserNode())
|
||||
|
||||
# 8. Execute Agent
|
||||
print("▶ Executing agent with input: name='Alice'...")
|
||||
print("Executing agent with input: name='Alice'...")
|
||||
|
||||
result = await executor.execute(graph=graph, goal=goal, input_data={"name": "Alice"})
|
||||
|
||||
# 9. Verify Results
|
||||
if result.success:
|
||||
print("\n✅ Success!")
|
||||
print("\nSuccess!")
|
||||
print(f"Path taken: {' -> '.join(result.path)}")
|
||||
print(f"Final output: {result.output.get('final_greeting')}")
|
||||
else:
|
||||
print(f"\n❌ Failed: {result.error}")
|
||||
print(f"\nFailed: {result.error}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -122,7 +122,7 @@ async def example_4_custom_agent_with_mcp_tools():
|
||||
node_id="web-searcher",
|
||||
name="Web Search",
|
||||
description="Search the web for information",
|
||||
node_type="llm_tool_use",
|
||||
node_type="event_loop",
|
||||
system_prompt="Search for {query} and return the top results. Use the web_search tool.",
|
||||
tools=["web_search"], # This tool comes from tools MCP server
|
||||
input_keys=["query"],
|
||||
@@ -133,7 +133,7 @@ async def example_4_custom_agent_with_mcp_tools():
|
||||
node_id="summarizer",
|
||||
name="Summarize Results",
|
||||
description="Summarize the search results",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
system_prompt="Summarize the following search results in 2-3 sentences: {search_results}",
|
||||
input_keys=["search_results"],
|
||||
output_keys=["summary"],
|
||||
|
||||
@@ -245,20 +245,14 @@ class GraphBuilder:
|
||||
warnings.append(f"Node '{node.id}' should have a description")
|
||||
|
||||
# Type-specific validation
|
||||
if node.node_type == "llm_tool_use":
|
||||
if not node.tools:
|
||||
errors.append(f"LLM tool node '{node.id}' must specify tools")
|
||||
if not node.system_prompt:
|
||||
warnings.append(f"LLM node '{node.id}' should have a system_prompt")
|
||||
if node.node_type == "event_loop":
|
||||
if node.tools and not node.system_prompt:
|
||||
warnings.append(f"Event loop node '{node.id}' should have a system_prompt")
|
||||
|
||||
if node.node_type == "router":
|
||||
if not node.routes:
|
||||
errors.append(f"Router node '{node.id}' must specify routes")
|
||||
|
||||
if node.node_type == "function":
|
||||
if not node.function:
|
||||
errors.append(f"Function node '{node.id}' must specify function name")
|
||||
|
||||
# Check input/output keys
|
||||
if not node.input_keys:
|
||||
suggestions.append(f"Consider specifying input_keys for '{node.id}'")
|
||||
@@ -400,9 +394,13 @@ class GraphBuilder:
|
||||
if not terminal_candidates and self.session.nodes:
|
||||
warnings.append("No terminal nodes found (all nodes have outgoing edges)")
|
||||
|
||||
# Check reachability
|
||||
# Check reachability from ALL entry candidates (not just the first one).
|
||||
# Agents with async entry points have multiple nodes with no incoming
|
||||
# edges (e.g., a primary entry node and an event-driven entry node).
|
||||
if entry_candidates and self.session.nodes:
|
||||
reachable = self._compute_reachable(entry_candidates[0])
|
||||
reachable = set()
|
||||
for candidate in entry_candidates:
|
||||
reachable |= self._compute_reachable(candidate)
|
||||
unreachable = [n.id for n in self.session.nodes if n.id not in reachable]
|
||||
if unreachable:
|
||||
errors.append(f"Unreachable nodes: {unreachable}")
|
||||
|
||||
@@ -59,6 +59,13 @@ from .provider import (
|
||||
CredentialProvider,
|
||||
StaticProvider,
|
||||
)
|
||||
from .setup import (
|
||||
CredentialSetupSession,
|
||||
MissingCredential,
|
||||
SetupResult,
|
||||
detect_missing_credentials_from_nodes,
|
||||
run_credential_setup_cli,
|
||||
)
|
||||
from .storage import (
|
||||
CompositeStorage,
|
||||
CredentialStorage,
|
||||
@@ -68,6 +75,7 @@ from .storage import (
|
||||
)
|
||||
from .store import CredentialStore
|
||||
from .template import TemplateResolver
|
||||
from .validation import ensure_credential_key_env, validate_agent_credentials
|
||||
|
||||
# Aden sync components (lazy import to avoid httpx dependency when not needed)
|
||||
# Usage: from core.framework.credentials.aden import AdenSyncProvider
|
||||
@@ -111,6 +119,15 @@ __all__ = [
|
||||
"CredentialRefreshError",
|
||||
"CredentialValidationError",
|
||||
"CredentialDecryptionError",
|
||||
# Validation
|
||||
"ensure_credential_key_env",
|
||||
"validate_agent_credentials",
|
||||
# Interactive setup
|
||||
"CredentialSetupSession",
|
||||
"MissingCredential",
|
||||
"SetupResult",
|
||||
"detect_missing_credentials_from_nodes",
|
||||
"run_credential_setup_cli",
|
||||
# Aden sync (optional - requires httpx)
|
||||
"AdenSyncProvider",
|
||||
"AdenCredentialClient",
|
||||
|
||||
@@ -0,0 +1,745 @@
|
||||
"""
|
||||
Interactive credential setup for CLI applications.
|
||||
|
||||
Provides a modular, reusable credential setup flow that can be triggered
|
||||
when validate_agent_credentials() fails. Works with both TUI and headless CLIs.
|
||||
|
||||
Usage:
|
||||
from framework.credentials.setup import CredentialSetupSession
|
||||
|
||||
# From agent path
|
||||
session = CredentialSetupSession.from_agent_path("exports/my-agent")
|
||||
result = session.run_interactive()
|
||||
|
||||
# From nodes directly
|
||||
session = CredentialSetupSession.from_nodes(nodes)
|
||||
result = session.run_interactive()
|
||||
|
||||
# With custom I/O (for integration with other UIs)
|
||||
session = CredentialSetupSession(
|
||||
missing=missing_creds,
|
||||
input_fn=my_input,
|
||||
print_fn=my_print,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
|
||||
# ANSI colors for terminal output
|
||||
class Colors:
|
||||
RED = "\033[0;31m"
|
||||
GREEN = "\033[0;32m"
|
||||
YELLOW = "\033[1;33m"
|
||||
BLUE = "\033[0;34m"
|
||||
CYAN = "\033[0;36m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
NC = "\033[0m" # No Color
|
||||
|
||||
@classmethod
|
||||
def disable(cls):
|
||||
"""Disable colors (for non-TTY output)."""
|
||||
cls.RED = cls.GREEN = cls.YELLOW = cls.BLUE = ""
|
||||
cls.CYAN = cls.BOLD = cls.DIM = cls.NC = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MissingCredential:
|
||||
"""A credential that needs to be configured."""
|
||||
|
||||
credential_name: str
|
||||
"""Internal credential name (e.g., 'brave_search')"""
|
||||
|
||||
env_var: str
|
||||
"""Environment variable name (e.g., 'BRAVE_SEARCH_API_KEY')"""
|
||||
|
||||
description: str
|
||||
"""Human-readable description"""
|
||||
|
||||
help_url: str
|
||||
"""URL where user can obtain credential"""
|
||||
|
||||
api_key_instructions: str
|
||||
"""Step-by-step instructions for getting API key"""
|
||||
|
||||
tools: list[str] = field(default_factory=list)
|
||||
"""Tools that require this credential"""
|
||||
|
||||
node_types: list[str] = field(default_factory=list)
|
||||
"""Node types that require this credential"""
|
||||
|
||||
aden_supported: bool = False
|
||||
"""Whether Aden OAuth flow is supported"""
|
||||
|
||||
direct_api_key_supported: bool = True
|
||||
"""Whether direct API key entry is supported"""
|
||||
|
||||
credential_id: str = ""
|
||||
"""Credential store ID"""
|
||||
|
||||
credential_key: str = "api_key"
|
||||
"""Key name within the credential"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetupResult:
|
||||
"""Result of credential setup session."""
|
||||
|
||||
success: bool
|
||||
"""Whether all required credentials were configured"""
|
||||
|
||||
configured: list[str] = field(default_factory=list)
|
||||
"""Credentials that were successfully set up"""
|
||||
|
||||
skipped: list[str] = field(default_factory=list)
|
||||
"""Credentials user chose to skip"""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
"""Any errors encountered"""
|
||||
|
||||
|
||||
class CredentialSetupSession:
|
||||
"""
|
||||
Interactive credential setup session.
|
||||
|
||||
Can be used by any CLI (runner, coding agent, etc.) to guide users
|
||||
through credential configuration when validation fails.
|
||||
|
||||
Example:
|
||||
from framework.credentials.setup import CredentialSetupSession
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
try:
|
||||
validate_agent_credentials(nodes)
|
||||
except CredentialError:
|
||||
session = CredentialSetupSession.from_nodes(nodes)
|
||||
result = session.run_interactive()
|
||||
if result.success:
|
||||
# Retry - credentials are now configured
|
||||
validate_agent_credentials(nodes)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
missing: list[MissingCredential],
|
||||
input_fn: Callable[[str], str] | None = None,
|
||||
print_fn: Callable[[str], None] | None = None,
|
||||
password_fn: Callable[[str], str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the setup session.
|
||||
|
||||
Args:
|
||||
missing: List of credentials that need setup
|
||||
input_fn: Custom input function (default: built-in input)
|
||||
print_fn: Custom print function (default: built-in print)
|
||||
password_fn: Custom password input function (default: getpass.getpass)
|
||||
"""
|
||||
self.missing = missing
|
||||
self.input_fn = input_fn or input
|
||||
self.print_fn = print_fn or print
|
||||
self.password_fn = password_fn or getpass.getpass
|
||||
|
||||
# Disable colors if not a TTY
|
||||
if not sys.stdout.isatty():
|
||||
Colors.disable()
|
||||
|
||||
@classmethod
|
||||
def from_nodes(cls, nodes: list[NodeSpec]) -> CredentialSetupSession:
|
||||
"""Create a setup session by detecting missing credentials from nodes."""
|
||||
missing = detect_missing_credentials_from_nodes(nodes)
|
||||
return cls(missing)
|
||||
|
||||
@classmethod
|
||||
def from_agent_path(cls, agent_path: str | Path) -> CredentialSetupSession:
|
||||
"""Create a setup session for an agent by path."""
|
||||
agent_path = Path(agent_path)
|
||||
|
||||
# Load agent to get nodes
|
||||
agent_json = agent_path / "agent.json"
|
||||
agent_py = agent_path / "agent.py"
|
||||
|
||||
nodes = []
|
||||
if agent_py.exists():
|
||||
# Python-based agent
|
||||
nodes = _load_nodes_from_python_agent(agent_path)
|
||||
elif agent_json.exists():
|
||||
# JSON-based agent
|
||||
nodes = _load_nodes_from_json_agent(agent_json)
|
||||
|
||||
missing = detect_missing_credentials_from_nodes(nodes)
|
||||
return cls(missing)
|
||||
|
||||
def run_interactive(self) -> SetupResult:
|
||||
"""Run the interactive setup flow."""
|
||||
configured: list[str] = []
|
||||
skipped: list[str] = []
|
||||
errors: list[str] = []
|
||||
|
||||
if not self.missing:
|
||||
self._print(f"\n{Colors.GREEN}✓ All credentials are already configured!{Colors.NC}\n")
|
||||
return SetupResult(success=True)
|
||||
|
||||
self._print_header()
|
||||
|
||||
# Ensure HIVE_CREDENTIAL_KEY is set before storing anything
|
||||
if not self._ensure_credential_key():
|
||||
return SetupResult(
|
||||
success=False,
|
||||
errors=["Failed to initialize credential store encryption key"],
|
||||
)
|
||||
|
||||
for cred in self.missing:
|
||||
try:
|
||||
result = self._setup_single_credential(cred)
|
||||
if result:
|
||||
configured.append(cred.credential_name)
|
||||
else:
|
||||
skipped.append(cred.credential_name)
|
||||
except KeyboardInterrupt:
|
||||
self._print(f"\n{Colors.YELLOW}Setup interrupted.{Colors.NC}")
|
||||
skipped.append(cred.credential_name)
|
||||
break
|
||||
except Exception as e:
|
||||
errors.append(f"{cred.credential_name}: {e}")
|
||||
|
||||
self._print_summary(configured, skipped, errors)
|
||||
|
||||
return SetupResult(
|
||||
success=len(errors) == 0 and len(skipped) == 0,
|
||||
configured=configured,
|
||||
skipped=skipped,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
def _print(self, msg: str) -> None:
|
||||
"""Print a message."""
|
||||
self.print_fn(msg)
|
||||
|
||||
def _input(self, prompt: str) -> str:
|
||||
"""Get input from user."""
|
||||
return self.input_fn(prompt)
|
||||
|
||||
def _print_header(self) -> None:
|
||||
"""Print the setup header."""
|
||||
self._print("")
|
||||
self._print(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
self._print(f"{Colors.BOLD} CREDENTIAL SETUP{Colors.NC}")
|
||||
self._print(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
self._print("")
|
||||
self._print(f" {len(self.missing)} credential(s) need to be configured:")
|
||||
for cred in self.missing:
|
||||
affected = cred.tools or cred.node_types
|
||||
self._print(f" • {cred.env_var} ({', '.join(affected)})")
|
||||
self._print("")
|
||||
|
||||
def _ensure_credential_key(self) -> bool:
|
||||
"""Ensure HIVE_CREDENTIAL_KEY is available for encrypted storage."""
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
return True
|
||||
|
||||
# Try to load from shell config
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import check_env_var_in_shell_config
|
||||
|
||||
found, value = check_env_var_in_shell_config("HIVE_CREDENTIAL_KEY")
|
||||
if found and value:
|
||||
os.environ["HIVE_CREDENTIAL_KEY"] = value
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Generate a new key
|
||||
self._print(f"{Colors.YELLOW}Initializing credential store...{Colors.NC}")
|
||||
try:
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
generated_key = Fernet.generate_key().decode()
|
||||
os.environ["HIVE_CREDENTIAL_KEY"] = generated_key
|
||||
|
||||
# Save to shell config
|
||||
self._save_key_to_shell_config(generated_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
self._print(f"{Colors.RED}Failed to initialize credential store: {e}{Colors.NC}")
|
||||
return False
|
||||
|
||||
def _save_key_to_shell_config(self, key: str) -> None:
|
||||
"""Save HIVE_CREDENTIAL_KEY to shell config."""
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import (
|
||||
add_env_var_to_shell_config,
|
||||
)
|
||||
|
||||
success, config_path = add_env_var_to_shell_config(
|
||||
"HIVE_CREDENTIAL_KEY",
|
||||
key,
|
||||
comment="Encryption key for Hive credential store",
|
||||
)
|
||||
if success:
|
||||
self._print(f"{Colors.GREEN}✓ Encryption key saved to {config_path}{Colors.NC}")
|
||||
except Exception:
|
||||
# Fallback: just tell the user
|
||||
self._print("\n")
|
||||
self._print(
|
||||
f"{Colors.YELLOW}Add this to your shell config (~/.zshrc or ~/.bashrc):{Colors.NC}"
|
||||
)
|
||||
self._print(f' export HIVE_CREDENTIAL_KEY="{key}"')
|
||||
|
||||
def _setup_single_credential(self, cred: MissingCredential) -> bool:
|
||||
"""Set up a single credential. Returns True if configured."""
|
||||
self._print(f"\n{Colors.CYAN}{'─' * 60}{Colors.NC}")
|
||||
self._print(f"{Colors.BOLD}Setting up: {cred.credential_name}{Colors.NC}")
|
||||
affected = cred.tools or cred.node_types
|
||||
self._print(f"{Colors.DIM}Required for: {', '.join(affected)}{Colors.NC}")
|
||||
if cred.description:
|
||||
self._print(f"{Colors.DIM}{cred.description}{Colors.NC}")
|
||||
self._print(f"{Colors.CYAN}{'─' * 60}{Colors.NC}")
|
||||
|
||||
# Show auth options
|
||||
options = self._get_auth_options(cred)
|
||||
choice = self._prompt_choice(options)
|
||||
|
||||
if choice == "skip":
|
||||
return False
|
||||
elif choice == "aden":
|
||||
return self._setup_via_aden(cred)
|
||||
elif choice == "direct":
|
||||
return self._setup_direct_api_key(cred)
|
||||
|
||||
return False
|
||||
|
||||
def _get_auth_options(self, cred: MissingCredential) -> list[tuple[str, str, str]]:
|
||||
"""Get available auth options as (key, label, description) tuples."""
|
||||
options = []
|
||||
|
||||
if cred.direct_api_key_supported:
|
||||
options.append(
|
||||
(
|
||||
"direct",
|
||||
"Enter API key directly",
|
||||
"Paste your API key from the provider's dashboard",
|
||||
)
|
||||
)
|
||||
|
||||
if cred.aden_supported:
|
||||
options.append(
|
||||
(
|
||||
"aden",
|
||||
"Use Aden Platform (OAuth)",
|
||||
"Secure OAuth2 flow via hive.adenhq.com",
|
||||
)
|
||||
)
|
||||
|
||||
options.append(
|
||||
(
|
||||
"skip",
|
||||
"Skip for now",
|
||||
"Configure this credential later",
|
||||
)
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
def _prompt_choice(self, options: list[tuple[str, str, str]]) -> str:
|
||||
"""Prompt user to choose from options."""
|
||||
self._print("")
|
||||
for i, (key, label, desc) in enumerate(options, 1):
|
||||
if key == "skip":
|
||||
self._print(f" {Colors.DIM}{i}) {label}{Colors.NC}")
|
||||
else:
|
||||
self._print(f" {Colors.CYAN}{i}){Colors.NC} {label}")
|
||||
self._print(f" {Colors.DIM}{desc}{Colors.NC}")
|
||||
self._print("")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice_str = self._input(f"Select option (1-{len(options)}): ").strip()
|
||||
if not choice_str:
|
||||
continue
|
||||
choice_num = int(choice_str)
|
||||
if 1 <= choice_num <= len(options):
|
||||
return options[choice_num - 1][0]
|
||||
except ValueError:
|
||||
pass
|
||||
self._print(f"{Colors.RED}Invalid choice. Enter 1-{len(options)}{Colors.NC}")
|
||||
|
||||
def _setup_direct_api_key(self, cred: MissingCredential) -> bool:
|
||||
"""Guide user through direct API key setup."""
|
||||
# Show instructions
|
||||
if cred.api_key_instructions:
|
||||
self._print(f"\n{Colors.BOLD}Setup Instructions:{Colors.NC}")
|
||||
self._print(cred.api_key_instructions)
|
||||
|
||||
if cred.help_url:
|
||||
self._print(f"\n{Colors.CYAN}Get your API key at:{Colors.NC} {cred.help_url}")
|
||||
|
||||
# Collect key (use password input to hide the value)
|
||||
self._print("")
|
||||
try:
|
||||
api_key = self.password_fn(f"Paste your {cred.env_var}: ").strip()
|
||||
except Exception:
|
||||
# Fallback to regular input if password input fails
|
||||
api_key = self._input(f"Paste your {cred.env_var}: ").strip()
|
||||
|
||||
if not api_key:
|
||||
self._print(f"{Colors.YELLOW}No value entered. Skipping.{Colors.NC}")
|
||||
return False
|
||||
|
||||
# Health check
|
||||
health_result = self._run_health_check(cred, api_key)
|
||||
if health_result is not None:
|
||||
if health_result["valid"]:
|
||||
self._print(f"{Colors.GREEN}✓ {health_result['message']}{Colors.NC}")
|
||||
else:
|
||||
self._print(f"{Colors.YELLOW}⚠ {health_result['message']}{Colors.NC}")
|
||||
confirm = self._input("Continue anyway? [y/N]: ").strip().lower()
|
||||
if confirm != "y":
|
||||
return False
|
||||
|
||||
# Store credential
|
||||
self._store_credential(cred, api_key)
|
||||
return True
|
||||
|
||||
def _setup_via_aden(self, cred: MissingCredential) -> bool:
|
||||
"""Guide user through Aden OAuth flow."""
|
||||
self._print(f"\n{Colors.BOLD}Aden Platform Setup{Colors.NC}")
|
||||
self._print("This will sync credentials from your Aden account.")
|
||||
self._print("")
|
||||
|
||||
# Check for ADEN_API_KEY
|
||||
aden_key = os.environ.get("ADEN_API_KEY")
|
||||
if not aden_key:
|
||||
self._print("You need an Aden API key to use this method.")
|
||||
self._print(f"{Colors.CYAN}Get one at:{Colors.NC} https://hive.adenhq.com")
|
||||
self._print("")
|
||||
|
||||
try:
|
||||
aden_key = self.password_fn("Paste your ADEN_API_KEY: ").strip()
|
||||
except Exception:
|
||||
aden_key = self._input("Paste your ADEN_API_KEY: ").strip()
|
||||
|
||||
if not aden_key:
|
||||
self._print(f"{Colors.YELLOW}No key entered. Skipping.{Colors.NC}")
|
||||
return False
|
||||
|
||||
os.environ["ADEN_API_KEY"] = aden_key
|
||||
|
||||
# Save to shell config
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import add_env_var_to_shell_config
|
||||
|
||||
add_env_var_to_shell_config(
|
||||
"ADEN_API_KEY",
|
||||
aden_key,
|
||||
comment="Aden Platform API key",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Sync from Aden
|
||||
try:
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
store = CredentialStore.with_aden_sync(
|
||||
base_url="https://api.adenhq.com",
|
||||
auto_sync=True,
|
||||
)
|
||||
|
||||
# Check if the credential was synced
|
||||
cred_id = cred.credential_id or cred.credential_name
|
||||
if store.is_available(cred_id):
|
||||
self._print(f"{Colors.GREEN}✓ {cred.credential_name} synced from Aden{Colors.NC}")
|
||||
# Export to current session
|
||||
try:
|
||||
value = store.get_key(cred_id, cred.credential_key)
|
||||
if value:
|
||||
os.environ[cred.env_var] = value
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
else:
|
||||
self._print(
|
||||
f"{Colors.YELLOW}⚠ {cred.credential_name} not found in Aden account.{Colors.NC}"
|
||||
)
|
||||
self._print("Please connect this integration on https://hive.adenhq.com first.")
|
||||
return False
|
||||
except Exception as e:
|
||||
self._print(f"{Colors.RED}Failed to sync from Aden: {e}{Colors.NC}")
|
||||
return False
|
||||
|
||||
def _run_health_check(self, cred: MissingCredential, value: str) -> dict[str, Any] | None:
|
||||
"""Run health check on credential value."""
|
||||
try:
|
||||
from aden_tools.credentials import check_credential_health
|
||||
|
||||
result = check_credential_health(cred.credential_name, value)
|
||||
return {
|
||||
"valid": result.valid,
|
||||
"message": result.message,
|
||||
"details": result.details,
|
||||
}
|
||||
except Exception:
|
||||
# No health checker available
|
||||
return None
|
||||
|
||||
def _store_credential(self, cred: MissingCredential, value: str) -> None:
|
||||
"""Store credential in encrypted store and export to env."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from framework.credentials import CredentialKey, CredentialObject, CredentialStore
|
||||
|
||||
try:
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
cred_id = cred.credential_id or cred.credential_name
|
||||
key_name = cred.credential_key or "api_key"
|
||||
|
||||
cred_obj = CredentialObject(
|
||||
id=cred_id,
|
||||
name=cred.description or cred.credential_name,
|
||||
keys={key_name: CredentialKey(name=key_name, value=SecretStr(value))},
|
||||
)
|
||||
store.save_credential(cred_obj)
|
||||
self._print(f"{Colors.GREEN}✓ Stored in ~/.hive/credentials/{Colors.NC}")
|
||||
except Exception as e:
|
||||
self._print(f"{Colors.YELLOW}⚠ Could not store in credential store: {e}{Colors.NC}")
|
||||
|
||||
# Export to current session
|
||||
os.environ[cred.env_var] = value
|
||||
self._print(f"{Colors.GREEN}✓ Exported to current session{Colors.NC}")
|
||||
|
||||
def _print_summary(self, configured: list[str], skipped: list[str], errors: list[str]) -> None:
|
||||
"""Print final summary."""
|
||||
self._print("")
|
||||
self._print(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
self._print(f"{Colors.BOLD} SETUP COMPLETE{Colors.NC}")
|
||||
self._print(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
|
||||
if configured:
|
||||
self._print(f"\n{Colors.GREEN}✓ Configured:{Colors.NC}")
|
||||
for name in configured:
|
||||
self._print(f" • {name}")
|
||||
|
||||
if skipped:
|
||||
self._print(f"\n{Colors.YELLOW}⏭ Skipped:{Colors.NC}")
|
||||
for name in skipped:
|
||||
self._print(f" • {name}")
|
||||
|
||||
if errors:
|
||||
self._print(f"\n{Colors.RED}✗ Errors:{Colors.NC}")
|
||||
for err in errors:
|
||||
self._print(f" • {err}")
|
||||
|
||||
if not skipped and not errors:
|
||||
self._print(f"\n{Colors.GREEN}All credentials configured successfully!{Colors.NC}")
|
||||
elif skipped:
|
||||
self._print(f"\n{Colors.YELLOW}Note: Skipped credentials must be configured ")
|
||||
self._print(f"before running the agent.{Colors.NC}")
|
||||
|
||||
self._print("")
|
||||
|
||||
|
||||
def detect_missing_credentials_from_nodes(nodes: list) -> list[MissingCredential]:
|
||||
"""
|
||||
Detect missing credentials for a list of nodes.
|
||||
|
||||
Args:
|
||||
nodes: List of NodeSpec objects
|
||||
|
||||
Returns:
|
||||
List of MissingCredential objects for credentials that need setup
|
||||
"""
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
from framework.credentials import CredentialStore
|
||||
from framework.credentials.storage import (
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
except ImportError:
|
||||
return []
|
||||
|
||||
# Collect required tools and node types
|
||||
required_tools: set[str] = set()
|
||||
node_types: set[str] = set()
|
||||
|
||||
for node in nodes:
|
||||
if hasattr(node, "tools") and node.tools:
|
||||
required_tools.update(node.tools)
|
||||
if hasattr(node, "node_type"):
|
||||
node_types.add(node.node_type)
|
||||
|
||||
# Build credential store to check availability
|
||||
env_mapping = {
|
||||
(spec.credential_id or name): spec.env_var for name, spec in CREDENTIAL_SPECS.items()
|
||||
}
|
||||
storages: list = [EnvVarStorage(env_mapping=env_mapping)]
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
storages.insert(0, EncryptedFileStorage())
|
||||
if len(storages) == 1:
|
||||
storage = storages[0]
|
||||
else:
|
||||
storage = CompositeStorage(primary=storages[0], fallbacks=storages[1:])
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Build reverse mappings
|
||||
tool_to_cred: dict[str, str] = {}
|
||||
node_type_to_cred: dict[str, str] = {}
|
||||
for cred_name, spec in CREDENTIAL_SPECS.items():
|
||||
for tool_name in spec.tools:
|
||||
tool_to_cred[tool_name] = cred_name
|
||||
for nt in spec.node_types:
|
||||
node_type_to_cred[nt] = cred_name
|
||||
|
||||
missing: list[MissingCredential] = []
|
||||
checked: set[str] = set()
|
||||
|
||||
# Check tool credentials
|
||||
for tool_name in sorted(required_tools):
|
||||
cred_name = tool_to_cred.get(tool_name)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected_tools = sorted(t for t in required_tools if t in spec.tools)
|
||||
missing.append(
|
||||
MissingCredential(
|
||||
credential_name=cred_name,
|
||||
env_var=spec.env_var,
|
||||
description=spec.description,
|
||||
help_url=spec.help_url,
|
||||
api_key_instructions=spec.api_key_instructions,
|
||||
tools=affected_tools,
|
||||
aden_supported=spec.aden_supported,
|
||||
direct_api_key_supported=spec.direct_api_key_supported,
|
||||
credential_id=spec.credential_id,
|
||||
credential_key=spec.credential_key,
|
||||
)
|
||||
)
|
||||
|
||||
# Check node type credentials
|
||||
for nt in sorted(node_types):
|
||||
cred_name = node_type_to_cred.get(nt)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected_types = sorted(t for t in node_types if t in spec.node_types)
|
||||
missing.append(
|
||||
MissingCredential(
|
||||
credential_name=cred_name,
|
||||
env_var=spec.env_var,
|
||||
description=spec.description,
|
||||
help_url=spec.help_url,
|
||||
api_key_instructions=spec.api_key_instructions,
|
||||
node_types=affected_types,
|
||||
aden_supported=spec.aden_supported,
|
||||
direct_api_key_supported=spec.direct_api_key_supported,
|
||||
credential_id=spec.credential_id,
|
||||
credential_key=spec.credential_key,
|
||||
)
|
||||
)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def _load_nodes_from_python_agent(agent_path: Path) -> list:
|
||||
"""Load nodes from a Python-based agent."""
|
||||
import importlib.util
|
||||
|
||||
agent_py = agent_path / "agent.py"
|
||||
if not agent_py.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
# Add agent path and its parent to sys.path so imports work
|
||||
paths_to_add = [str(agent_path), str(agent_path.parent)]
|
||||
for p in paths_to_add:
|
||||
if p not in sys.path:
|
||||
sys.path.insert(0, p)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
f"{agent_path.name}.agent",
|
||||
agent_py,
|
||||
submodule_search_locations=[str(agent_path)],
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return getattr(module, "nodes", [])
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _load_nodes_from_json_agent(agent_json: Path) -> list:
|
||||
"""Load nodes from a JSON-based agent."""
|
||||
try:
|
||||
with open(agent_json) as f:
|
||||
data = json.load(f)
|
||||
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
nodes_data = data.get("graph", {}).get("nodes", [])
|
||||
nodes = []
|
||||
for node_data in nodes_data:
|
||||
nodes.append(
|
||||
NodeSpec(
|
||||
id=node_data.get("id", ""),
|
||||
name=node_data.get("name", ""),
|
||||
description=node_data.get("description", ""),
|
||||
node_type=node_data.get("node_type", ""),
|
||||
tools=node_data.get("tools", []),
|
||||
input_keys=node_data.get("input_keys", []),
|
||||
output_keys=node_data.get("output_keys", []),
|
||||
)
|
||||
)
|
||||
return nodes
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def run_credential_setup_cli(agent_path: str | Path | None = None) -> int:
|
||||
"""
|
||||
Standalone CLI entry point for credential setup.
|
||||
|
||||
Can be called from:
|
||||
- `hive setup-credentials <agent>`
|
||||
- After CredentialError in runner CLI
|
||||
- From coding agent CLI
|
||||
|
||||
Args:
|
||||
agent_path: Optional path to agent directory
|
||||
|
||||
Returns:
|
||||
Exit code (0 = success, 1 = failure/skipped)
|
||||
"""
|
||||
if agent_path:
|
||||
session = CredentialSetupSession.from_agent_path(agent_path)
|
||||
else:
|
||||
# No agent specified - detect from current context or show error
|
||||
print("Usage: hive setup-credentials <agent_path>")
|
||||
return 1
|
||||
|
||||
result = session.run_interactive()
|
||||
return 0 if result.success else 1
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Credential validation utilities.
|
||||
|
||||
Provides reusable credential validation for agents, whether run through
|
||||
the AgentRunner or directly via GraphExecutor.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_credential_key_env() -> None:
|
||||
"""Load HIVE_CREDENTIAL_KEY and ADEN_API_KEY from shell config if not in environment.
|
||||
|
||||
The setup-credentials skill writes these to ~/.zshrc or ~/.bashrc.
|
||||
If the user hasn't sourced their config in the current shell, this reads
|
||||
them directly so the runner (and any MCP subprocesses it spawns) can:
|
||||
- Unlock the encrypted credential store (HIVE_CREDENTIAL_KEY)
|
||||
- Enable Aden OAuth sync for Google/HubSpot/etc. (ADEN_API_KEY)
|
||||
"""
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import check_env_var_in_shell_config
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
for var_name in ("HIVE_CREDENTIAL_KEY", "ADEN_API_KEY"):
|
||||
if os.environ.get(var_name):
|
||||
continue
|
||||
found, value = check_env_var_in_shell_config(var_name)
|
||||
if found and value:
|
||||
os.environ[var_name] = value
|
||||
logger.debug("Loaded %s from shell config", var_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CredentialCheck:
|
||||
"""Result of checking a single credential."""
|
||||
|
||||
env_var: str
|
||||
source: str
|
||||
used_by: str
|
||||
available: bool
|
||||
help_url: str = ""
|
||||
|
||||
|
||||
def validate_agent_credentials(nodes: list, quiet: bool = False) -> None:
|
||||
"""Check that required credentials are available before running an agent.
|
||||
|
||||
Uses CredentialStoreAdapter.default() which includes Aden sync support,
|
||||
correctly resolving OAuth credentials stored under hashed IDs.
|
||||
|
||||
Prints a summary of all credentials and their sources (encrypted store, env var).
|
||||
Raises CredentialError with actionable guidance if any are missing.
|
||||
|
||||
Args:
|
||||
nodes: List of NodeSpec objects from the agent graph.
|
||||
quiet: If True, suppress the credential summary output.
|
||||
"""
|
||||
# Collect required tools and node types
|
||||
required_tools = {tool for node in nodes if node.tools for tool in node.tools}
|
||||
node_types = {node.node_type for node in nodes}
|
||||
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
except ImportError:
|
||||
return # aden_tools not installed, skip check
|
||||
|
||||
from framework.credentials.storage import CompositeStorage, EncryptedFileStorage, EnvVarStorage
|
||||
from framework.credentials.store import CredentialStore
|
||||
|
||||
# Build credential store
|
||||
env_mapping = {
|
||||
(spec.credential_id or name): spec.env_var for name, spec in CREDENTIAL_SPECS.items()
|
||||
}
|
||||
storages: list = [EnvVarStorage(env_mapping=env_mapping)]
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
storages.insert(0, EncryptedFileStorage())
|
||||
if len(storages) == 1:
|
||||
storage = storages[0]
|
||||
else:
|
||||
storage = CompositeStorage(primary=storages[0], fallbacks=storages[1:])
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Build reverse mappings
|
||||
tool_to_cred: dict[str, str] = {}
|
||||
node_type_to_cred: dict[str, str] = {}
|
||||
for cred_name, spec in CREDENTIAL_SPECS.items():
|
||||
for tool_name in spec.tools:
|
||||
tool_to_cred[tool_name] = cred_name
|
||||
for nt in spec.node_types:
|
||||
node_type_to_cred[nt] = cred_name
|
||||
|
||||
missing: list[str] = []
|
||||
checked: set[str] = set()
|
||||
|
||||
# Check tool credentials
|
||||
for tool_name in sorted(required_tools):
|
||||
cred_name = tool_to_cred.get(tool_name)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected = sorted(t for t in required_tools if t in spec.tools)
|
||||
entry = f" {spec.env_var} for {', '.join(affected)}"
|
||||
if spec.help_url:
|
||||
entry += f"\n Get it at: {spec.help_url}"
|
||||
missing.append(entry)
|
||||
|
||||
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
|
||||
for nt in sorted(node_types):
|
||||
cred_name = node_type_to_cred.get(nt)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected_types = sorted(t for t in node_types if t in spec.node_types)
|
||||
entry = f" {spec.env_var} for {', '.join(affected_types)} nodes"
|
||||
if spec.help_url:
|
||||
entry += f"\n Get it at: {spec.help_url}"
|
||||
missing.append(entry)
|
||||
|
||||
if missing:
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
lines = ["Missing required credentials:\n"]
|
||||
lines.extend(missing)
|
||||
lines.append(
|
||||
"\nTo fix: run /hive-credentials in Claude Code."
|
||||
"\nIf you've already set up credentials, restart your terminal to load them."
|
||||
)
|
||||
raise CredentialError("\n".join(lines))
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
|
||||
"""Graph structures: Goals, Nodes, Edges, and Execution."""
|
||||
|
||||
from framework.graph.client_io import (
|
||||
ActiveNodeClientIO,
|
||||
@@ -6,7 +6,6 @@ from framework.graph.client_io import (
|
||||
InertNodeClientIO,
|
||||
NodeClientIO,
|
||||
)
|
||||
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import ConversationStore, Message, NodeConversation
|
||||
from framework.graph.edge import DEFAULT_MAX_TOKENS, EdgeCondition, EdgeSpec, GraphSpec
|
||||
@@ -18,31 +17,9 @@ from framework.graph.event_loop_node import (
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
|
||||
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
|
||||
from framework.graph.judge import HybridJudge, create_default_judge
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
|
||||
# Flexible execution (Worker-Judge pattern)
|
||||
from framework.graph.plan import (
|
||||
ActionSpec,
|
||||
ActionType,
|
||||
# HITL (Human-in-the-loop)
|
||||
ApprovalDecision,
|
||||
ApprovalRequest,
|
||||
ApprovalResult,
|
||||
EvaluationRule,
|
||||
ExecutionStatus,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
Plan,
|
||||
PlanExecutionResult,
|
||||
PlanStep,
|
||||
StepStatus,
|
||||
load_export,
|
||||
)
|
||||
from framework.graph.worker_node import StepExecutionResult, WorkerNode
|
||||
|
||||
__all__ = [
|
||||
# Goal
|
||||
"Goal",
|
||||
@@ -59,35 +36,8 @@ __all__ = [
|
||||
"EdgeCondition",
|
||||
"GraphSpec",
|
||||
"DEFAULT_MAX_TOKENS",
|
||||
# Executor (fixed graph)
|
||||
# Executor
|
||||
"GraphExecutor",
|
||||
# Plan (flexible execution)
|
||||
"Plan",
|
||||
"PlanStep",
|
||||
"ActionSpec",
|
||||
"ActionType",
|
||||
"StepStatus",
|
||||
"Judgment",
|
||||
"JudgmentAction",
|
||||
"EvaluationRule",
|
||||
"PlanExecutionResult",
|
||||
"ExecutionStatus",
|
||||
"load_export",
|
||||
# HITL (Human-in-the-loop)
|
||||
"ApprovalDecision",
|
||||
"ApprovalRequest",
|
||||
"ApprovalResult",
|
||||
# Worker-Judge
|
||||
"HybridJudge",
|
||||
"create_default_judge",
|
||||
"WorkerNode",
|
||||
"StepExecutionResult",
|
||||
"FlexibleGraphExecutor",
|
||||
"ExecutorConfig",
|
||||
# Code Sandbox
|
||||
"CodeSandbox",
|
||||
"safe_exec",
|
||||
"safe_eval",
|
||||
# Conversation
|
||||
"NodeConversation",
|
||||
"ConversationStore",
|
||||
|
||||
@@ -1,413 +0,0 @@
|
||||
"""
|
||||
Code Sandbox for Safe Execution of Dynamic Code.
|
||||
|
||||
Provides a restricted execution environment for code generated by
|
||||
the external planner. This is critical for open-ended planning where
|
||||
the planner can create arbitrary code actions.
|
||||
|
||||
Security measures:
|
||||
1. Restricted builtins (no file I/O, no imports of dangerous modules)
|
||||
2. Timeout enforcement
|
||||
3. Memory limits (via resource module on Unix)
|
||||
4. Namespace isolation
|
||||
"""
|
||||
|
||||
import ast
|
||||
import signal
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
# Safe builtins whitelist
|
||||
SAFE_BUILTINS = {
|
||||
# Basic types
|
||||
"True": True,
|
||||
"False": False,
|
||||
"None": None,
|
||||
# Type constructors
|
||||
"bool": bool,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"str": str,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"set": set,
|
||||
"tuple": tuple,
|
||||
"frozenset": frozenset,
|
||||
# Basic functions
|
||||
"abs": abs,
|
||||
"all": all,
|
||||
"any": any,
|
||||
"bin": bin,
|
||||
"chr": chr,
|
||||
"divmod": divmod,
|
||||
"enumerate": enumerate,
|
||||
"filter": filter,
|
||||
"format": format,
|
||||
"hex": hex,
|
||||
"isinstance": isinstance,
|
||||
"issubclass": issubclass,
|
||||
"iter": iter,
|
||||
"len": len,
|
||||
"map": map,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"next": next,
|
||||
"oct": oct,
|
||||
"ord": ord,
|
||||
"pow": pow,
|
||||
"range": range,
|
||||
"repr": repr,
|
||||
"reversed": reversed,
|
||||
"round": round,
|
||||
"slice": slice,
|
||||
"sorted": sorted,
|
||||
"sum": sum,
|
||||
"zip": zip,
|
||||
}
|
||||
|
||||
# Modules that can be imported
|
||||
ALLOWED_MODULES = {
|
||||
"math",
|
||||
"json",
|
||||
"re",
|
||||
"datetime",
|
||||
"collections",
|
||||
"itertools",
|
||||
"functools",
|
||||
"operator",
|
||||
"string",
|
||||
"random",
|
||||
"statistics",
|
||||
"decimal",
|
||||
"fractions",
|
||||
}
|
||||
|
||||
# Dangerous AST nodes to block
|
||||
BLOCKED_AST_NODES = {
|
||||
ast.Import,
|
||||
ast.ImportFrom,
|
||||
ast.Global,
|
||||
ast.Nonlocal,
|
||||
}
|
||||
|
||||
|
||||
class CodeSandboxError(Exception):
|
||||
"""Error during sandboxed code execution."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(CodeSandboxError):
|
||||
"""Code execution timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SecurityError(CodeSandboxError):
|
||||
"""Code contains potentially dangerous operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxResult:
|
||||
"""Result of sandboxed code execution."""
|
||||
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
stdout: str = ""
|
||||
variables: dict[str, Any] = field(default_factory=dict)
|
||||
execution_time_ms: int = 0
|
||||
|
||||
|
||||
class RestrictedImporter:
|
||||
"""Custom importer that only allows whitelisted modules."""
|
||||
|
||||
def __init__(self, allowed_modules: set[str]):
|
||||
self.allowed_modules = allowed_modules
|
||||
self._cache: dict[str, Any] = {}
|
||||
|
||||
def __call__(self, name: str, *args, **kwargs):
|
||||
if name not in self.allowed_modules:
|
||||
raise SecurityError(f"Import of module '{name}' is not allowed")
|
||||
|
||||
if name not in self._cache:
|
||||
import importlib
|
||||
|
||||
self._cache[name] = importlib.import_module(name)
|
||||
|
||||
return self._cache[name]
|
||||
|
||||
|
||||
class CodeValidator:
|
||||
"""Validates code for safety before execution."""
|
||||
|
||||
def __init__(self, blocked_nodes: set[type] | None = None):
|
||||
self.blocked_nodes = blocked_nodes or BLOCKED_AST_NODES
|
||||
|
||||
def validate(self, code: str) -> list[str]:
|
||||
"""
|
||||
Validate code and return list of issues.
|
||||
|
||||
Returns empty list if code is safe.
|
||||
"""
|
||||
issues = []
|
||||
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
return [f"Syntax error: {e}"]
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Check for blocked node types
|
||||
if type(node) in self.blocked_nodes:
|
||||
lineno = getattr(node, "lineno", "?")
|
||||
issues.append(f"Blocked operation: {type(node).__name__} at line {lineno}")
|
||||
|
||||
# Check for dangerous attribute access
|
||||
if isinstance(node, ast.Attribute):
|
||||
if node.attr.startswith("_"):
|
||||
issues.append(
|
||||
f"Access to private attribute '{node.attr}' at line {node.lineno}"
|
||||
)
|
||||
|
||||
# Check for exec/eval calls
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in ("exec", "eval", "compile", "__import__"):
|
||||
issues.append(
|
||||
f"Blocked function call: {node.func.id} at line {node.lineno}"
|
||||
)
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
class CodeSandbox:
|
||||
"""
|
||||
Sandboxed environment for executing dynamic code.
|
||||
|
||||
Usage:
|
||||
sandbox = CodeSandbox(timeout_seconds=5)
|
||||
result = sandbox.execute(
|
||||
code="x = 1 + 2\\nresult = x * 3",
|
||||
inputs={"multiplier": 2},
|
||||
)
|
||||
if result.success:
|
||||
print(result.variables["result"]) # 6
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout_seconds: int = 10,
|
||||
allowed_modules: set[str] | None = None,
|
||||
safe_builtins: dict[str, Any] | None = None,
|
||||
):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.allowed_modules = allowed_modules or ALLOWED_MODULES
|
||||
self.safe_builtins = safe_builtins or SAFE_BUILTINS
|
||||
self.validator = CodeValidator()
|
||||
self.importer = RestrictedImporter(self.allowed_modules)
|
||||
|
||||
@contextmanager
|
||||
def _timeout_context(self, seconds: int):
|
||||
"""Context manager for timeout enforcement."""
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError(f"Code execution timed out after {seconds} seconds")
|
||||
|
||||
# Only works on Unix-like systems
|
||||
if hasattr(signal, "SIGALRM"):
|
||||
old_handler = signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(seconds)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
else:
|
||||
# Windows: no timeout support, just execute
|
||||
yield
|
||||
|
||||
def _create_namespace(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Create isolated namespace for code execution."""
|
||||
namespace = {
|
||||
"__builtins__": dict(self.safe_builtins),
|
||||
"__import__": self.importer,
|
||||
}
|
||||
|
||||
# Add input variables
|
||||
namespace.update(inputs)
|
||||
|
||||
return namespace
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
extract_vars: list[str] | None = None,
|
||||
) -> SandboxResult:
|
||||
"""
|
||||
Execute code in sandbox.
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
inputs: Variables to inject into namespace
|
||||
extract_vars: Variable names to extract from namespace after execution
|
||||
|
||||
Returns:
|
||||
SandboxResult with execution outcome
|
||||
"""
|
||||
import time
|
||||
|
||||
inputs = inputs or {}
|
||||
extract_vars = extract_vars or []
|
||||
|
||||
# Validate code first
|
||||
issues = self.validator.validate(code)
|
||||
if issues:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=f"Code validation failed: {'; '.join(issues)}",
|
||||
)
|
||||
|
||||
# Create isolated namespace
|
||||
namespace = self._create_namespace(inputs)
|
||||
|
||||
# Capture stdout
|
||||
import io
|
||||
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = captured_stdout = io.StringIO()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
with self._timeout_context(self.timeout_seconds):
|
||||
# Compile and execute
|
||||
compiled = compile(code, "<sandbox>", "exec")
|
||||
exec(compiled, namespace)
|
||||
|
||||
execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Extract requested variables
|
||||
extracted = {}
|
||||
for var in extract_vars:
|
||||
if var in namespace:
|
||||
extracted[var] = namespace[var]
|
||||
|
||||
# Also extract any new variables (not in inputs or builtins)
|
||||
for key, value in namespace.items():
|
||||
if key not in inputs and key not in self.safe_builtins and not key.startswith("_"):
|
||||
extracted[key] = value
|
||||
|
||||
return SandboxResult(
|
||||
success=True,
|
||||
result=namespace.get("result"), # Convention: 'result' is the return value
|
||||
stdout=captured_stdout.getvalue(),
|
||||
variables=extracted,
|
||||
execution_time_ms=execution_time_ms,
|
||||
)
|
||||
|
||||
except TimeoutError as e:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
execution_time_ms=self.timeout_seconds * 1000,
|
||||
)
|
||||
|
||||
except SecurityError as e:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=f"Security violation: {e}",
|
||||
execution_time_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=f"{type(e).__name__}: {e}",
|
||||
stdout=captured_stdout.getvalue(),
|
||||
execution_time_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
|
||||
def execute_expression(
|
||||
self,
|
||||
expression: str,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> SandboxResult:
|
||||
"""
|
||||
Execute a single expression and return its value.
|
||||
|
||||
Simpler than execute() - just evaluates one expression.
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
|
||||
# Validate
|
||||
try:
|
||||
ast.parse(expression, mode="eval")
|
||||
except SyntaxError as e:
|
||||
return SandboxResult(success=False, error=f"Syntax error: {e}")
|
||||
|
||||
namespace = self._create_namespace(inputs)
|
||||
|
||||
try:
|
||||
with self._timeout_context(self.timeout_seconds):
|
||||
result = eval(expression, namespace)
|
||||
|
||||
return SandboxResult(success=True, result=result)
|
||||
|
||||
except Exception as e:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=f"{type(e).__name__}: {e}",
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance with default settings
|
||||
default_sandbox = CodeSandbox()
|
||||
|
||||
|
||||
def safe_exec(
|
||||
code: str,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
timeout_seconds: int = 10,
|
||||
) -> SandboxResult:
|
||||
"""
|
||||
Convenience function for safe code execution.
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
inputs: Variables to inject
|
||||
timeout_seconds: Max execution time
|
||||
|
||||
Returns:
|
||||
SandboxResult
|
||||
"""
|
||||
sandbox = CodeSandbox(timeout_seconds=timeout_seconds)
|
||||
return sandbox.execute(code, inputs)
|
||||
|
||||
|
||||
def safe_eval(
|
||||
expression: str,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
timeout_seconds: int = 5,
|
||||
) -> SandboxResult:
|
||||
"""
|
||||
Convenience function for safe expression evaluation.
|
||||
|
||||
Args:
|
||||
expression: Python expression to evaluate
|
||||
inputs: Variables to inject
|
||||
timeout_seconds: Max execution time
|
||||
|
||||
Returns:
|
||||
SandboxResult
|
||||
"""
|
||||
sandbox = CodeSandbox(timeout_seconds=timeout_seconds)
|
||||
return sandbox.execute_expression(expression, inputs)
|
||||
@@ -27,6 +27,9 @@ class Message:
|
||||
tool_use_id: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
is_error: bool = False
|
||||
# Phase-aware compaction metadata (continuous mode)
|
||||
phase_id: str | None = None
|
||||
is_transition_marker: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
@@ -60,6 +63,10 @@ class Message:
|
||||
d["tool_calls"] = self.tool_calls
|
||||
if self.is_error:
|
||||
d["is_error"] = self.is_error
|
||||
if self.phase_id is not None:
|
||||
d["phase_id"] = self.phase_id
|
||||
if self.is_transition_marker:
|
||||
d["is_transition_marker"] = self.is_transition_marker
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
@@ -72,6 +79,8 @@ class Message:
|
||||
tool_use_id=data.get("tool_use_id"),
|
||||
tool_calls=data.get("tool_calls"),
|
||||
is_error=data.get("is_error", False),
|
||||
phase_id=data.get("phase_id"),
|
||||
is_transition_marker=data.get("is_transition_marker", False),
|
||||
)
|
||||
|
||||
|
||||
@@ -188,6 +197,7 @@ class NodeConversation:
|
||||
self._next_seq: int = 0
|
||||
self._meta_persisted: bool = False
|
||||
self._last_api_input_tokens: int | None = None
|
||||
self._current_phase: str | None = None
|
||||
|
||||
# --- Properties --------------------------------------------------------
|
||||
|
||||
@@ -195,6 +205,33 @@ class NodeConversation:
|
||||
def system_prompt(self) -> str:
|
||||
return self._system_prompt
|
||||
|
||||
def update_system_prompt(self, new_prompt: str) -> None:
|
||||
"""Update the system prompt.
|
||||
|
||||
Used in continuous conversation mode at phase transitions to swap
|
||||
Layer 3 (focus) while preserving the conversation history.
|
||||
"""
|
||||
self._system_prompt = new_prompt
|
||||
|
||||
def set_current_phase(self, phase_id: str) -> None:
|
||||
"""Set the current phase ID. Subsequent messages will be stamped with it."""
|
||||
self._current_phase = phase_id
|
||||
|
||||
async def switch_store(self, new_store: ConversationStore) -> None:
|
||||
"""Switch to a new persistence store at a phase transition.
|
||||
|
||||
Subsequent messages are written to *new_store*. Meta (system
|
||||
prompt, config) is re-persisted on the next write so the new
|
||||
store's ``meta.json`` reflects the updated prompt.
|
||||
"""
|
||||
self._store = new_store
|
||||
self._meta_persisted = False
|
||||
await new_store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
@property
|
||||
def current_phase(self) -> str | None:
|
||||
return self._current_phase
|
||||
|
||||
@property
|
||||
def messages(self) -> list[Message]:
|
||||
"""Return a defensive copy of the message list."""
|
||||
@@ -216,8 +253,19 @@ class NodeConversation:
|
||||
|
||||
# --- Add messages ------------------------------------------------------
|
||||
|
||||
async def add_user_message(self, content: str) -> Message:
|
||||
msg = Message(seq=self._next_seq, role="user", content=content)
|
||||
async def add_user_message(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
is_transition_marker: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
role="user",
|
||||
content=content,
|
||||
phase_id=self._current_phase,
|
||||
is_transition_marker=is_transition_marker,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
await self._persist(msg)
|
||||
@@ -233,6 +281,7 @@ class NodeConversation:
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
phase_id=self._current_phase,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -251,6 +300,7 @@ class NodeConversation:
|
||||
content=content,
|
||||
tool_use_id=tool_use_id,
|
||||
is_error=is_error,
|
||||
phase_id=self._current_phase,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -380,6 +430,11 @@ class NodeConversation:
|
||||
spillover filename reference (if any). Message structure (role,
|
||||
seq, tool_use_id) stays valid for the LLM API.
|
||||
|
||||
Phase-aware behavior (continuous mode): when messages have ``phase_id``
|
||||
metadata, all messages in the current phase are protected regardless of
|
||||
token budget. Transition markers are never pruned. Older phases' tool
|
||||
results are pruned more aggressively.
|
||||
|
||||
Error tool results are never pruned — they prevent re-calling
|
||||
failing tools.
|
||||
|
||||
@@ -388,13 +443,18 @@ class NodeConversation:
|
||||
if not self._messages:
|
||||
return 0
|
||||
|
||||
# Phase 1: Walk backward, classify tool results as protected vs pruneable
|
||||
# Walk backward, classify tool results as protected vs pruneable
|
||||
protected_tokens = 0
|
||||
pruneable: list[int] = [] # indices into self._messages
|
||||
pruneable_tokens = 0
|
||||
|
||||
for i in range(len(self._messages) - 1, -1, -1):
|
||||
msg = self._messages[i]
|
||||
|
||||
# Transition markers are never pruned (any role)
|
||||
if msg.is_transition_marker:
|
||||
continue
|
||||
|
||||
if msg.role != "tool":
|
||||
continue
|
||||
if msg.is_error:
|
||||
@@ -402,6 +462,10 @@ class NodeConversation:
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
|
||||
# Phase-aware: protect current phase messages
|
||||
if self._current_phase and msg.phase_id == self._current_phase:
|
||||
continue
|
||||
|
||||
est = len(msg.content) // 4
|
||||
if protected_tokens < protect_tokens:
|
||||
protected_tokens += est
|
||||
@@ -409,11 +473,11 @@ class NodeConversation:
|
||||
pruneable.append(i)
|
||||
pruneable_tokens += est
|
||||
|
||||
# Phase 2: Only prune if enough to be worthwhile
|
||||
# Only prune if enough to be worthwhile
|
||||
if pruneable_tokens < min_prune_tokens:
|
||||
return 0
|
||||
|
||||
# Phase 3: Replace content with compact placeholder
|
||||
# Replace content with compact placeholder
|
||||
count = 0
|
||||
for i in pruneable:
|
||||
msg = self._messages[i]
|
||||
@@ -436,6 +500,8 @@ class NodeConversation:
|
||||
tool_use_id=msg.tool_use_id,
|
||||
tool_calls=msg.tool_calls,
|
||||
is_error=msg.is_error,
|
||||
phase_id=msg.phase_id,
|
||||
is_transition_marker=msg.is_transition_marker,
|
||||
)
|
||||
count += 1
|
||||
|
||||
@@ -446,22 +512,38 @@ class NodeConversation:
|
||||
self._last_api_input_tokens = None
|
||||
return count
|
||||
|
||||
async def compact(self, summary: str, keep_recent: int = 2) -> None:
|
||||
async def compact(
|
||||
self,
|
||||
summary: str,
|
||||
keep_recent: int = 2,
|
||||
phase_graduated: bool = False,
|
||||
) -> None:
|
||||
"""Replace old messages with a summary, optionally keeping recent ones.
|
||||
|
||||
Args:
|
||||
summary: Caller-provided summary text.
|
||||
keep_recent: Number of recent messages to preserve (default 2).
|
||||
Clamped to [0, len(messages) - 1].
|
||||
phase_graduated: When True and messages have phase_id metadata,
|
||||
split at phase boundaries instead of using keep_recent.
|
||||
Keeps current + previous phase intact; compacts older phases.
|
||||
"""
|
||||
if not self._messages:
|
||||
return
|
||||
|
||||
# Clamp: must discard at least 1 message
|
||||
keep_recent = max(0, min(keep_recent, len(self._messages) - 1))
|
||||
|
||||
total = len(self._messages)
|
||||
split = total - keep_recent if keep_recent > 0 else total
|
||||
|
||||
# Phase-graduated: find the split point based on phase boundaries.
|
||||
# Keeps current phase + previous phase intact, compacts older phases.
|
||||
if phase_graduated and self._current_phase:
|
||||
split = self._find_phase_graduated_split()
|
||||
else:
|
||||
split = None
|
||||
|
||||
if split is None:
|
||||
# Fallback: use keep_recent (non-phase or single-phase conversation)
|
||||
keep_recent = max(0, min(keep_recent, total - 1))
|
||||
split = total - keep_recent if keep_recent > 0 else total
|
||||
|
||||
# Advance split past orphaned tool results at the boundary.
|
||||
# Tool-role messages reference a tool_use from the preceding
|
||||
@@ -470,6 +552,10 @@ class NodeConversation:
|
||||
while split < total and self._messages[split].role == "tool":
|
||||
split += 1
|
||||
|
||||
# Nothing to compact
|
||||
if split == 0:
|
||||
return
|
||||
|
||||
old_messages = list(self._messages[:split])
|
||||
recent_messages = list(self._messages[split:])
|
||||
|
||||
@@ -504,6 +590,33 @@ class NodeConversation:
|
||||
self._messages = [summary_msg] + recent_messages
|
||||
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
|
||||
|
||||
def _find_phase_graduated_split(self) -> int | None:
|
||||
"""Find split point that preserves current + previous phase.
|
||||
|
||||
Returns the index of the first message in the protected set,
|
||||
or None if phase graduation doesn't apply (< 3 phases).
|
||||
"""
|
||||
# Collect distinct phases in order of first appearance
|
||||
phases_seen: list[str] = []
|
||||
for msg in self._messages:
|
||||
if msg.phase_id and msg.phase_id not in phases_seen:
|
||||
phases_seen.append(msg.phase_id)
|
||||
|
||||
# Need at least 3 phases for graduation to be meaningful
|
||||
# (current + previous are protected, older get compacted)
|
||||
if len(phases_seen) < 3:
|
||||
return None
|
||||
|
||||
# Protect: current phase + previous phase
|
||||
protected_phases = {phases_seen[-1], phases_seen[-2]}
|
||||
|
||||
# Find split: first message belonging to a protected phase
|
||||
for i, msg in enumerate(self._messages):
|
||||
if msg.phase_id in protected_phases:
|
||||
return i
|
||||
|
||||
return None
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
|
||||
if self._store:
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
"""Level 2 Conversation-Aware Judge.
|
||||
|
||||
When a node has `success_criteria` set, the implicit judge upgrades:
|
||||
after Level 0 passes (all output keys set), a fast LLM call evaluates
|
||||
whether the conversation actually meets the criteria.
|
||||
|
||||
This prevents nodes from "checking boxes" (setting output keys) without
|
||||
doing quality work. The LLM reads the recent conversation and assesses
|
||||
whether the phase's goal was genuinely accomplished.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.provider import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseVerdict:
|
||||
"""Result of Level 2 conversation-aware evaluation."""
|
||||
|
||||
action: str # "ACCEPT" or "RETRY"
|
||||
confidence: float = 0.8
|
||||
feedback: str = ""
|
||||
|
||||
|
||||
async def evaluate_phase_completion(
|
||||
llm: LLMProvider,
|
||||
conversation: NodeConversation,
|
||||
phase_name: str,
|
||||
phase_description: str,
|
||||
success_criteria: str,
|
||||
accumulator_state: dict[str, Any],
|
||||
max_history_tokens: int = 8_196,
|
||||
) -> PhaseVerdict:
|
||||
"""Level 2 judge: read the conversation and evaluate quality.
|
||||
|
||||
Only called after Level 0 passes (all output keys set).
|
||||
|
||||
Args:
|
||||
llm: LLM provider for evaluation
|
||||
conversation: The current conversation to evaluate
|
||||
phase_name: Name of the current phase/node
|
||||
phase_description: Description of the phase
|
||||
success_criteria: Natural-language criteria for phase completion
|
||||
accumulator_state: Current output key values
|
||||
max_history_tokens: Main conversation token budget (judge gets 20%)
|
||||
|
||||
Returns:
|
||||
PhaseVerdict with action and optional feedback
|
||||
"""
|
||||
# Build a compact view of the recent conversation
|
||||
recent_messages = _extract_recent_context(conversation, max_messages=10)
|
||||
outputs_summary = _format_outputs(accumulator_state)
|
||||
|
||||
system_prompt = (
|
||||
"You are a quality judge evaluating whether a phase of work is complete. "
|
||||
"Be concise. Evaluate based on the success criteria, not on style."
|
||||
)
|
||||
|
||||
user_prompt = f"""Evaluate this phase:
|
||||
|
||||
PHASE: {phase_name}
|
||||
DESCRIPTION: {phase_description}
|
||||
|
||||
SUCCESS CRITERIA:
|
||||
{success_criteria}
|
||||
|
||||
OUTPUTS SET:
|
||||
{outputs_summary}
|
||||
|
||||
RECENT CONVERSATION:
|
||||
{recent_messages}
|
||||
|
||||
Has this phase accomplished its goal based on the success criteria?
|
||||
|
||||
Respond in exactly this format:
|
||||
ACTION: ACCEPT or RETRY
|
||||
CONFIDENCE: 0.X
|
||||
FEEDBACK: (reason if RETRY, empty if ACCEPT)"""
|
||||
|
||||
try:
|
||||
response = await llm.acomplete(
|
||||
messages=[{"role": "user", "content": user_prompt}],
|
||||
system=system_prompt,
|
||||
max_tokens=max(1024, max_history_tokens // 5),
|
||||
max_retries=1,
|
||||
)
|
||||
if not response.content or not response.content.strip():
|
||||
logger.debug("Level 2 judge: empty response, accepting by default")
|
||||
return PhaseVerdict(action="ACCEPT", confidence=0.5, feedback="")
|
||||
return _parse_verdict(response.content)
|
||||
except Exception as e:
|
||||
logger.warning(f"Level 2 judge failed, accepting by default: {e}")
|
||||
# On failure, don't block — Level 0 already passed
|
||||
return PhaseVerdict(action="ACCEPT", confidence=0.5, feedback="")
|
||||
|
||||
|
||||
def _extract_recent_context(conversation: NodeConversation, max_messages: int = 10) -> str:
|
||||
"""Extract recent conversation messages for evaluation."""
|
||||
messages = conversation.messages
|
||||
recent = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
|
||||
parts = []
|
||||
for msg in recent:
|
||||
role = msg.role.upper()
|
||||
content = msg.content or ""
|
||||
# Truncate long tool results
|
||||
if msg.role == "tool" and len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
if content.strip():
|
||||
parts.append(f"[{role}]: {content.strip()}")
|
||||
|
||||
return "\n".join(parts) if parts else "(no messages)"
|
||||
|
||||
|
||||
def _format_outputs(accumulator_state: dict[str, Any]) -> str:
|
||||
"""Format output key values for evaluation.
|
||||
|
||||
Lists and dicts get structural formatting so the judge can assess
|
||||
quantity and structure, not just a truncated stringification.
|
||||
"""
|
||||
if not accumulator_state:
|
||||
return "(none)"
|
||||
parts = []
|
||||
for key, value in accumulator_state.items():
|
||||
if isinstance(value, list):
|
||||
# Show count + brief per-item preview so the judge can
|
||||
# verify quantity without the full serialization.
|
||||
items_preview = []
|
||||
for i, item in enumerate(value[:8]):
|
||||
item_str = str(item)
|
||||
if len(item_str) > 150:
|
||||
item_str = item_str[:150] + "..."
|
||||
items_preview.append(f" [{i}]: {item_str}")
|
||||
val_str = f"list ({len(value)} items):\n" + "\n".join(items_preview)
|
||||
if len(value) > 8:
|
||||
val_str += f"\n ... and {len(value) - 8} more"
|
||||
elif isinstance(value, dict):
|
||||
val_str = str(value)
|
||||
if len(val_str) > 400:
|
||||
val_str = val_str[:400] + "..."
|
||||
else:
|
||||
val_str = str(value)
|
||||
if len(val_str) > 300:
|
||||
val_str = val_str[:300] + "..."
|
||||
parts.append(f" {key}: {val_str}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _parse_verdict(response: str) -> PhaseVerdict:
|
||||
"""Parse LLM response into PhaseVerdict."""
|
||||
action = "ACCEPT"
|
||||
confidence = 0.8
|
||||
feedback = ""
|
||||
|
||||
for line in response.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("ACTION:"):
|
||||
action_str = line.split(":", 1)[1].strip().upper()
|
||||
if action_str in ("ACCEPT", "RETRY"):
|
||||
action = action_str
|
||||
elif line.startswith("CONFIDENCE:"):
|
||||
try:
|
||||
confidence = float(line.split(":", 1)[1].strip())
|
||||
except ValueError:
|
||||
pass
|
||||
elif line.startswith("FEEDBACK:"):
|
||||
feedback = line.split(":", 1)[1].strip()
|
||||
|
||||
return PhaseVerdict(action=action, confidence=confidence, feedback=feedback)
|
||||
@@ -21,6 +21,9 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
|
||||
given the current goal, context, and execution state.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
@@ -28,6 +31,8 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_TOKENS = 8192
|
||||
|
||||
|
||||
@@ -99,7 +104,7 @@ class EdgeSpec(BaseModel):
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def should_traverse(
|
||||
async def should_traverse(
|
||||
self,
|
||||
source_success: bool,
|
||||
source_output: dict[str, Any],
|
||||
@@ -140,7 +145,7 @@ class EdgeSpec(BaseModel):
|
||||
if llm is None or goal is None:
|
||||
# Fallback to ON_SUCCESS if LLM not available
|
||||
return source_success
|
||||
return self._llm_decide(
|
||||
return await self._llm_decide(
|
||||
llm=llm,
|
||||
goal=goal,
|
||||
source_success=source_success,
|
||||
@@ -158,9 +163,6 @@ class EdgeSpec(BaseModel):
|
||||
memory: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Evaluate a conditional expression."""
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not self.condition_expr:
|
||||
return True
|
||||
@@ -201,7 +203,7 @@ class EdgeSpec(BaseModel):
|
||||
logger.warning(f" Available context keys: {list(context.keys())}")
|
||||
return False
|
||||
|
||||
def _llm_decide(
|
||||
async def _llm_decide(
|
||||
self,
|
||||
llm: Any,
|
||||
goal: Any,
|
||||
@@ -217,8 +219,6 @@ class EdgeSpec(BaseModel):
|
||||
The LLM evaluates whether proceeding to the target node
|
||||
is the best next step toward achieving the goal.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Build context for LLM
|
||||
prompt = f"""You are evaluating whether to proceed along an edge in an agent workflow.
|
||||
|
||||
@@ -247,15 +247,13 @@ Respond with ONLY a JSON object:
|
||||
{{"proceed": true/false, "reasoning": "brief explanation"}}"""
|
||||
|
||||
try:
|
||||
response = llm.complete(
|
||||
response = await llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a routing agent. Respond with JSON only.",
|
||||
max_tokens=150,
|
||||
)
|
||||
|
||||
# Parse response
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
@@ -263,9 +261,6 @@ Respond with ONLY a JSON object:
|
||||
reasoning = data.get("reasoning", "")
|
||||
|
||||
# Log the decision (using basic print for now)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f" 🤔 LLM routing decision: {'PROCEED' if proceed else 'SKIP'}")
|
||||
logger.info(f" Reason: {reasoning}")
|
||||
|
||||
@@ -273,9 +268,6 @@ Respond with ONLY a JSON object:
|
||||
|
||||
except Exception as e:
|
||||
# Fallback: proceed on success
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f" ⚠ LLM routing failed, defaulting to on_success: {e}")
|
||||
return source_success
|
||||
|
||||
@@ -443,6 +435,25 @@ class GraphSpec(BaseModel):
|
||||
description="EventLoopNode configuration (max_iterations, max_tool_calls_per_turn, etc.)",
|
||||
)
|
||||
|
||||
# Conversation mode
|
||||
conversation_mode: str = Field(
|
||||
default="continuous",
|
||||
description=(
|
||||
"How conversations flow between event_loop nodes. "
|
||||
"'continuous' (default): one conversation threads through all "
|
||||
"event_loop nodes with cumulative tools and layered prompt composition. "
|
||||
"'isolated': each node gets a fresh conversation."
|
||||
),
|
||||
)
|
||||
identity_prompt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Agent-level identity prompt (Layer 1 of the onion model). "
|
||||
"In continuous mode, this is the static identity that persists "
|
||||
"unchanged across all node transitions. In isolated mode, ignored."
|
||||
),
|
||||
)
|
||||
|
||||
# Metadata
|
||||
description: str = ""
|
||||
created_by: str = "" # "human" or "builder_agent"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,6 @@ The executor:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -21,13 +20,10 @@ from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import (
|
||||
FunctionNode,
|
||||
LLMNode,
|
||||
NodeContext,
|
||||
NodeProtocol,
|
||||
NodeResult,
|
||||
NodeSpec,
|
||||
RouterNode,
|
||||
SharedMemory,
|
||||
)
|
||||
from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
|
||||
@@ -186,6 +182,55 @@ class GraphExecutor:
|
||||
# Pause/resume control
|
||||
self._pause_requested = asyncio.Event()
|
||||
|
||||
def _write_progress(
|
||||
self,
|
||||
current_node: str,
|
||||
path: list[str],
|
||||
memory: Any,
|
||||
node_visit_counts: dict[str, int],
|
||||
) -> None:
|
||||
"""Update state.json with live progress at node transitions.
|
||||
|
||||
Reads the existing state.json (written by ExecutionStream at session
|
||||
start) and patches the progress fields in-place. This keeps
|
||||
state.json as the single source of truth — readers always see
|
||||
current progress, not stale initial values.
|
||||
|
||||
The write is synchronous and best-effort: never blocks execution.
|
||||
"""
|
||||
if not self._storage_path:
|
||||
return
|
||||
try:
|
||||
import json as _json
|
||||
from datetime import datetime
|
||||
|
||||
state_path = self._storage_path / "state.json"
|
||||
if state_path.exists():
|
||||
state_data = _json.loads(state_path.read_text(encoding="utf-8"))
|
||||
else:
|
||||
state_data = {}
|
||||
|
||||
# Patch progress fields
|
||||
progress = state_data.setdefault("progress", {})
|
||||
progress["current_node"] = current_node
|
||||
progress["path"] = list(path)
|
||||
progress["node_visit_counts"] = dict(node_visit_counts)
|
||||
progress["steps_executed"] = len(path)
|
||||
|
||||
# Update timestamp
|
||||
timestamps = state_data.setdefault("timestamps", {})
|
||||
timestamps["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
# Persist full memory so state.json is sufficient for resume
|
||||
# even if the process dies before the final write.
|
||||
memory_snapshot = memory.read_all()
|
||||
state_data["memory"] = memory_snapshot
|
||||
state_data["memory_keys"] = list(memory_snapshot.keys())
|
||||
|
||||
state_path.write_text(_json.dumps(state_data, indent=2), encoding="utf-8")
|
||||
except Exception:
|
||||
pass # Best-effort — never block execution
|
||||
|
||||
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
||||
"""
|
||||
Validate that all tools declared by nodes are available.
|
||||
@@ -257,6 +302,13 @@ class GraphExecutor:
|
||||
# Initialize execution state
|
||||
memory = SharedMemory()
|
||||
|
||||
# Continuous conversation mode state
|
||||
is_continuous = getattr(graph, "conversation_mode", "isolated") == "continuous"
|
||||
continuous_conversation = None # NodeConversation threaded across nodes
|
||||
cumulative_tools: list = [] # Tools accumulate, never removed
|
||||
cumulative_tool_names: set[str] = set()
|
||||
cumulative_output_keys: list[str] = [] # Output keys from all visited nodes
|
||||
|
||||
# Initialize checkpoint store if checkpointing is enabled
|
||||
checkpoint_store: CheckpointStore | None = None
|
||||
if checkpoint_config and checkpoint_config.enabled and self._storage_path:
|
||||
@@ -273,16 +325,26 @@ class GraphExecutor:
|
||||
f"{type(memory_data).__name__}, expected dict"
|
||||
)
|
||||
else:
|
||||
# Restore memory from previous session
|
||||
# Restore memory from previous session.
|
||||
# Skip validation — this data was already validated when
|
||||
# originally written, and research text triggers false
|
||||
# positives on the code-indicator heuristic.
|
||||
for key, value in memory_data.items():
|
||||
memory.write(key, value)
|
||||
memory.write(key, value, validate=False)
|
||||
self.logger.info(f"📥 Restored session state with {len(memory_data)} memory keys")
|
||||
|
||||
# Write new input data to memory (each key individually)
|
||||
if input_data:
|
||||
# Write new input data to memory (each key individually).
|
||||
# Skip when resuming from a paused session — restored memory already
|
||||
# contains all state including the original input, and re-writing
|
||||
# input_data would overwrite intermediate results with stale values.
|
||||
_is_resuming = bool(session_state and session_state.get("paused_at"))
|
||||
if input_data and not _is_resuming:
|
||||
for key, value in input_data.items():
|
||||
memory.write(key, value)
|
||||
|
||||
# Detect event-triggered execution (timer/webhook) — no interactive user.
|
||||
_event_triggered = bool(input_data and isinstance(input_data.get("event"), dict))
|
||||
|
||||
path: list[str] = []
|
||||
total_tokens = 0
|
||||
total_latency = 0
|
||||
@@ -396,9 +458,76 @@ class GraphExecutor:
|
||||
|
||||
steps = 0
|
||||
|
||||
# Fresh shared-session execution: clear stale cursor so the entry
|
||||
# node doesn't restore a filled OutputAccumulator from the previous
|
||||
# webhook run (which would cause the judge to accept immediately).
|
||||
# The conversation history is preserved (continuous memory).
|
||||
_is_fresh_shared = bool(
|
||||
session_state
|
||||
and session_state.get("resume_session_id")
|
||||
and not session_state.get("paused_at")
|
||||
and not session_state.get("resume_from_checkpoint")
|
||||
)
|
||||
if _is_fresh_shared and is_continuous and self._storage_path:
|
||||
try:
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
entry_conv_path = self._storage_path / "conversations" / current_node_id
|
||||
if entry_conv_path.exists():
|
||||
_store = FileConversationStore(base_path=entry_conv_path)
|
||||
|
||||
# Read cursor to find next seq for the transition marker.
|
||||
_cursor = await _store.read_cursor() or {}
|
||||
_next_seq = _cursor.get("next_seq", 0)
|
||||
if _next_seq == 0:
|
||||
# Fallback: scan part files for max seq
|
||||
_parts = await _store.read_parts()
|
||||
if _parts:
|
||||
_next_seq = max(p.get("seq", 0) for p in _parts) + 1
|
||||
|
||||
# Reset cursor — clears stale accumulator outputs and
|
||||
# iteration counter so the node starts fresh work while
|
||||
# the conversation thread carries forward.
|
||||
await _store.write_cursor({})
|
||||
|
||||
# Append a transition marker so the LLM knows a new
|
||||
# event arrived and previous results are outdated.
|
||||
await _store.write_part(
|
||||
_next_seq,
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"--- NEW EVENT TRIGGER ---\n"
|
||||
"A new event has been received. "
|
||||
"Process this as a fresh request — "
|
||||
"previous outputs are no longer valid."
|
||||
),
|
||||
"seq": _next_seq,
|
||||
"is_transition_marker": True,
|
||||
},
|
||||
)
|
||||
self.logger.info(
|
||||
"🔄 Cleared stale cursor and added transition marker "
|
||||
"for shared-session entry node '%s'",
|
||||
current_node_id,
|
||||
)
|
||||
except Exception:
|
||||
self.logger.debug(
|
||||
"Could not prepare conversation store for shared-session entry node '%s'",
|
||||
current_node_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if session_state and current_node_id != graph.entry_node:
|
||||
self.logger.info(f"🔄 Resuming from: {current_node_id}")
|
||||
|
||||
# Emit resume event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_resumed(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
)
|
||||
|
||||
# Start run
|
||||
_run_id = self.runtime.start_run(
|
||||
goal_id=goal.id,
|
||||
@@ -435,6 +564,14 @@ class GraphExecutor:
|
||||
if self._pause_requested.is_set():
|
||||
self.logger.info("⏸ Pause detected - stopping at node boundary")
|
||||
|
||||
# Emit pause event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_paused(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
reason="User requested pause (Ctrl+Z)",
|
||||
)
|
||||
|
||||
# Create session state for pause
|
||||
saved_memory = memory.read_all()
|
||||
pause_session_state: dict[str, Any] = {
|
||||
@@ -489,7 +626,7 @@ class GraphExecutor:
|
||||
)
|
||||
# Skip execution — follow outgoing edges using current memory
|
||||
skip_result = NodeResult(success=True, output=memory.read_all())
|
||||
next_node = self._follow_edges(
|
||||
next_node = await self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -530,6 +667,17 @@ class GraphExecutor:
|
||||
self.logger.info(f" Inputs: {node_spec.input_keys}")
|
||||
self.logger.info(f" Outputs: {node_spec.output_keys}")
|
||||
|
||||
# Continuous mode: accumulate tools and output keys from this node
|
||||
if is_continuous and node_spec.tools:
|
||||
for t in self.tools:
|
||||
if t.name in node_spec.tools and t.name not in cumulative_tool_names:
|
||||
cumulative_tools.append(t)
|
||||
cumulative_tool_names.add(t.name)
|
||||
if is_continuous and node_spec.output_keys:
|
||||
for k in node_spec.output_keys:
|
||||
if k not in cumulative_output_keys:
|
||||
cumulative_output_keys.append(k)
|
||||
|
||||
# Build context for node
|
||||
ctx = self._build_context(
|
||||
node_spec=node_spec,
|
||||
@@ -537,6 +685,11 @@ class GraphExecutor:
|
||||
goal=goal,
|
||||
input_data=input_data or {},
|
||||
max_tokens=graph.max_tokens,
|
||||
continuous_mode=is_continuous,
|
||||
inherited_conversation=continuous_conversation if is_continuous else None,
|
||||
override_tools=cumulative_tools if is_continuous else None,
|
||||
cumulative_output_keys=cumulative_output_keys if is_continuous else None,
|
||||
event_triggered=_event_triggered,
|
||||
)
|
||||
|
||||
# Log actual input data being read
|
||||
@@ -680,9 +833,13 @@ class GraphExecutor:
|
||||
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
|
||||
max_retries = getattr(node_spec, "max_retries", 3)
|
||||
|
||||
# Event loop nodes handle retry internally via judge —
|
||||
# executor retry is catastrophic (retry multiplication)
|
||||
if node_spec.node_type == "event_loop" and max_retries > 0:
|
||||
# EventLoopNode instances handle retry internally via judge —
|
||||
# executor retry would cause catastrophic retry multiplication.
|
||||
# Only override for actual EventLoopNode instances, not custom
|
||||
# NodeProtocol implementations that happen to use node_type="event_loop"
|
||||
from framework.graph.event_loop_node import EventLoopNode
|
||||
|
||||
if isinstance(node_impl, EventLoopNode) and max_retries > 0:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has max_retries={max_retries}. "
|
||||
"Overriding to 0 — event loop nodes handle retry internally via judge."
|
||||
@@ -704,6 +861,17 @@ class GraphExecutor:
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
|
||||
# Emit retry event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_retry(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
retry_count=retry_count,
|
||||
max_retries=max_retries,
|
||||
error=result.error or "",
|
||||
)
|
||||
|
||||
_is_retry = True
|
||||
continue
|
||||
else:
|
||||
@@ -713,7 +881,7 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Check if there's an ON_FAILURE edge to follow
|
||||
next_node = self._follow_edges(
|
||||
next_node = await self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -763,6 +931,7 @@ class GraphExecutor:
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
"resume_from": current_node_id,
|
||||
}
|
||||
|
||||
return ExecutionResult(
|
||||
@@ -789,11 +958,22 @@ class GraphExecutor:
|
||||
# This must happen BEFORE determining next node, since pause nodes may have no edges
|
||||
if node_spec.id in graph.pause_nodes:
|
||||
self.logger.info("💾 Saving session state after pause node")
|
||||
|
||||
# Emit pause event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_paused(
|
||||
stream_id=self._stream_id,
|
||||
node_id=node_spec.id,
|
||||
reason="HITL pause node",
|
||||
)
|
||||
|
||||
saved_memory = memory.read_all()
|
||||
session_state_out = {
|
||||
"paused_at": node_spec.id,
|
||||
"resume_from": f"{node_spec.id}_resume", # Resume key
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
"next_node": None, # Will resume from entry point
|
||||
}
|
||||
|
||||
@@ -842,10 +1022,21 @@ class GraphExecutor:
|
||||
if result.next_node:
|
||||
# Router explicitly set next node
|
||||
self.logger.info(f" → Router directing to: {result.next_node}")
|
||||
|
||||
# Emit edge traversed event for router-directed edge
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=result.next_node,
|
||||
edge_condition="router",
|
||||
)
|
||||
|
||||
current_node_id = result.next_node
|
||||
self._write_progress(current_node_id, path, memory, node_visit_counts)
|
||||
else:
|
||||
# Get all traversable edges for fan-out detection
|
||||
traversable_edges = self._get_all_traversable_edges(
|
||||
traversable_edges = await self._get_all_traversable_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -864,6 +1055,18 @@ class GraphExecutor:
|
||||
targets = [e.target for e in traversable_edges]
|
||||
fan_in_node = self._find_convergence_node(graph, targets)
|
||||
|
||||
# Emit edge traversed events for fan-out branches
|
||||
if self._event_bus:
|
||||
for edge in traversable_edges:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=edge.target,
|
||||
edge_condition=edge.condition.value
|
||||
if hasattr(edge.condition, "value")
|
||||
else str(edge.condition),
|
||||
)
|
||||
|
||||
# Execute branches in parallel
|
||||
(
|
||||
_branch_results,
|
||||
@@ -886,13 +1089,14 @@ class GraphExecutor:
|
||||
if fan_in_node:
|
||||
self.logger.info(f" ⑃ Fan-in: converging at {fan_in_node}")
|
||||
current_node_id = fan_in_node
|
||||
self._write_progress(current_node_id, path, memory, node_visit_counts)
|
||||
else:
|
||||
# No convergence point - branches are terminal
|
||||
self.logger.info(" → Parallel branches completed (no convergence)")
|
||||
break
|
||||
else:
|
||||
# Sequential: follow single edge (existing logic via _follow_edges)
|
||||
next_node = self._follow_edges(
|
||||
next_node = await self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -906,6 +1110,14 @@ class GraphExecutor:
|
||||
next_spec = graph.get_node(next_node)
|
||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||
|
||||
# Emit edge traversed event for sequential edge
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=next_node,
|
||||
)
|
||||
|
||||
# CHECKPOINT: node_complete (after determining next node)
|
||||
if (
|
||||
checkpoint_store
|
||||
@@ -940,6 +1152,84 @@ class GraphExecutor:
|
||||
|
||||
current_node_id = next_node
|
||||
|
||||
# Write progress snapshot at node transition
|
||||
self._write_progress(current_node_id, path, memory, node_visit_counts)
|
||||
|
||||
# Continuous mode: thread conversation forward with transition marker
|
||||
if is_continuous and result.conversation is not None:
|
||||
continuous_conversation = result.conversation
|
||||
|
||||
# Look up the next node spec for the transition marker
|
||||
next_spec = graph.get_node(current_node_id)
|
||||
if next_spec and next_spec.node_type == "event_loop":
|
||||
from framework.graph.prompt_composer import (
|
||||
build_narrative,
|
||||
build_transition_marker,
|
||||
compose_system_prompt,
|
||||
)
|
||||
|
||||
# Build Layer 2 (narrative) from current state
|
||||
narrative = build_narrative(memory, path, graph)
|
||||
|
||||
# Compose new system prompt (Layer 1 + 2 + 3)
|
||||
new_system = compose_system_prompt(
|
||||
identity_prompt=getattr(graph, "identity_prompt", None),
|
||||
focus_prompt=next_spec.system_prompt,
|
||||
narrative=narrative,
|
||||
)
|
||||
continuous_conversation.update_system_prompt(new_system)
|
||||
|
||||
# Switch conversation store to the next node's directory
|
||||
# so the transition marker and all subsequent messages are
|
||||
# persisted there instead of the first node's directory.
|
||||
if self._storage_path:
|
||||
from framework.storage.conversation_store import (
|
||||
FileConversationStore,
|
||||
)
|
||||
|
||||
next_store_path = self._storage_path / "conversations" / next_spec.id
|
||||
next_store = FileConversationStore(base_path=next_store_path)
|
||||
await continuous_conversation.switch_store(next_store)
|
||||
|
||||
# Insert transition marker into conversation
|
||||
data_dir = str(self._storage_path / "data") if self._storage_path else None
|
||||
marker = build_transition_marker(
|
||||
previous_node=node_spec,
|
||||
next_node=next_spec,
|
||||
memory=memory,
|
||||
cumulative_tool_names=sorted(cumulative_tool_names),
|
||||
data_dir=data_dir,
|
||||
)
|
||||
await continuous_conversation.add_user_message(
|
||||
marker,
|
||||
is_transition_marker=True,
|
||||
)
|
||||
|
||||
# Set current phase for phase-aware compaction
|
||||
continuous_conversation.set_current_phase(next_spec.id)
|
||||
|
||||
# Opportunistic compaction at transition:
|
||||
# 1. Prune old tool results (free, no LLM call)
|
||||
# 2. If still over 80%, do a phase-graduated compact
|
||||
if continuous_conversation.usage_ratio() > 0.5:
|
||||
await continuous_conversation.prune_old_tool_results(
|
||||
protect_tokens=2000,
|
||||
)
|
||||
if continuous_conversation.needs_compaction():
|
||||
self.logger.info(
|
||||
" Phase-boundary compaction (%.0f%% usage)",
|
||||
continuous_conversation.usage_ratio() * 100,
|
||||
)
|
||||
summary = (
|
||||
f"Summary of earlier phases (before {next_spec.name}). "
|
||||
"See transition markers for phase details."
|
||||
)
|
||||
await continuous_conversation.compact(
|
||||
summary,
|
||||
keep_recent=4,
|
||||
phase_graduated=True,
|
||||
)
|
||||
|
||||
# Update input_data for next node
|
||||
input_data = result.output
|
||||
|
||||
@@ -993,6 +1283,11 @@ class GraphExecutor:
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
session_state={
|
||||
"memory": output, # output IS memory.read_all()
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
},
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -1082,6 +1377,7 @@ class GraphExecutor:
|
||||
"memory": saved_memory,
|
||||
"execution_path": list(path),
|
||||
"node_visit_counts": dict(node_visit_counts),
|
||||
"resume_from": current_node_id,
|
||||
}
|
||||
|
||||
# Mark latest checkpoint for resume on failure
|
||||
@@ -1134,12 +1430,21 @@ class GraphExecutor:
|
||||
goal: Goal,
|
||||
input_data: dict[str, Any],
|
||||
max_tokens: int = 4096,
|
||||
continuous_mode: bool = False,
|
||||
inherited_conversation: Any = None,
|
||||
override_tools: list | None = None,
|
||||
cumulative_output_keys: list[str] | None = None,
|
||||
event_triggered: bool = False,
|
||||
) -> NodeContext:
|
||||
"""Build execution context for a node."""
|
||||
# Filter tools to those available to this node
|
||||
available_tools = []
|
||||
if node_spec.tools:
|
||||
available_tools = [t for t in self.tools if t.name in node_spec.tools]
|
||||
if override_tools is not None:
|
||||
# Continuous mode: use cumulative tool set
|
||||
available_tools = list(override_tools)
|
||||
else:
|
||||
available_tools = []
|
||||
if node_spec.tools:
|
||||
available_tools = [t for t in self.tools if t.name in node_spec.tools]
|
||||
|
||||
# Create scoped memory view
|
||||
scoped_memory = memory.with_permissions(
|
||||
@@ -1160,18 +1465,23 @@ class GraphExecutor:
|
||||
max_tokens=max_tokens,
|
||||
runtime_logger=self.runtime_logger,
|
||||
pause_event=self._pause_requested, # Pass pause event for granular control
|
||||
continuous_mode=continuous_mode,
|
||||
inherited_conversation=inherited_conversation,
|
||||
cumulative_output_keys=cumulative_output_keys or [],
|
||||
event_triggered=event_triggered,
|
||||
)
|
||||
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
VALID_NODE_TYPES = {
|
||||
"llm_tool_use",
|
||||
"llm_generate",
|
||||
"router",
|
||||
"function",
|
||||
"human_input",
|
||||
"event_loop",
|
||||
}
|
||||
DEPRECATED_NODE_TYPES = {"llm_tool_use": "event_loop", "llm_generate": "event_loop"}
|
||||
# Node types removed in v0.5 — provide migration guidance
|
||||
REMOVED_NODE_TYPES = {
|
||||
"function": "event_loop",
|
||||
"llm_tool_use": "event_loop",
|
||||
"llm_generate": "event_loop",
|
||||
"router": "event_loop", # Unused theoretical infrastructure
|
||||
"human_input": "event_loop", # Use client_facing=True instead
|
||||
}
|
||||
|
||||
def _get_node_implementation(
|
||||
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
|
||||
@@ -1181,62 +1491,23 @@ class GraphExecutor:
|
||||
if node_spec.id in self.node_registry:
|
||||
return self.node_registry[node_spec.id]
|
||||
|
||||
# Reject removed node types with migration guidance
|
||||
if node_spec.node_type in self.REMOVED_NODE_TYPES:
|
||||
replacement = self.REMOVED_NODE_TYPES[node_spec.node_type]
|
||||
raise RuntimeError(
|
||||
f"Node type '{node_spec.node_type}' was removed in v0.5. "
|
||||
f"Migrate node '{node_spec.id}' to '{replacement}'. "
|
||||
f"See https://github.com/adenhq/hive/issues/4753 for migration guide."
|
||||
)
|
||||
|
||||
# Validate node type
|
||||
if node_spec.node_type not in self.VALID_NODE_TYPES:
|
||||
raise RuntimeError(
|
||||
f"Invalid node type '{node_spec.node_type}' for node '{node_spec.id}'. "
|
||||
f"Must be one of: {sorted(self.VALID_NODE_TYPES)}. "
|
||||
f"Use 'llm_tool_use' for nodes that call tools, 'llm_generate' for text generation."
|
||||
)
|
||||
|
||||
# Warn on deprecated node types
|
||||
if node_spec.node_type in self.DEPRECATED_NODE_TYPES:
|
||||
replacement = self.DEPRECATED_NODE_TYPES[node_spec.node_type]
|
||||
warnings.warn(
|
||||
f"Node type '{node_spec.node_type}' is deprecated. "
|
||||
f"Use '{replacement}' instead. "
|
||||
f"Node: '{node_spec.id}'",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Create based on type
|
||||
if node_spec.node_type == "llm_tool_use":
|
||||
if not node_spec.tools:
|
||||
raise RuntimeError(
|
||||
f"Node '{node_spec.id}' is type 'llm_tool_use' but declares no tools. "
|
||||
"Either add tools to the node or change type to 'llm_generate'."
|
||||
)
|
||||
return LLMNode(
|
||||
tool_executor=self.tool_executor,
|
||||
require_tools=True,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "llm_generate":
|
||||
return LLMNode(
|
||||
tool_executor=None,
|
||||
require_tools=False,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "router":
|
||||
return RouterNode()
|
||||
|
||||
if node_spec.node_type == "function":
|
||||
# Function nodes need explicit registration
|
||||
raise RuntimeError(
|
||||
f"Function node '{node_spec.id}' not registered. Register with node_registry."
|
||||
)
|
||||
|
||||
if node_spec.node_type == "human_input":
|
||||
# Human input nodes are handled specially by HITL mechanism
|
||||
return LLMNode(
|
||||
tool_executor=None,
|
||||
require_tools=False,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
f"Must be one of: {sorted(self.VALID_NODE_TYPES)}."
|
||||
)
|
||||
|
||||
# Create based on type (only event_loop is valid)
|
||||
if node_spec.node_type == "event_loop":
|
||||
# Auto-create EventLoopNode with sensible defaults.
|
||||
# Custom configs can still be pre-registered via node_registry.
|
||||
@@ -1284,7 +1555,7 @@ class GraphExecutor:
|
||||
# Should never reach here due to validation above
|
||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||
|
||||
def _follow_edges(
|
||||
async def _follow_edges(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
@@ -1299,7 +1570,7 @@ class GraphExecutor:
|
||||
for edge in edges:
|
||||
target_node_spec = graph.get_node(edge.target)
|
||||
|
||||
if edge.should_traverse(
|
||||
if await edge.should_traverse(
|
||||
source_success=result.success,
|
||||
source_output=result.output,
|
||||
memory=memory.read_all(),
|
||||
@@ -1325,7 +1596,7 @@ class GraphExecutor:
|
||||
self.logger.warning(f"⚠ Output validation failed: {validation.errors}")
|
||||
|
||||
# Clean the output
|
||||
cleaned_output = self.output_cleaner.clean_output(
|
||||
cleaned_output = await self.output_cleaner.clean_output(
|
||||
output=output_to_validate,
|
||||
source_node_id=current_node_id,
|
||||
target_node_spec=target_node_spec,
|
||||
@@ -1363,7 +1634,7 @@ class GraphExecutor:
|
||||
|
||||
return None
|
||||
|
||||
def _get_all_traversable_edges(
|
||||
async def _get_all_traversable_edges(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
@@ -1383,7 +1654,7 @@ class GraphExecutor:
|
||||
|
||||
for edge in edges:
|
||||
target_node_spec = graph.get_node(edge.target)
|
||||
if edge.should_traverse(
|
||||
if await edge.should_traverse(
|
||||
source_success=result.success,
|
||||
source_output=result.output,
|
||||
memory=memory.read_all(),
|
||||
@@ -1496,14 +1767,19 @@ class GraphExecutor:
|
||||
branch.error = f"Node {branch.node_id} not found in graph"
|
||||
return branch, RuntimeError(branch.error)
|
||||
|
||||
# Get node implementation to check its type
|
||||
branch_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||
|
||||
effective_max_retries = node_spec.max_retries
|
||||
if node_spec.node_type == "event_loop":
|
||||
if effective_max_retries > 1:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has "
|
||||
f"max_retries={effective_max_retries}. Overriding "
|
||||
"to 1 — event loop nodes handle retry internally."
|
||||
)
|
||||
# Only override for actual EventLoopNode instances, not custom NodeProtocol impls
|
||||
from framework.graph.event_loop_node import EventLoopNode
|
||||
|
||||
if isinstance(branch_impl, EventLoopNode) and effective_max_retries > 1:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has "
|
||||
f"max_retries={effective_max_retries}. Overriding "
|
||||
"to 1 — event loop nodes handle retry internally."
|
||||
)
|
||||
effective_max_retries = 1
|
||||
|
||||
branch.status = "running"
|
||||
@@ -1525,7 +1801,7 @@ class GraphExecutor:
|
||||
f"⚠ Output validation failed for branch "
|
||||
f"{branch.node_id}: {validation.errors}"
|
||||
)
|
||||
cleaned_output = self.output_cleaner.clean_output(
|
||||
cleaned_output = await self.output_cleaner.clean_output(
|
||||
output=mem_snapshot,
|
||||
source_node_id=source_node_spec.id if source_node_spec else "unknown",
|
||||
target_node_spec=node_spec,
|
||||
@@ -1669,10 +1945,6 @@ class GraphExecutor:
|
||||
"""Register a custom node implementation."""
|
||||
self.node_registry[node_id] = implementation
|
||||
|
||||
def register_function(self, node_id: str, func: Callable) -> None:
|
||||
"""Register a function as a node."""
|
||||
self.node_registry[node_id] = FunctionNode(func)
|
||||
|
||||
def request_pause(self) -> None:
|
||||
"""
|
||||
Request graceful pause of the current execution.
|
||||
|
||||
@@ -1,552 +0,0 @@
|
||||
"""
|
||||
Flexible Graph Executor with Worker-Judge Loop.
|
||||
|
||||
Executes plans created by external planner (Claude Code, etc.)
|
||||
using a Worker-Judge loop:
|
||||
|
||||
1. External planner creates Plan
|
||||
2. FlexibleGraphExecutor receives Plan
|
||||
3. Worker executes each step
|
||||
4. Judge evaluates each result
|
||||
5. If Judge says "replan" → return to external planner with feedback
|
||||
6. If Judge says "escalate" → request human intervention
|
||||
7. If all steps complete → return success
|
||||
|
||||
This keeps planning external while execution/evaluation is internal.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.code_sandbox import CodeSandbox
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.judge import HybridJudge, create_default_judge
|
||||
from framework.graph.plan import (
|
||||
ApprovalDecision,
|
||||
ApprovalRequest,
|
||||
ApprovalResult,
|
||||
ExecutionStatus,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
Plan,
|
||||
PlanExecutionResult,
|
||||
PlanStep,
|
||||
StepStatus,
|
||||
)
|
||||
from framework.graph.worker_node import StepExecutionResult, WorkerNode
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# Type alias for approval callback
|
||||
ApprovalCallback = Callable[[ApprovalRequest], ApprovalResult]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Configuration for FlexibleGraphExecutor."""
|
||||
|
||||
max_retries_per_step: int = 3
|
||||
max_total_steps: int = 100
|
||||
timeout_seconds: int = 300
|
||||
enable_parallel_execution: bool = False # Future: parallel step execution
|
||||
|
||||
|
||||
class FlexibleGraphExecutor:
|
||||
"""
|
||||
Executes plans with Worker-Judge loop.
|
||||
|
||||
Plans come from external source (Claude Code, etc.).
|
||||
Returns feedback for replanning if needed.
|
||||
|
||||
Usage:
|
||||
executor = FlexibleGraphExecutor(
|
||||
runtime=runtime,
|
||||
llm=llm_provider,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
result = await executor.execute_plan(plan, goal, context)
|
||||
|
||||
if result.status == ExecutionStatus.NEEDS_REPLAN:
|
||||
# External planner should create new plan using result.feedback
|
||||
new_plan = external_planner.replan(result.feedback_context)
|
||||
result = await executor.execute_plan(new_plan, goal, result.feedback_context)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runtime: Runtime,
|
||||
llm: LLMProvider | None = None,
|
||||
tools: dict[str, Tool] | None = None,
|
||||
tool_executor: Callable | None = None,
|
||||
functions: dict[str, Callable] | None = None,
|
||||
judge: HybridJudge | None = None,
|
||||
config: ExecutorConfig | None = None,
|
||||
approval_callback: ApprovalCallback | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the FlexibleGraphExecutor.
|
||||
|
||||
Args:
|
||||
runtime: Runtime for decision logging
|
||||
llm: LLM provider for Worker and Judge
|
||||
tools: Available tools
|
||||
tool_executor: Function to execute tools
|
||||
functions: Registered functions
|
||||
judge: Custom judge (defaults to HybridJudge with default rules)
|
||||
config: Executor configuration
|
||||
approval_callback: Callback for human-in-the-loop approval.
|
||||
If None, steps requiring approval will pause execution.
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
self.tools = tools or {}
|
||||
self.tool_executor = tool_executor
|
||||
self.functions = functions or {}
|
||||
self.config = config or ExecutorConfig()
|
||||
self.approval_callback = approval_callback
|
||||
|
||||
# Create judge
|
||||
self.judge = judge or create_default_judge(llm)
|
||||
|
||||
# Create worker
|
||||
self.worker = WorkerNode(
|
||||
runtime=runtime,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
functions=functions,
|
||||
sandbox=CodeSandbox(),
|
||||
)
|
||||
|
||||
async def execute_plan(
|
||||
self,
|
||||
plan: Plan,
|
||||
goal: Goal,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> PlanExecutionResult:
|
||||
"""
|
||||
Execute a plan created by external planner.
|
||||
|
||||
Args:
|
||||
plan: The plan to execute
|
||||
goal: The goal context
|
||||
context: Initial context (e.g., from previous execution)
|
||||
|
||||
Returns:
|
||||
PlanExecutionResult with status and feedback
|
||||
"""
|
||||
context = context or {}
|
||||
context.update(plan.context) # Merge plan's accumulated context
|
||||
|
||||
# Start run
|
||||
_run_id = self.runtime.start_run(
|
||||
goal_id=goal.id,
|
||||
goal_description=goal.description,
|
||||
input_data={"plan_id": plan.id, "revision": plan.revision},
|
||||
)
|
||||
|
||||
steps_executed = 0
|
||||
total_tokens = 0
|
||||
total_latency = 0
|
||||
|
||||
try:
|
||||
while steps_executed < self.config.max_total_steps:
|
||||
# Get next ready steps
|
||||
ready_steps = plan.get_ready_steps()
|
||||
|
||||
if not ready_steps:
|
||||
# Check if we're done or stuck
|
||||
if plan.is_complete():
|
||||
break
|
||||
else:
|
||||
# No ready steps but not complete - something's wrong
|
||||
return self._create_result(
|
||||
status=ExecutionStatus.NEEDS_REPLAN,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=(
|
||||
"No executable steps available but plan not complete. "
|
||||
"Check dependencies."
|
||||
),
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
)
|
||||
|
||||
# Execute next step (for now, sequential; could be parallel)
|
||||
step = ready_steps[0]
|
||||
# Debug: show ready steps
|
||||
# ready_ids = [s.id for s in ready_steps]
|
||||
# print(f" [DEBUG] Ready steps: {ready_ids}, executing: {step.id}")
|
||||
|
||||
# APPROVAL CHECK - before execution
|
||||
if step.requires_approval:
|
||||
approval_result = await self._request_approval(step, context)
|
||||
|
||||
if approval_result is None:
|
||||
# No callback, pause execution
|
||||
step.status = StepStatus.AWAITING_APPROVAL
|
||||
return self._create_result(
|
||||
status=ExecutionStatus.AWAITING_APPROVAL,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=f"Step '{step.id}' requires approval: {step.description}",
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
)
|
||||
|
||||
if approval_result.decision == ApprovalDecision.REJECT:
|
||||
step.status = StepStatus.REJECTED
|
||||
step.error = approval_result.reason or "Rejected by human"
|
||||
# Skip this step and continue with dependents marked as skipped
|
||||
self._skip_dependent_steps(plan, step.id)
|
||||
continue
|
||||
|
||||
if approval_result.decision == ApprovalDecision.ABORT:
|
||||
return self._create_result(
|
||||
status=ExecutionStatus.ABORTED,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=approval_result.reason or "Aborted by human",
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
)
|
||||
|
||||
if approval_result.decision == ApprovalDecision.MODIFY:
|
||||
# Apply modifications to step
|
||||
if approval_result.modifications:
|
||||
self._apply_modifications(step, approval_result.modifications)
|
||||
|
||||
# APPROVE - continue to execution
|
||||
|
||||
step.status = StepStatus.IN_PROGRESS
|
||||
step.started_at = datetime.now()
|
||||
step.attempts += 1
|
||||
|
||||
# WORK
|
||||
work_result = await self.worker.execute(step, context)
|
||||
steps_executed += 1
|
||||
total_tokens += work_result.tokens_used
|
||||
total_latency += work_result.latency_ms
|
||||
|
||||
# JUDGE
|
||||
judgment = await self.judge.evaluate(
|
||||
step=step,
|
||||
result=work_result.__dict__,
|
||||
goal=goal,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# Handle judgment
|
||||
result = await self._handle_judgment(
|
||||
step=step,
|
||||
work_result=work_result,
|
||||
judgment=judgment,
|
||||
plan=plan,
|
||||
goal=goal,
|
||||
context=context,
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
# Judgment resulted in early return (replan/escalate)
|
||||
self.runtime.end_run(
|
||||
success=False,
|
||||
narrative=f"Execution stopped: {result.status.value}",
|
||||
)
|
||||
return result
|
||||
|
||||
# All steps completed successfully
|
||||
self.runtime.end_run(
|
||||
success=True,
|
||||
output_data=context,
|
||||
narrative=f"Plan completed: {steps_executed} steps executed",
|
||||
)
|
||||
|
||||
return self._create_result(
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
plan=plan,
|
||||
context=context,
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.runtime.report_problem(
|
||||
severity="critical",
|
||||
description=str(e),
|
||||
)
|
||||
self.runtime.end_run(
|
||||
success=False,
|
||||
narrative=f"Execution failed: {e}",
|
||||
)
|
||||
|
||||
return PlanExecutionResult(
|
||||
status=ExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
feedback=f"Execution error: {e}",
|
||||
feedback_context=plan.to_feedback_context(),
|
||||
completed_steps=[s.id for s in plan.get_completed_steps()],
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency_ms=total_latency,
|
||||
)
|
||||
|
||||
async def _handle_judgment(
|
||||
self,
|
||||
step: PlanStep,
|
||||
work_result: StepExecutionResult,
|
||||
judgment: Judgment,
|
||||
plan: Plan,
|
||||
goal: Goal,
|
||||
context: dict[str, Any],
|
||||
steps_executed: int,
|
||||
total_tokens: int,
|
||||
total_latency: int,
|
||||
) -> PlanExecutionResult | None:
|
||||
"""
|
||||
Handle judgment and return result if execution should stop.
|
||||
|
||||
Returns None to continue execution, or PlanExecutionResult to stop.
|
||||
"""
|
||||
if judgment.action == JudgmentAction.ACCEPT:
|
||||
# Step succeeded - update state and continue
|
||||
step.status = StepStatus.COMPLETED
|
||||
step.completed_at = datetime.now()
|
||||
step.result = work_result.outputs
|
||||
|
||||
# Map outputs to expected output keys
|
||||
# If output has generic "result" key but step expects specific keys, map it
|
||||
outputs_to_store = work_result.outputs.copy()
|
||||
if step.expected_outputs and "result" in outputs_to_store:
|
||||
result_value = outputs_to_store["result"]
|
||||
# For each expected output key that's not in outputs, map from "result"
|
||||
for expected_key in step.expected_outputs:
|
||||
if expected_key not in outputs_to_store:
|
||||
outputs_to_store[expected_key] = result_value
|
||||
|
||||
# Update context with mapped outputs
|
||||
context.update(outputs_to_store)
|
||||
|
||||
# Store in plan context for replanning feedback
|
||||
plan.context[step.id] = outputs_to_store
|
||||
|
||||
return None # Continue execution
|
||||
|
||||
elif judgment.action == JudgmentAction.RETRY:
|
||||
# Retry step if under limit
|
||||
if step.attempts < step.max_retries:
|
||||
step.status = StepStatus.PENDING
|
||||
step.error = judgment.feedback
|
||||
|
||||
# Record retry decision
|
||||
self.runtime.decide(
|
||||
intent=f"Retry step {step.id}",
|
||||
options=[{"id": "retry", "description": "Retry with feedback"}],
|
||||
chosen="retry",
|
||||
reasoning=judgment.reasoning,
|
||||
context={"attempt": step.attempts, "feedback": judgment.feedback},
|
||||
)
|
||||
|
||||
return None # Continue (step will be retried)
|
||||
else:
|
||||
# Max retries exceeded - escalate to replan
|
||||
step.status = StepStatus.FAILED
|
||||
step.error = f"Max retries ({step.max_retries}) exceeded: {judgment.feedback}"
|
||||
|
||||
return self._create_result(
|
||||
status=ExecutionStatus.NEEDS_REPLAN,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=(
|
||||
f"Step '{step.id}' failed after {step.attempts} attempts: "
|
||||
f"{judgment.feedback}"
|
||||
),
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
)
|
||||
|
||||
elif judgment.action == JudgmentAction.REPLAN:
|
||||
# Return to external planner
|
||||
step.status = StepStatus.FAILED
|
||||
step.error = judgment.feedback
|
||||
|
||||
return self._create_result(
|
||||
status=ExecutionStatus.NEEDS_REPLAN,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=judgment.feedback or f"Step '{step.id}' requires replanning",
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
)
|
||||
|
||||
elif judgment.action == JudgmentAction.ESCALATE:
|
||||
# Request human intervention
|
||||
return self._create_result(
|
||||
status=ExecutionStatus.NEEDS_ESCALATION,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=judgment.feedback or f"Step '{step.id}' requires human intervention",
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
)
|
||||
|
||||
return None # Unknown action - continue
|
||||
|
||||
def _create_result(
|
||||
self,
|
||||
status: ExecutionStatus,
|
||||
plan: Plan,
|
||||
context: dict[str, Any],
|
||||
feedback: str | None = None,
|
||||
steps_executed: int = 0,
|
||||
total_tokens: int = 0,
|
||||
total_latency: int = 0,
|
||||
) -> PlanExecutionResult:
|
||||
"""Create a PlanExecutionResult."""
|
||||
return PlanExecutionResult(
|
||||
status=status,
|
||||
results=context,
|
||||
feedback=feedback,
|
||||
feedback_context=plan.to_feedback_context(),
|
||||
completed_steps=[s.id for s in plan.get_completed_steps()],
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency_ms=total_latency,
|
||||
)
|
||||
|
||||
def register_function(self, name: str, func: Callable) -> None:
|
||||
"""Register a function for FUNCTION actions."""
|
||||
self.functions[name] = func
|
||||
self.worker.register_function(name, func)
|
||||
|
||||
def register_tool(self, tool: Tool) -> None:
|
||||
"""Register a tool for TOOL_USE actions."""
|
||||
self.tools[tool.name] = tool
|
||||
self.worker.register_tool(tool)
|
||||
|
||||
def add_evaluation_rule(self, rule) -> None:
|
||||
"""Add an evaluation rule to the judge."""
|
||||
self.judge.add_rule(rule)
|
||||
|
||||
async def _request_approval(
|
||||
self,
|
||||
step: PlanStep,
|
||||
context: dict[str, Any],
|
||||
) -> ApprovalResult | None:
|
||||
"""
|
||||
Request human approval for a step.
|
||||
|
||||
Returns None if no callback is set (execution should pause).
|
||||
"""
|
||||
if self.approval_callback is None:
|
||||
return None
|
||||
|
||||
# Build preview of what will happen
|
||||
preview_parts = []
|
||||
if step.action.tool_name:
|
||||
preview_parts.append(f"Tool: {step.action.tool_name}")
|
||||
if step.action.tool_args:
|
||||
import json
|
||||
|
||||
args_preview = json.dumps(step.action.tool_args, indent=2, default=str)
|
||||
if len(args_preview) > 500:
|
||||
args_preview = args_preview[:500] + "..."
|
||||
preview_parts.append(f"Args: {args_preview}")
|
||||
elif step.action.prompt:
|
||||
prompt_preview = (
|
||||
step.action.prompt[:300] + "..."
|
||||
if len(step.action.prompt) > 300
|
||||
else step.action.prompt
|
||||
)
|
||||
preview_parts.append(f"Prompt: {prompt_preview}")
|
||||
|
||||
# Include step inputs resolved from context (what will be sent/used)
|
||||
relevant_context = {}
|
||||
for input_key, input_value in step.inputs.items():
|
||||
# Resolve variable references like "$email_sequence"
|
||||
if isinstance(input_value, str) and input_value.startswith("$"):
|
||||
context_key = input_value[1:] # Remove $ prefix
|
||||
if context_key in context:
|
||||
relevant_context[input_key] = context[context_key]
|
||||
else:
|
||||
relevant_context[input_key] = input_value
|
||||
|
||||
request = ApprovalRequest(
|
||||
step_id=step.id,
|
||||
step_description=step.description,
|
||||
action_type=step.action.action_type.value,
|
||||
action_details={
|
||||
"tool_name": step.action.tool_name,
|
||||
"tool_args": step.action.tool_args,
|
||||
"prompt": step.action.prompt,
|
||||
},
|
||||
context=relevant_context,
|
||||
approval_message=step.approval_message,
|
||||
preview="\n".join(preview_parts) if preview_parts else None,
|
||||
)
|
||||
|
||||
return self.approval_callback(request)
|
||||
|
||||
def _skip_dependent_steps(self, plan: Plan, rejected_step_id: str) -> None:
|
||||
"""Mark steps that depend on a rejected step as skipped."""
|
||||
for step in plan.steps:
|
||||
if rejected_step_id in step.dependencies:
|
||||
if step.status == StepStatus.PENDING:
|
||||
step.status = StepStatus.SKIPPED
|
||||
step.error = f"Skipped because dependency '{rejected_step_id}' was rejected"
|
||||
# Recursively skip dependents
|
||||
self._skip_dependent_steps(plan, step.id)
|
||||
|
||||
def _apply_modifications(self, step: PlanStep, modifications: dict[str, Any]) -> None:
|
||||
"""Apply human modifications to a step before execution."""
|
||||
# Allow modifying tool args
|
||||
if "tool_args" in modifications and step.action.tool_args:
|
||||
step.action.tool_args.update(modifications["tool_args"])
|
||||
|
||||
# Allow modifying prompt
|
||||
if "prompt" in modifications:
|
||||
step.action.prompt = modifications["prompt"]
|
||||
|
||||
# Allow modifying inputs
|
||||
if "inputs" in modifications:
|
||||
step.inputs.update(modifications["inputs"])
|
||||
|
||||
def set_approval_callback(self, callback: ApprovalCallback) -> None:
|
||||
"""Set the approval callback for HITL steps."""
|
||||
self.approval_callback = callback
|
||||
|
||||
|
||||
# Convenience function for simple execution
|
||||
async def execute_plan(
|
||||
plan: Plan,
|
||||
goal: Goal,
|
||||
runtime: Runtime,
|
||||
llm: LLMProvider | None = None,
|
||||
tools: dict[str, Tool] | None = None,
|
||||
tool_executor: Callable | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> PlanExecutionResult:
|
||||
"""
|
||||
Execute a plan with default configuration.
|
||||
|
||||
Convenience function for simple use cases.
|
||||
"""
|
||||
executor = FlexibleGraphExecutor(
|
||||
runtime=runtime,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
return await executor.execute_plan(plan, goal, context)
|
||||
@@ -1,406 +0,0 @@
|
||||
"""
|
||||
Hybrid Judge for Evaluating Plan Step Results.
|
||||
|
||||
The HybridJudge evaluates step execution results using:
|
||||
1. Rule-based evaluation (fast, deterministic)
|
||||
2. LLM-based evaluation (fallback for ambiguous cases)
|
||||
|
||||
Escalation path: rules → LLM → human
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.code_sandbox import safe_eval
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.plan import (
|
||||
EvaluationRule,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
PlanStep,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuleEvaluationResult:
|
||||
"""Result of rule-based evaluation."""
|
||||
|
||||
is_definitive: bool # True if a rule matched definitively
|
||||
judgment: Judgment | None = None
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
rules_checked: int = 0
|
||||
rule_matched: str | None = None
|
||||
|
||||
|
||||
class HybridJudge:
|
||||
"""
|
||||
Evaluates plan step results using rules first, then LLM fallback.
|
||||
|
||||
Usage:
|
||||
judge = HybridJudge(llm=llm_provider)
|
||||
judge.add_rule(EvaluationRule(
|
||||
id="success_check",
|
||||
condition="result.get('success') == True",
|
||||
action=JudgmentAction.ACCEPT,
|
||||
))
|
||||
|
||||
judgment = await judge.evaluate(step, result, goal)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLMProvider | None = None,
|
||||
rules: list[EvaluationRule] | None = None,
|
||||
llm_confidence_threshold: float = 0.7,
|
||||
):
|
||||
"""
|
||||
Initialize the HybridJudge.
|
||||
|
||||
Args:
|
||||
llm: LLM provider for ambiguous cases
|
||||
rules: Initial evaluation rules
|
||||
llm_confidence_threshold: Confidence below this triggers escalation
|
||||
"""
|
||||
self.llm = llm
|
||||
self.rules: list[EvaluationRule] = rules or []
|
||||
self.llm_confidence_threshold = llm_confidence_threshold
|
||||
|
||||
# Sort rules by priority (higher first)
|
||||
self._sort_rules()
|
||||
|
||||
def _sort_rules(self):
|
||||
"""Sort rules by priority."""
|
||||
self.rules.sort(key=lambda r: -r.priority)
|
||||
|
||||
def add_rule(self, rule: EvaluationRule) -> None:
|
||||
"""Add an evaluation rule."""
|
||||
self.rules.append(rule)
|
||||
self._sort_rules()
|
||||
|
||||
def remove_rule(self, rule_id: str) -> bool:
|
||||
"""Remove a rule by ID. Returns True if found and removed."""
|
||||
for i, rule in enumerate(self.rules):
|
||||
if rule.id == rule_id:
|
||||
self.rules.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
step: PlanStep,
|
||||
result: Any,
|
||||
goal: Goal,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> Judgment:
|
||||
"""
|
||||
Evaluate a step result.
|
||||
|
||||
Args:
|
||||
step: The executed plan step
|
||||
result: The result of executing the step
|
||||
goal: The goal context for evaluation
|
||||
context: Additional context from previous steps
|
||||
|
||||
Returns:
|
||||
Judgment with action and feedback
|
||||
"""
|
||||
context = context or {}
|
||||
|
||||
# Try rule-based evaluation first
|
||||
rule_result = self._evaluate_rules(step, result, goal, context)
|
||||
|
||||
if rule_result.is_definitive:
|
||||
return rule_result.judgment
|
||||
|
||||
# Fall back to LLM evaluation
|
||||
if self.llm:
|
||||
return await self._evaluate_llm(step, result, goal, context, rule_result)
|
||||
|
||||
# No LLM available - default to accept with low confidence
|
||||
return Judgment(
|
||||
action=JudgmentAction.ACCEPT,
|
||||
reasoning="No definitive rule matched and no LLM available for evaluation",
|
||||
confidence=0.5,
|
||||
llm_used=False,
|
||||
)
|
||||
|
||||
def _evaluate_rules(
|
||||
self,
|
||||
step: PlanStep,
|
||||
result: Any,
|
||||
goal: Goal,
|
||||
context: dict[str, Any],
|
||||
) -> RuleEvaluationResult:
|
||||
"""Evaluate step using rules."""
|
||||
rules_checked = 0
|
||||
|
||||
# Build evaluation context
|
||||
eval_context = {
|
||||
"step": step.model_dump() if hasattr(step, "model_dump") else step,
|
||||
"result": result,
|
||||
"goal": goal.model_dump() if hasattr(goal, "model_dump") else goal,
|
||||
"context": context,
|
||||
"success": isinstance(result, dict) and result.get("success", False),
|
||||
"error": isinstance(result, dict) and result.get("error"),
|
||||
}
|
||||
|
||||
for rule in self.rules:
|
||||
rules_checked += 1
|
||||
|
||||
# Evaluate rule condition
|
||||
eval_result = safe_eval(rule.condition, eval_context)
|
||||
|
||||
if eval_result.success and eval_result.result:
|
||||
# Rule matched!
|
||||
feedback = self._format_feedback(rule.feedback_template, eval_context)
|
||||
|
||||
return RuleEvaluationResult(
|
||||
is_definitive=True,
|
||||
judgment=Judgment(
|
||||
action=rule.action,
|
||||
reasoning=rule.description,
|
||||
feedback=feedback if feedback else None,
|
||||
rule_matched=rule.id,
|
||||
confidence=1.0,
|
||||
llm_used=False,
|
||||
),
|
||||
rules_checked=rules_checked,
|
||||
rule_matched=rule.id,
|
||||
)
|
||||
|
||||
# No rule matched definitively
|
||||
return RuleEvaluationResult(
|
||||
is_definitive=False,
|
||||
context=eval_context,
|
||||
rules_checked=rules_checked,
|
||||
)
|
||||
|
||||
def _format_feedback(
|
||||
self,
|
||||
template: str,
|
||||
context: dict[str, Any],
|
||||
) -> str:
|
||||
"""Format feedback template with context values."""
|
||||
if not template:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return template.format(**context)
|
||||
except (KeyError, ValueError):
|
||||
return template
|
||||
|
||||
async def _evaluate_llm(
|
||||
self,
|
||||
step: PlanStep,
|
||||
result: Any,
|
||||
goal: Goal,
|
||||
context: dict[str, Any],
|
||||
rule_result: RuleEvaluationResult,
|
||||
) -> Judgment:
|
||||
"""Evaluate step using LLM."""
|
||||
system_prompt = self._build_llm_system_prompt(goal)
|
||||
user_prompt = self._build_llm_user_prompt(step, result, context, rule_result)
|
||||
|
||||
try:
|
||||
response = self.llm.complete(
|
||||
messages=[{"role": "user", "content": user_prompt}],
|
||||
system=system_prompt,
|
||||
)
|
||||
|
||||
# Parse LLM response
|
||||
judgment = self._parse_llm_response(response.content)
|
||||
judgment.llm_used = True
|
||||
|
||||
# Check confidence threshold
|
||||
if judgment.confidence < self.llm_confidence_threshold:
|
||||
# Low confidence - escalate
|
||||
return Judgment(
|
||||
action=JudgmentAction.ESCALATE,
|
||||
reasoning=(
|
||||
f"LLM confidence ({judgment.confidence:.2f}) "
|
||||
f"below threshold ({self.llm_confidence_threshold})"
|
||||
),
|
||||
feedback=judgment.feedback,
|
||||
confidence=judgment.confidence,
|
||||
llm_used=True,
|
||||
context={"original_judgment": judgment.model_dump()},
|
||||
)
|
||||
|
||||
return judgment
|
||||
|
||||
except Exception as e:
|
||||
# LLM failed - escalate
|
||||
return Judgment(
|
||||
action=JudgmentAction.ESCALATE,
|
||||
reasoning=f"LLM evaluation failed: {e}",
|
||||
feedback="Human review needed due to LLM error",
|
||||
llm_used=True,
|
||||
)
|
||||
|
||||
def _build_llm_system_prompt(self, goal: Goal) -> str:
|
||||
"""Build system prompt for LLM judge."""
|
||||
return f"""You are a judge evaluating the execution of a plan step.
|
||||
|
||||
GOAL: {goal.description}
|
||||
|
||||
SUCCESS CRITERIA:
|
||||
{chr(10).join(f"- {sc.description}" for sc in goal.success_criteria)}
|
||||
|
||||
CONSTRAINTS:
|
||||
{chr(10).join(f"- {c.description}" for c in goal.constraints)}
|
||||
|
||||
Your task is to evaluate whether the step was executed successfully and decide the next action.
|
||||
|
||||
Respond in this exact format:
|
||||
ACTION: [ACCEPT|RETRY|REPLAN|ESCALATE]
|
||||
CONFIDENCE: [0.0-1.0]
|
||||
REASONING: [Your reasoning]
|
||||
FEEDBACK: [Feedback for retry/replan, or empty if accepting]
|
||||
|
||||
Actions:
|
||||
- ACCEPT: Step completed successfully, continue to next step
|
||||
- RETRY: Step failed but can be retried with feedback
|
||||
- REPLAN: Step failed in a way that requires replanning
|
||||
- ESCALATE: Requires human intervention
|
||||
"""
|
||||
|
||||
def _build_llm_user_prompt(
|
||||
self,
|
||||
step: PlanStep,
|
||||
result: Any,
|
||||
context: dict[str, Any],
|
||||
rule_result: RuleEvaluationResult,
|
||||
) -> str:
|
||||
"""Build user prompt for LLM judge."""
|
||||
return f"""Evaluate this step execution:
|
||||
|
||||
STEP: {step.description}
|
||||
STEP ID: {step.id}
|
||||
ACTION TYPE: {step.action.action_type}
|
||||
EXPECTED OUTPUTS: {step.expected_outputs}
|
||||
|
||||
RESULT:
|
||||
{result}
|
||||
|
||||
CONTEXT FROM PREVIOUS STEPS:
|
||||
{context}
|
||||
|
||||
RULES CHECKED: {rule_result.rules_checked} (none matched definitively)
|
||||
|
||||
Please evaluate and provide your judgment."""
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Judgment:
|
||||
"""Parse LLM response into Judgment."""
|
||||
lines = response.strip().split("\n")
|
||||
|
||||
action = JudgmentAction.ACCEPT
|
||||
confidence = 0.8
|
||||
reasoning = ""
|
||||
feedback = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("ACTION:"):
|
||||
action_str = line.split(":", 1)[1].strip().upper()
|
||||
try:
|
||||
action = JudgmentAction(action_str.lower())
|
||||
except ValueError:
|
||||
action = JudgmentAction.ESCALATE
|
||||
|
||||
elif line.startswith("CONFIDENCE:"):
|
||||
try:
|
||||
confidence = float(line.split(":", 1)[1].strip())
|
||||
except ValueError:
|
||||
confidence = 0.5
|
||||
|
||||
elif line.startswith("REASONING:"):
|
||||
reasoning = line.split(":", 1)[1].strip()
|
||||
|
||||
elif line.startswith("FEEDBACK:"):
|
||||
feedback = line.split(":", 1)[1].strip()
|
||||
|
||||
return Judgment(
|
||||
action=action,
|
||||
reasoning=reasoning or "LLM evaluation",
|
||||
feedback=feedback if feedback else None,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
|
||||
# Factory function for creating judge with common rules
|
||||
def create_default_judge(llm: LLMProvider | None = None) -> HybridJudge:
|
||||
"""
|
||||
Create a HybridJudge with commonly useful default rules.
|
||||
|
||||
Args:
|
||||
llm: LLM provider for fallback evaluation
|
||||
|
||||
Returns:
|
||||
Configured HybridJudge instance
|
||||
"""
|
||||
judge = HybridJudge(llm=llm)
|
||||
|
||||
# Rule: Accept on explicit success flag
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="explicit_success",
|
||||
description="Step explicitly marked as successful",
|
||||
condition="isinstance(result, dict) and result.get('success') == True",
|
||||
action=JudgmentAction.ACCEPT,
|
||||
priority=100,
|
||||
)
|
||||
)
|
||||
|
||||
# Rule: Retry on transient errors
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="transient_error_retry",
|
||||
description="Transient error that can be retried",
|
||||
condition=(
|
||||
"isinstance(result, dict) and "
|
||||
"result.get('error_type') in ['timeout', 'rate_limit', 'connection_error']"
|
||||
),
|
||||
action=JudgmentAction.RETRY,
|
||||
feedback_template="Transient error: {result[error]}. Please retry.",
|
||||
priority=90,
|
||||
)
|
||||
)
|
||||
|
||||
# Rule: Replan on missing data
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="missing_data_replan",
|
||||
description="Required data not available",
|
||||
condition="isinstance(result, dict) and result.get('error_type') == 'missing_data'",
|
||||
action=JudgmentAction.REPLAN,
|
||||
feedback_template="Missing required data: {result[error]}. Plan needs adjustment.",
|
||||
priority=80,
|
||||
)
|
||||
)
|
||||
|
||||
# Rule: Escalate on security issues
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="security_escalate",
|
||||
description="Security issue detected",
|
||||
condition="isinstance(result, dict) and result.get('error_type') == 'security'",
|
||||
action=JudgmentAction.ESCALATE,
|
||||
feedback_template="Security issue detected: {result[error]}",
|
||||
priority=200,
|
||||
)
|
||||
)
|
||||
|
||||
# Rule: Fail on max retries exceeded
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="max_retries_fail",
|
||||
description="Maximum retries exceeded",
|
||||
condition="step.get('attempts', 0) >= step.get('max_retries', 3)",
|
||||
action=JudgmentAction.REPLAN,
|
||||
feedback_template="Step '{step[id]}' failed after {step[attempts]} attempts",
|
||||
priority=150,
|
||||
)
|
||||
)
|
||||
|
||||
return judge
|
||||
+24
-1342
File diff suppressed because it is too large
Load Diff
@@ -206,7 +206,7 @@ class OutputCleaner:
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
def clean_output(
|
||||
async def clean_output(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
source_node_id: str,
|
||||
@@ -288,7 +288,7 @@ Return ONLY valid JSON matching the expected schema. No explanations, no markdow
|
||||
f"🧹 Cleaning output from '{source_node_id}' using {self.config.fast_model}"
|
||||
)
|
||||
|
||||
response = self.llm.complete(
|
||||
response = await self.llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system=(
|
||||
"You clean malformed agent outputs. Return only valid JSON matching the schema."
|
||||
|
||||
@@ -1,513 +0,0 @@
|
||||
"""
|
||||
Plan Data Structures for Flexible Execution.
|
||||
|
||||
Plans are created externally (by Claude Code or another LLM agent) and
|
||||
executed internally by the FlexibleGraphExecutor with Worker-Judge loop.
|
||||
|
||||
The Plan is the contract between the external planner and the executor:
|
||||
- Planner creates a Plan with PlanSteps
|
||||
- Executor runs steps and judges results
|
||||
- If replanning needed, returns feedback to external planner
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ActionType(StrEnum):
|
||||
"""Types of actions a PlanStep can perform."""
|
||||
|
||||
LLM_CALL = "llm_call" # Call LLM for generation
|
||||
TOOL_USE = "tool_use" # Use a registered tool
|
||||
SUB_GRAPH = "sub_graph" # Execute a sub-graph
|
||||
FUNCTION = "function" # Call a Python function
|
||||
CODE_EXECUTION = "code_execution" # Execute dynamic code (sandboxed)
|
||||
|
||||
|
||||
class StepStatus(StrEnum):
|
||||
"""Status of a plan step."""
|
||||
|
||||
PENDING = "pending"
|
||||
AWAITING_APPROVAL = "awaiting_approval" # Waiting for human approval
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
REJECTED = "rejected" # Human rejected execution
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
"""Check if this status represents a terminal (finished) state.
|
||||
|
||||
Terminal states are states where the step will not execute further,
|
||||
either because it completed successfully or failed/was skipped.
|
||||
"""
|
||||
return self in (
|
||||
StepStatus.COMPLETED,
|
||||
StepStatus.FAILED,
|
||||
StepStatus.SKIPPED,
|
||||
StepStatus.REJECTED,
|
||||
)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if this status represents successful completion."""
|
||||
return self == StepStatus.COMPLETED
|
||||
|
||||
|
||||
class ApprovalDecision(StrEnum):
|
||||
"""Human decision on a step requiring approval."""
|
||||
|
||||
APPROVE = "approve" # Execute as planned
|
||||
REJECT = "reject" # Skip this step
|
||||
MODIFY = "modify" # Execute with modifications
|
||||
ABORT = "abort" # Stop entire execution
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""Request for human approval before executing a step."""
|
||||
|
||||
step_id: str
|
||||
step_description: str
|
||||
action_type: str
|
||||
action_details: dict[str, Any] = Field(default_factory=dict)
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
approval_message: str | None = None
|
||||
|
||||
# Preview of what will happen
|
||||
preview: str | None = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class ApprovalResult(BaseModel):
|
||||
"""Result of human approval decision."""
|
||||
|
||||
decision: ApprovalDecision
|
||||
reason: str | None = None
|
||||
modifications: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class JudgmentAction(StrEnum):
|
||||
"""Actions the judge can take after evaluating a step."""
|
||||
|
||||
ACCEPT = "accept" # Step completed successfully, continue
|
||||
RETRY = "retry" # Retry the step with feedback
|
||||
REPLAN = "replan" # Return to external planner for new plan
|
||||
ESCALATE = "escalate" # Request human intervention
|
||||
|
||||
|
||||
class ActionSpec(BaseModel):
|
||||
"""
|
||||
Specification for an action to be executed.
|
||||
|
||||
This is the "what to do" part of a PlanStep.
|
||||
"""
|
||||
|
||||
action_type: ActionType
|
||||
|
||||
# For LLM_CALL
|
||||
prompt: str | None = None
|
||||
system_prompt: str | None = None
|
||||
model: str | None = None
|
||||
|
||||
# For TOOL_USE
|
||||
tool_name: str | None = None
|
||||
tool_args: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# For SUB_GRAPH
|
||||
graph_id: str | None = None
|
||||
|
||||
# For FUNCTION
|
||||
function_name: str | None = None
|
||||
function_args: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# For CODE_EXECUTION
|
||||
code: str | None = None
|
||||
language: str = "python"
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class PlanStep(BaseModel):
|
||||
"""
|
||||
A single step in a plan.
|
||||
|
||||
Created by external planner, executed by Worker, evaluated by Judge.
|
||||
"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
action: ActionSpec
|
||||
|
||||
# Data flow
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Input data for this step (can reference previous step outputs)",
|
||||
)
|
||||
expected_outputs: list[str] = Field(
|
||||
default_factory=list, description="Keys this step should produce"
|
||||
)
|
||||
|
||||
# Dependencies
|
||||
dependencies: list[str] = Field(
|
||||
default_factory=list, description="IDs of steps that must complete before this one"
|
||||
)
|
||||
|
||||
# Human-in-the-loop (HITL)
|
||||
requires_approval: bool = Field(
|
||||
default=False, description="If True, requires human approval before execution"
|
||||
)
|
||||
approval_message: str | None = Field(
|
||||
default=None, description="Message to show human when requesting approval"
|
||||
)
|
||||
|
||||
# Execution state
|
||||
status: StepStatus = StepStatus.PENDING
|
||||
result: Any | None = None
|
||||
error: str | None = None
|
||||
attempts: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
# Metadata
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def is_ready(self, terminal_step_ids: set[str]) -> bool:
|
||||
"""Check if this step is ready to execute (all dependencies finished).
|
||||
|
||||
A step is ready when:
|
||||
1. Its status is PENDING (not yet started)
|
||||
2. All its dependencies are in a terminal state (completed, failed, skipped, or rejected)
|
||||
|
||||
Note: This allows dependent steps to become "ready" even if their dependencies
|
||||
failed. The executor should check if any dependencies failed and handle
|
||||
accordingly (e.g., skip the step or mark it as blocked).
|
||||
|
||||
Args:
|
||||
terminal_step_ids: Set of step IDs that are in a terminal state
|
||||
"""
|
||||
if self.status != StepStatus.PENDING:
|
||||
return False
|
||||
return all(dep in terminal_step_ids for dep in self.dependencies)
|
||||
|
||||
|
||||
class Judgment(BaseModel):
|
||||
"""
|
||||
Result of judging a step execution.
|
||||
|
||||
The Judge evaluates step results and decides what to do next.
|
||||
"""
|
||||
|
||||
action: JudgmentAction
|
||||
reasoning: str
|
||||
feedback: str | None = None # For retry/replan - what went wrong
|
||||
|
||||
# For rule-based judgments
|
||||
rule_matched: str | None = None
|
||||
|
||||
# For LLM-based judgments
|
||||
confidence: float = 1.0
|
||||
llm_used: bool = False
|
||||
|
||||
# Context for replanning
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class EvaluationRule(BaseModel):
|
||||
"""
|
||||
A rule for the HybridJudge to evaluate step results.
|
||||
|
||||
Rules are checked before falling back to LLM evaluation.
|
||||
"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
|
||||
# Condition (Python expression evaluated with result, step, goal context)
|
||||
condition: str
|
||||
|
||||
# What to do if condition matches
|
||||
action: JudgmentAction
|
||||
feedback_template: str = "" # Can use {result}, {step}, etc.
|
||||
|
||||
# Priority (higher = checked first)
|
||||
priority: int = 0
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class Plan(BaseModel):
|
||||
"""
|
||||
A complete execution plan.
|
||||
|
||||
Created by external planner (Claude Code, etc).
|
||||
Executed by FlexibleGraphExecutor.
|
||||
"""
|
||||
|
||||
id: str
|
||||
goal_id: str
|
||||
description: str
|
||||
|
||||
# Steps to execute
|
||||
steps: list[PlanStep] = Field(default_factory=list)
|
||||
|
||||
# Execution state
|
||||
revision: int = 1 # Incremented on replan
|
||||
current_step_idx: int = 0
|
||||
|
||||
# Accumulated context from execution
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Metadata
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
created_by: str = "external" # Who created this plan
|
||||
|
||||
# Previous attempt info (for replanning)
|
||||
previous_feedback: str | None = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: str | dict) -> "Plan":
|
||||
"""
|
||||
Load a Plan from exported JSON.
|
||||
|
||||
This handles the output from export_graph() and properly converts
|
||||
action_type strings to ActionType enums.
|
||||
|
||||
Args:
|
||||
data: JSON string or dict from export_graph()
|
||||
|
||||
Returns:
|
||||
Plan object ready for FlexibleGraphExecutor
|
||||
|
||||
Example:
|
||||
# Load from export_graph() output
|
||||
exported = export_graph()
|
||||
plan = Plan.from_json(exported)
|
||||
|
||||
# Load from file
|
||||
with open("plan.json") as f:
|
||||
plan = Plan.from_json(json.load(f))
|
||||
"""
|
||||
import json as json_module
|
||||
|
||||
if isinstance(data, str):
|
||||
data = json_module.loads(data)
|
||||
|
||||
# Handle nested "plan" key from export_graph output
|
||||
if "plan" in data:
|
||||
data = data["plan"]
|
||||
|
||||
# Convert steps
|
||||
steps = []
|
||||
for step_data in data.get("steps", []):
|
||||
action_data = step_data.get("action", {})
|
||||
|
||||
# Convert action_type string to enum
|
||||
action_type_str = action_data.get("action_type", "function")
|
||||
action_type = ActionType(action_type_str)
|
||||
|
||||
action = ActionSpec(
|
||||
action_type=action_type,
|
||||
prompt=action_data.get("prompt"),
|
||||
system_prompt=action_data.get("system_prompt"),
|
||||
tool_name=action_data.get("tool_name"),
|
||||
tool_args=action_data.get("tool_args", {}),
|
||||
function_name=action_data.get("function_name"),
|
||||
function_args=action_data.get("function_args", {}),
|
||||
code=action_data.get("code"),
|
||||
)
|
||||
|
||||
step = PlanStep(
|
||||
id=step_data["id"],
|
||||
description=step_data.get("description", ""),
|
||||
action=action,
|
||||
inputs=step_data.get("inputs", {}),
|
||||
expected_outputs=step_data.get("expected_outputs", []),
|
||||
dependencies=step_data.get("dependencies", []),
|
||||
requires_approval=step_data.get("requires_approval", False),
|
||||
approval_message=step_data.get("approval_message"),
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", "plan"),
|
||||
goal_id=data.get("goal_id", ""),
|
||||
description=data.get("description", ""),
|
||||
steps=steps,
|
||||
context=data.get("context", {}),
|
||||
revision=data.get("revision", 1),
|
||||
)
|
||||
|
||||
def get_step(self, step_id: str) -> PlanStep | None:
|
||||
"""Get a step by ID."""
|
||||
for step in self.steps:
|
||||
if step.id == step_id:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_ready_steps(self) -> list[PlanStep]:
|
||||
"""Get all steps that are ready to execute.
|
||||
|
||||
A step is ready when all its dependencies are in terminal states
|
||||
(completed, failed, skipped, or rejected).
|
||||
"""
|
||||
terminal_ids = {s.id for s in self.steps if s.status.is_terminal()}
|
||||
return [s for s in self.steps if s.is_ready(terminal_ids)]
|
||||
|
||||
def get_completed_steps(self) -> list[PlanStep]:
|
||||
"""Get all completed steps."""
|
||||
return [s for s in self.steps if s.status == StepStatus.COMPLETED]
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all steps are in terminal states (finished executing).
|
||||
|
||||
Returns True when all steps have reached a terminal state, regardless
|
||||
of whether they succeeded or failed. Use has_failed_steps() to check
|
||||
if any steps failed.
|
||||
"""
|
||||
return all(s.status.is_terminal() for s in self.steps)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if all steps completed successfully."""
|
||||
return all(s.status == StepStatus.COMPLETED for s in self.steps)
|
||||
|
||||
def has_failed_steps(self) -> bool:
|
||||
"""Check if any steps failed, were skipped, or were rejected."""
|
||||
return any(
|
||||
s.status in (StepStatus.FAILED, StepStatus.SKIPPED, StepStatus.REJECTED)
|
||||
for s in self.steps
|
||||
)
|
||||
|
||||
def get_failed_steps(self) -> list[PlanStep]:
|
||||
"""Get all steps that failed, were skipped, or were rejected."""
|
||||
return [
|
||||
s
|
||||
for s in self.steps
|
||||
if s.status in (StepStatus.FAILED, StepStatus.SKIPPED, StepStatus.REJECTED)
|
||||
]
|
||||
|
||||
def to_feedback_context(self) -> dict[str, Any]:
|
||||
"""Create context for replanning."""
|
||||
return {
|
||||
"plan_id": self.id,
|
||||
"revision": self.revision,
|
||||
"completed_steps": [
|
||||
{
|
||||
"id": s.id,
|
||||
"description": s.description,
|
||||
"result": s.result,
|
||||
}
|
||||
for s in self.get_completed_steps()
|
||||
],
|
||||
"failed_steps": [
|
||||
{
|
||||
"id": s.id,
|
||||
"description": s.description,
|
||||
"error": s.error,
|
||||
"attempts": s.attempts,
|
||||
}
|
||||
for s in self.steps
|
||||
if s.status == StepStatus.FAILED
|
||||
],
|
||||
"context": self.context,
|
||||
}
|
||||
|
||||
|
||||
class ExecutionStatus(StrEnum):
|
||||
"""Status of plan execution."""
|
||||
|
||||
COMPLETED = "completed"
|
||||
AWAITING_APPROVAL = "awaiting_approval" # Paused for human approval
|
||||
NEEDS_REPLAN = "needs_replan"
|
||||
NEEDS_ESCALATION = "needs_escalation"
|
||||
REJECTED = "rejected" # Human rejected a step
|
||||
ABORTED = "aborted" # Human aborted execution
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PlanExecutionResult(BaseModel):
|
||||
"""
|
||||
Result of executing a plan.
|
||||
|
||||
Returned to external planner with status and feedback.
|
||||
"""
|
||||
|
||||
status: ExecutionStatus
|
||||
|
||||
# Results from completed steps
|
||||
results: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# For needs_replan - what to tell the planner
|
||||
feedback: str | None = None
|
||||
feedback_context: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Steps that completed before stopping
|
||||
completed_steps: list[str] = Field(default_factory=list)
|
||||
|
||||
# Metrics
|
||||
steps_executed: int = 0
|
||||
total_tokens: int = 0
|
||||
total_latency_ms: int = 0
|
||||
|
||||
# Error info (for failed status)
|
||||
error: str | None = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
def load_export(data: str | dict) -> tuple["Plan", Any]:
|
||||
"""
|
||||
Load both Plan and Goal from export_graph() output.
|
||||
|
||||
The export_graph() MCP tool returns both the plan and the goal that was
|
||||
defined and approved during the agent building process. This function
|
||||
loads both so you can use them with FlexibleGraphExecutor.
|
||||
|
||||
Args:
|
||||
data: JSON string or dict from export_graph()
|
||||
|
||||
Returns:
|
||||
Tuple of (Plan, Goal) ready for FlexibleGraphExecutor
|
||||
|
||||
Example:
|
||||
# Load from export_graph() output
|
||||
exported = export_graph()
|
||||
plan, goal = load_export(exported)
|
||||
|
||||
result = await executor.execute_plan(plan, goal, context)
|
||||
"""
|
||||
import json as json_module
|
||||
|
||||
from framework.graph.goal import Goal
|
||||
|
||||
if isinstance(data, str):
|
||||
data = json_module.loads(data)
|
||||
|
||||
# Load plan
|
||||
plan = Plan.from_json(data)
|
||||
|
||||
# Load goal
|
||||
goal_data = data.get("goal", {})
|
||||
if goal_data:
|
||||
goal = Goal.model_validate(goal_data)
|
||||
else:
|
||||
# Fallback: create minimal goal from plan metadata
|
||||
goal = Goal(
|
||||
id=plan.goal_id,
|
||||
name=plan.goal_id,
|
||||
description=plan.description,
|
||||
success_criteria=[],
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
return plan, goal
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Prompt composition for continuous agent mode.
|
||||
|
||||
Composes the three-layer system prompt (onion model) and generates
|
||||
transition markers inserted into the conversation at phase boundaries.
|
||||
|
||||
Layer 1 — Identity (static, defined at agent level, never changes):
|
||||
"You are a thorough research agent. You prefer clarity over jargon..."
|
||||
|
||||
Layer 2 — Narrative (auto-generated from conversation/memory state):
|
||||
"We've finished scoping the project. The user wants to focus on..."
|
||||
|
||||
Layer 3 — Focus (per-node system_prompt, reframed as focus directive):
|
||||
"Your current attention: synthesize findings into a report..."
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.node import NodeSpec, SharedMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compose_system_prompt(
|
||||
identity_prompt: str | None,
|
||||
focus_prompt: str | None,
|
||||
narrative: str | None = None,
|
||||
) -> str:
|
||||
"""Compose the three-layer system prompt.
|
||||
|
||||
Args:
|
||||
identity_prompt: Layer 1 — static agent identity (from GraphSpec).
|
||||
focus_prompt: Layer 3 — per-node focus directive (from NodeSpec.system_prompt).
|
||||
narrative: Layer 2 — auto-generated from conversation state.
|
||||
|
||||
Returns:
|
||||
Composed system prompt with all layers present.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
# Layer 1: Identity (always first, anchors the personality)
|
||||
if identity_prompt:
|
||||
parts.append(identity_prompt)
|
||||
|
||||
# Layer 2: Narrative (what's happened so far)
|
||||
if narrative:
|
||||
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
|
||||
|
||||
# Layer 3: Focus (current phase directive)
|
||||
if focus_prompt:
|
||||
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
|
||||
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def build_narrative(
|
||||
memory: SharedMemory,
|
||||
execution_path: list[str],
|
||||
graph: GraphSpec,
|
||||
) -> str:
|
||||
"""Build Layer 2 (narrative) from structured state.
|
||||
|
||||
Deterministic — no LLM call. Reads SharedMemory and execution path
|
||||
to describe what has happened so far. Cheap and fast.
|
||||
|
||||
Args:
|
||||
memory: Current shared memory state.
|
||||
execution_path: List of node IDs visited so far.
|
||||
graph: Graph spec (for node names/descriptions).
|
||||
|
||||
Returns:
|
||||
Narrative string describing the session state.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
# Describe execution path
|
||||
if execution_path:
|
||||
phase_descriptions: list[str] = []
|
||||
for node_id in execution_path:
|
||||
node_spec = graph.get_node(node_id)
|
||||
if node_spec:
|
||||
phase_descriptions.append(f"- {node_spec.name}: {node_spec.description}")
|
||||
else:
|
||||
phase_descriptions.append(f"- {node_id}")
|
||||
parts.append("Phases completed:\n" + "\n".join(phase_descriptions))
|
||||
|
||||
# Describe key memory values (skip very long values)
|
||||
all_memory = memory.read_all()
|
||||
if all_memory:
|
||||
memory_lines: list[str] = []
|
||||
for key, value in all_memory.items():
|
||||
if value is None:
|
||||
continue
|
||||
val_str = str(value)
|
||||
if len(val_str) > 200:
|
||||
val_str = val_str[:200] + "..."
|
||||
memory_lines.append(f"- {key}: {val_str}")
|
||||
if memory_lines:
|
||||
parts.append("Current state:\n" + "\n".join(memory_lines))
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def build_transition_marker(
|
||||
previous_node: NodeSpec,
|
||||
next_node: NodeSpec,
|
||||
memory: SharedMemory,
|
||||
cumulative_tool_names: list[str],
|
||||
data_dir: Path | str | None = None,
|
||||
) -> str:
|
||||
"""Build a 'State of the World' transition marker.
|
||||
|
||||
Inserted into the conversation as a user message at phase boundaries.
|
||||
Gives the LLM full situational awareness: what happened, what's stored,
|
||||
what tools are available, and what to focus on next.
|
||||
|
||||
Args:
|
||||
previous_node: NodeSpec of the phase just completed.
|
||||
next_node: NodeSpec of the phase about to start.
|
||||
memory: Current shared memory state.
|
||||
cumulative_tool_names: All tools available (cumulative set).
|
||||
data_dir: Path to spillover data directory.
|
||||
|
||||
Returns:
|
||||
Transition marker message text.
|
||||
"""
|
||||
sections: list[str] = []
|
||||
|
||||
# Header
|
||||
sections.append(f"--- PHASE TRANSITION: {previous_node.name} → {next_node.name} ---")
|
||||
|
||||
# What just completed
|
||||
sections.append(f"\nCompleted: {previous_node.name}")
|
||||
sections.append(f" {previous_node.description}")
|
||||
|
||||
# Outputs in memory
|
||||
all_memory = memory.read_all()
|
||||
if all_memory:
|
||||
memory_lines: list[str] = []
|
||||
for key, value in all_memory.items():
|
||||
if value is None:
|
||||
continue
|
||||
val_str = str(value)
|
||||
if len(val_str) > 300:
|
||||
val_str = val_str[:300] + "..."
|
||||
memory_lines.append(f" {key}: {val_str}")
|
||||
if memory_lines:
|
||||
sections.append("\nOutputs available:\n" + "\n".join(memory_lines))
|
||||
|
||||
# Files in data directory
|
||||
if data_dir:
|
||||
data_path = Path(data_dir)
|
||||
if data_path.exists():
|
||||
files = sorted(data_path.iterdir())
|
||||
if files:
|
||||
file_lines = [
|
||||
f" {f.name} ({f.stat().st_size:,} bytes)" for f in files if f.is_file()
|
||||
]
|
||||
if file_lines:
|
||||
sections.append(
|
||||
"\nData files (use load_data to access):\n" + "\n".join(file_lines)
|
||||
)
|
||||
|
||||
# Available tools
|
||||
if cumulative_tool_names:
|
||||
sections.append("\nAvailable tools: " + ", ".join(sorted(cumulative_tool_names)))
|
||||
|
||||
# Next phase
|
||||
sections.append(f"\nNow entering: {next_node.name}")
|
||||
sections.append(f" {next_node.description}")
|
||||
|
||||
# Reflection prompt (engineered metacognition)
|
||||
sections.append(
|
||||
"\nBefore proceeding, briefly reflect: what went well in the "
|
||||
"previous phase? Are there any gaps or surprises worth noting?"
|
||||
)
|
||||
|
||||
sections.append("\n--- END TRANSITION ---")
|
||||
|
||||
return "\n".join(sections)
|
||||
@@ -145,7 +145,7 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
# value.attr
|
||||
# STIRCT CHECK: No access to private attributes (starting with _)
|
||||
# STRICT CHECK: No access to private attributes (starting with _)
|
||||
if node.attr.startswith("_"):
|
||||
raise ValueError(f"Access to private attribute '{node.attr}' is not allowed")
|
||||
|
||||
|
||||
@@ -1,620 +0,0 @@
|
||||
"""
|
||||
Worker Node for Executing Plan Steps.
|
||||
|
||||
The Worker executes individual plan steps by dispatching to the
|
||||
appropriate executor based on action type:
|
||||
- LLM calls
|
||||
- Tool usage
|
||||
- Sub-graph execution
|
||||
- Function calls
|
||||
- Code execution (sandboxed)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.code_sandbox import CodeSandbox
|
||||
from framework.graph.plan import (
|
||||
ActionSpec,
|
||||
ActionType,
|
||||
PlanStep,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
|
||||
"""
|
||||
Parse JSON from LLM response, handling markdown code blocks.
|
||||
|
||||
LLMs often return JSON wrapped in markdown code blocks like:
|
||||
```json
|
||||
{"key": "value"}
|
||||
```
|
||||
|
||||
This function extracts and parses the JSON.
|
||||
|
||||
Args:
|
||||
text: Raw LLM response text
|
||||
|
||||
Returns:
|
||||
Tuple of (parsed_json_or_None, cleaned_text)
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return None, str(text)
|
||||
|
||||
cleaned = text.strip()
|
||||
|
||||
# Try to extract JSON from markdown code blocks
|
||||
# Pattern: ```json ... ``` or ``` ... ```
|
||||
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
|
||||
matches = re.findall(code_block_pattern, cleaned)
|
||||
|
||||
if matches:
|
||||
# Try to parse each match
|
||||
for match in matches:
|
||||
try:
|
||||
parsed = json.loads(match.strip())
|
||||
return parsed, match.strip()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(
|
||||
f"Failed to parse JSON from code block: {e}. "
|
||||
f"Content preview: {match.strip()[:100]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
# No code blocks or parsing failed - try parsing the whole response
|
||||
try:
|
||||
parsed = json.loads(cleaned)
|
||||
return parsed, cleaned
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(
|
||||
f"Failed to parse entire response as JSON: {e}. Content preview: {cleaned[:100]}..."
|
||||
)
|
||||
|
||||
# Try to find JSON-like content (starts with { or [)
|
||||
json_start_pattern = r"(\{[\s\S]*\}|\[[\s\S]*\])"
|
||||
json_matches = re.findall(json_start_pattern, cleaned)
|
||||
|
||||
for match in json_matches:
|
||||
try:
|
||||
parsed = json.loads(match)
|
||||
return parsed, match
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(f"Failed to parse JSON pattern: {e}. Content preview: {match[:100]}...")
|
||||
continue
|
||||
|
||||
# Could not parse as JSON - log warning
|
||||
logger.warning(
|
||||
f"Could not parse LLM response as JSON after trying all strategies. "
|
||||
f"Response preview: {cleaned[:200]}..."
|
||||
)
|
||||
return None, cleaned
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepExecutionResult:
|
||||
"""Result of executing a plan step."""
|
||||
|
||||
success: bool
|
||||
outputs: dict[str, Any] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
error_type: str | None = None # For judge rules: timeout, rate_limit, etc.
|
||||
|
||||
# Metadata
|
||||
tokens_used: int = 0
|
||||
latency_ms: int = 0
|
||||
executor_type: str = ""
|
||||
|
||||
|
||||
class WorkerNode:
|
||||
"""
|
||||
Executes plan steps by dispatching to appropriate executors.
|
||||
|
||||
Usage:
|
||||
worker = WorkerNode(
|
||||
runtime=runtime,
|
||||
llm=llm_provider,
|
||||
tools=tool_registry,
|
||||
)
|
||||
|
||||
result = await worker.execute(step, context)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runtime: Runtime,
|
||||
llm: LLMProvider | None = None,
|
||||
tools: dict[str, Tool] | None = None,
|
||||
tool_executor: Callable | None = None,
|
||||
functions: dict[str, Callable] | None = None,
|
||||
sub_graph_executor: Callable | None = None,
|
||||
sandbox: CodeSandbox | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Worker.
|
||||
|
||||
Args:
|
||||
runtime: Runtime for decision logging
|
||||
llm: LLM provider for LLM_CALL actions
|
||||
tools: Available tools for TOOL_USE actions
|
||||
tool_executor: Function to execute tools
|
||||
functions: Registered functions for FUNCTION actions
|
||||
sub_graph_executor: Function to execute sub-graphs
|
||||
sandbox: Code sandbox for CODE_EXECUTION actions
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
self.tools = tools or {}
|
||||
self.tool_executor = tool_executor
|
||||
self.functions = functions or {}
|
||||
self.sub_graph_executor = sub_graph_executor
|
||||
self.sandbox = sandbox or CodeSandbox()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
step: PlanStep,
|
||||
context: dict[str, Any],
|
||||
) -> StepExecutionResult:
|
||||
"""
|
||||
Execute a plan step.
|
||||
|
||||
Args:
|
||||
step: The step to execute
|
||||
context: Current execution context
|
||||
|
||||
Returns:
|
||||
StepExecutionResult with outputs and status
|
||||
"""
|
||||
# Record decision
|
||||
decision_id = self.runtime.decide(
|
||||
intent=f"Execute plan step: {step.description}",
|
||||
options=[
|
||||
{
|
||||
"id": step.action.action_type.value,
|
||||
"description": f"Execute {step.action.action_type.value} action",
|
||||
"action_type": step.action.action_type.value,
|
||||
}
|
||||
],
|
||||
chosen=step.action.action_type.value,
|
||||
reasoning=f"Step requires {step.action.action_type.value}",
|
||||
context={"step_id": step.id, "inputs": step.inputs},
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Resolve inputs from context
|
||||
resolved_inputs = self._resolve_inputs(step.inputs, context)
|
||||
|
||||
# Dispatch to appropriate executor
|
||||
result = await self._dispatch(step.action, resolved_inputs, context)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
result.latency_ms = latency_ms
|
||||
|
||||
# Record outcome
|
||||
self.runtime.record_outcome(
|
||||
decision_id=decision_id,
|
||||
success=result.success,
|
||||
result=result.outputs if result.success else result.error,
|
||||
tokens_used=result.tokens_used,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
self.runtime.record_outcome(
|
||||
decision_id=decision_id,
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_type="exception",
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
def _resolve_inputs(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve input references from context."""
|
||||
resolved = {}
|
||||
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, str) and value.startswith("$"):
|
||||
# Reference to context variable
|
||||
ref_key = value[1:] # Remove $
|
||||
resolved[key] = context.get(ref_key, value)
|
||||
else:
|
||||
resolved[key] = value
|
||||
|
||||
return resolved
|
||||
|
||||
async def _dispatch(
|
||||
self,
|
||||
action: ActionSpec,
|
||||
inputs: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> StepExecutionResult:
|
||||
"""Dispatch to appropriate executor based on action type."""
|
||||
if action.action_type == ActionType.LLM_CALL:
|
||||
return await self._execute_llm_call(action, inputs, context)
|
||||
|
||||
elif action.action_type == ActionType.TOOL_USE:
|
||||
return await self._execute_tool_use(action, inputs)
|
||||
|
||||
elif action.action_type == ActionType.SUB_GRAPH:
|
||||
return await self._execute_sub_graph(action, inputs, context)
|
||||
|
||||
elif action.action_type == ActionType.FUNCTION:
|
||||
return await self._execute_function(action, inputs)
|
||||
|
||||
elif action.action_type == ActionType.CODE_EXECUTION:
|
||||
return self._execute_code(action, inputs, context)
|
||||
|
||||
else:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=f"Unknown action type: {action.action_type}",
|
||||
error_type="invalid_action",
|
||||
)
|
||||
|
||||
async def _execute_llm_call(
|
||||
self,
|
||||
action: ActionSpec,
|
||||
inputs: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> StepExecutionResult:
|
||||
"""Execute an LLM call action."""
|
||||
if self.llm is None:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error="No LLM provider configured",
|
||||
error_type="configuration",
|
||||
executor_type="llm_call",
|
||||
)
|
||||
|
||||
try:
|
||||
# Build prompt with context data
|
||||
prompt = action.prompt or ""
|
||||
|
||||
# First try format placeholders (for prompts like "Hello {name}")
|
||||
if inputs:
|
||||
try:
|
||||
prompt = prompt.format(**inputs)
|
||||
except (KeyError, ValueError):
|
||||
pass # Keep original prompt if formatting fails
|
||||
|
||||
# Always append context data so LLM can personalize
|
||||
# This ensures the LLM has access to lead info, company context, etc.
|
||||
if inputs:
|
||||
context_section = "\n\n--- Context Data ---\n"
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, dict | list):
|
||||
context_section += f"{key}: {json.dumps(value, indent=2)}\n"
|
||||
else:
|
||||
context_section += f"{key}: {value}\n"
|
||||
prompt = prompt + context_section
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = self.llm.complete(
|
||||
messages=messages,
|
||||
system=action.system_prompt,
|
||||
)
|
||||
|
||||
# Try to parse JSON from LLM response
|
||||
# LLMs often return JSON wrapped in markdown code blocks
|
||||
parsed_json, _ = parse_llm_json_response(response.content)
|
||||
|
||||
# If JSON was parsed successfully, use it as the result
|
||||
# Otherwise, use the raw text
|
||||
result_value = parsed_json if parsed_json is not None else response.content
|
||||
|
||||
return StepExecutionResult(
|
||||
success=True,
|
||||
outputs={
|
||||
"result": result_value,
|
||||
"response": response.content, # Always keep raw response
|
||||
"parsed_json": parsed_json, # Explicit parsed JSON (or None)
|
||||
},
|
||||
tokens_used=response.input_tokens + response.output_tokens,
|
||||
executor_type="llm_call",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_type = "rate_limit" if "rate" in str(e).lower() else "llm_error"
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
executor_type="llm_call",
|
||||
)
|
||||
|
||||
async def _execute_tool_use(
|
||||
self,
|
||||
action: ActionSpec,
|
||||
inputs: dict[str, Any],
|
||||
) -> StepExecutionResult:
|
||||
"""Execute a tool use action."""
|
||||
tool_name = action.tool_name
|
||||
if not tool_name:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error="No tool name specified",
|
||||
error_type="invalid_action",
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
# Merge action args with inputs
|
||||
args = {**action.tool_args, **inputs}
|
||||
|
||||
# Resolve any $variable references in the merged args
|
||||
# (tool_args may contain $refs that should be resolved from inputs)
|
||||
resolved_args = {}
|
||||
for key, value in args.items():
|
||||
if isinstance(value, str) and value.startswith("$"):
|
||||
ref_key = value[1:] # Remove $
|
||||
resolved_args[key] = args.get(ref_key, value)
|
||||
else:
|
||||
resolved_args[key] = value
|
||||
args = resolved_args
|
||||
|
||||
# First, check if we have a registered function with this name
|
||||
# This allows simpler tool registration without full Tool/ToolExecutor setup
|
||||
if tool_name in self.functions:
|
||||
try:
|
||||
func = self.functions[tool_name]
|
||||
result = func(**args)
|
||||
|
||||
# Handle async functions
|
||||
if hasattr(result, "__await__"):
|
||||
result = await result
|
||||
|
||||
# If result is already a dict with success/outputs, use it directly
|
||||
if isinstance(result, dict) and "success" in result:
|
||||
return StepExecutionResult(
|
||||
success=result.get("success", False),
|
||||
outputs=result.get("outputs", {}),
|
||||
error=result.get("error"),
|
||||
error_type=result.get("error_type"),
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
# Otherwise wrap the result
|
||||
return StepExecutionResult(
|
||||
success=True,
|
||||
outputs={"result": result},
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_type="tool_exception",
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
# Fall back to formal Tool registry
|
||||
if tool_name not in self.tools:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool '{tool_name}' not found",
|
||||
error_type="missing_tool",
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
if self.tool_executor is None:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error="No tool executor configured",
|
||||
error_type="configuration",
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute tool via formal executor
|
||||
from framework.llm.provider import ToolUse
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=f"step_{tool_name}",
|
||||
name=tool_name,
|
||||
input=args,
|
||||
)
|
||||
|
||||
result = self.tool_executor(tool_use)
|
||||
|
||||
if result.is_error:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
outputs={},
|
||||
error=result.content,
|
||||
error_type="tool_error",
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
# Parse JSON result and unpack fields into outputs
|
||||
# Tools return JSON like {"lead_email": "...", "company_name": "..."}
|
||||
# We want each field as a separate output key
|
||||
outputs = {"result": result.content}
|
||||
try:
|
||||
parsed = json.loads(result.content)
|
||||
if isinstance(parsed, dict):
|
||||
# Unpack all fields from the JSON response
|
||||
outputs.update(parsed)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass # Keep result as-is if not valid JSON
|
||||
|
||||
return StepExecutionResult(
|
||||
success=True,
|
||||
outputs=outputs,
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_type="tool_exception",
|
||||
executor_type="tool_use",
|
||||
)
|
||||
|
||||
async def _execute_sub_graph(
|
||||
self,
|
||||
action: ActionSpec,
|
||||
inputs: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> StepExecutionResult:
|
||||
"""Execute a sub-graph action."""
|
||||
if self.sub_graph_executor is None:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error="No sub-graph executor configured",
|
||||
error_type="configuration",
|
||||
executor_type="sub_graph",
|
||||
)
|
||||
|
||||
graph_id = action.graph_id
|
||||
if not graph_id:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error="No graph ID specified",
|
||||
error_type="invalid_action",
|
||||
executor_type="sub_graph",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self.sub_graph_executor(graph_id, inputs, context)
|
||||
|
||||
return StepExecutionResult(
|
||||
success=result.success,
|
||||
outputs=result.output if result.success else {},
|
||||
error=result.error if not result.success else None,
|
||||
tokens_used=result.total_tokens,
|
||||
executor_type="sub_graph",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_type="sub_graph_exception",
|
||||
executor_type="sub_graph",
|
||||
)
|
||||
|
||||
async def _execute_function(
|
||||
self,
|
||||
action: ActionSpec,
|
||||
inputs: dict[str, Any],
|
||||
) -> StepExecutionResult:
|
||||
"""Execute a function action."""
|
||||
func_name = action.function_name
|
||||
if not func_name:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error="No function name specified",
|
||||
error_type="invalid_action",
|
||||
executor_type="function",
|
||||
)
|
||||
|
||||
if func_name not in self.functions:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=f"Function '{func_name}' not registered",
|
||||
error_type="missing_function",
|
||||
executor_type="function",
|
||||
)
|
||||
|
||||
try:
|
||||
func = self.functions[func_name]
|
||||
|
||||
# Merge action args with inputs
|
||||
args = {**action.function_args, **inputs}
|
||||
|
||||
# Execute function
|
||||
result = func(**args)
|
||||
|
||||
# Handle async functions
|
||||
if hasattr(result, "__await__"):
|
||||
result = await result
|
||||
|
||||
return StepExecutionResult(
|
||||
success=True,
|
||||
outputs={"result": result},
|
||||
executor_type="function",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_type="function_exception",
|
||||
executor_type="function",
|
||||
)
|
||||
|
||||
def _execute_code(
|
||||
self,
|
||||
action: ActionSpec,
|
||||
inputs: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> StepExecutionResult:
|
||||
"""Execute a code action in sandbox."""
|
||||
code = action.code
|
||||
if not code:
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error="No code specified",
|
||||
error_type="invalid_action",
|
||||
executor_type="code_execution",
|
||||
)
|
||||
|
||||
# Merge inputs with context for code
|
||||
code_inputs = {**context, **inputs}
|
||||
|
||||
# Execute in sandbox
|
||||
sandbox_result = self.sandbox.execute(code, code_inputs)
|
||||
|
||||
if sandbox_result.success:
|
||||
return StepExecutionResult(
|
||||
success=True,
|
||||
outputs={
|
||||
"result": sandbox_result.result,
|
||||
**sandbox_result.variables,
|
||||
},
|
||||
executor_type="code_execution",
|
||||
latency_ms=sandbox_result.execution_time_ms,
|
||||
)
|
||||
else:
|
||||
error_type = "security" if "Security" in (sandbox_result.error or "") else "code_error"
|
||||
return StepExecutionResult(
|
||||
success=False,
|
||||
error=sandbox_result.error,
|
||||
error_type=error_type,
|
||||
executor_type="code_execution",
|
||||
latency_ms=sandbox_result.execution_time_ms,
|
||||
)
|
||||
|
||||
def register_function(self, name: str, func: Callable) -> None:
|
||||
"""Register a function for FUNCTION actions."""
|
||||
self.functions[name] = func
|
||||
|
||||
def register_tool(self, tool: Tool) -> None:
|
||||
"""Register a tool for TOOL_USE actions."""
|
||||
self.tools[tool.name] = tool
|
||||
@@ -70,6 +70,7 @@ class AnthropicProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion from Claude (via LiteLLM)."""
|
||||
return self._provider.complete(
|
||||
@@ -79,6 +80,7 @@ class AnthropicProvider(LLMProvider):
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
@@ -97,3 +99,41 @@ class AnthropicProvider(LLMProvider):
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async completion via LiteLLM."""
|
||||
return await self._provider.acomplete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Async tool-use loop via LiteLLM."""
|
||||
return await self._provider.acomplete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
+383
-16
@@ -30,6 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
|
||||
# Directory for dumping failed requests
|
||||
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
@@ -84,6 +85,91 @@ def _dump_failed_request(
|
||||
return str(filepath)
|
||||
|
||||
|
||||
def _compute_retry_delay(
|
||||
attempt: int,
|
||||
exception: BaseException | None = None,
|
||||
backoff_base: int = RATE_LIMIT_BACKOFF_BASE,
|
||||
max_delay: int = RATE_LIMIT_MAX_DELAY,
|
||||
) -> float:
|
||||
"""Compute retry delay, preferring server-provided Retry-After headers.
|
||||
|
||||
Priority:
|
||||
1. retry-after-ms header (milliseconds, float)
|
||||
2. retry-after header as seconds (float)
|
||||
3. retry-after header as HTTP-date (RFC 7231)
|
||||
4. Exponential backoff: backoff_base * 2^attempt
|
||||
|
||||
All values are capped at max_delay seconds.
|
||||
"""
|
||||
if exception is not None:
|
||||
response = getattr(exception, "response", None)
|
||||
if response is not None:
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is not None:
|
||||
# Priority 1: retry-after-ms (milliseconds)
|
||||
retry_after_ms = headers.get("retry-after-ms")
|
||||
if retry_after_ms is not None:
|
||||
try:
|
||||
delay = float(retry_after_ms) / 1000.0
|
||||
return min(max(delay, 0), max_delay)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Priority 2: retry-after (seconds or HTTP-date)
|
||||
retry_after = headers.get("retry-after")
|
||||
if retry_after is not None:
|
||||
# Try as seconds (float)
|
||||
try:
|
||||
delay = float(retry_after)
|
||||
return min(max(delay, 0), max_delay)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Try as HTTP-date (e.g., "Fri, 31 Dec 2025 23:59:59 GMT")
|
||||
try:
|
||||
from email.utils import parsedate_to_datetime
|
||||
|
||||
retry_date = parsedate_to_datetime(retry_after)
|
||||
now = datetime.now(retry_date.tzinfo)
|
||||
delay = (retry_date - now).total_seconds()
|
||||
return min(max(delay, 0), max_delay)
|
||||
except (ValueError, TypeError, OverflowError):
|
||||
pass
|
||||
|
||||
# Fallback: exponential backoff
|
||||
delay = backoff_base * (2**attempt)
|
||||
return min(delay, max_delay)
|
||||
|
||||
|
||||
def _is_stream_transient_error(exc: BaseException) -> bool:
|
||||
"""Classify whether a streaming exception is transient (recoverable).
|
||||
|
||||
Transient errors (recoverable=True): network issues, server errors, timeouts.
|
||||
Permanent errors (recoverable=False): auth, bad request, context window, etc.
|
||||
"""
|
||||
try:
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
|
||||
transient_types: tuple[type[BaseException], ...] = (
|
||||
APIConnectionError,
|
||||
InternalServerError,
|
||||
BadGatewayError,
|
||||
ServiceUnavailableError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
OSError,
|
||||
)
|
||||
except ImportError:
|
||||
transient_types = (TimeoutError, ConnectionError, OSError)
|
||||
|
||||
return isinstance(exc, transient_types)
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""
|
||||
LiteLLM-based LLM provider for multi-provider support.
|
||||
@@ -150,10 +236,13 @@ class LiteLLMProvider(LLMProvider):
|
||||
"LiteLLM is not installed. Please install it with: uv pip install litellm"
|
||||
)
|
||||
|
||||
def _completion_with_rate_limit_retry(self, **kwargs: Any) -> Any:
|
||||
def _completion_with_rate_limit_retry(
|
||||
self, max_retries: int | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Call litellm.completion with retry on 429 rate limit errors and empty responses."""
|
||||
model = kwargs.get("model", self.model)
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
retries = max_retries if max_retries is not None else RATE_LIMIT_MAX_RETRIES
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = litellm.completion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
@@ -194,22 +283,22 @@ class LiteLLMProvider(LLMProvider):
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
|
||||
if attempt == RATE_LIMIT_MAX_RETRIES:
|
||||
if attempt == retries:
|
||||
logger.error(
|
||||
f"[retry] GAVE UP on {model} after {RATE_LIMIT_MAX_RETRIES + 1} "
|
||||
f"[retry] GAVE UP on {model} after {retries + 1} "
|
||||
f"attempts — empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0})"
|
||||
)
|
||||
return response
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
wait = _compute_retry_delay(attempt)
|
||||
logger.warning(
|
||||
f"[retry] {model} returned empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0}) — "
|
||||
f"likely rate limited or quota exceeded. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
f"(attempt {attempt + 1}/{retries})"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
@@ -225,21 +314,21 @@ class LiteLLMProvider(LLMProvider):
|
||||
error_type="rate_limit",
|
||||
attempt=attempt,
|
||||
)
|
||||
if attempt == RATE_LIMIT_MAX_RETRIES:
|
||||
if attempt == retries:
|
||||
logger.error(
|
||||
f"[retry] GAVE UP on {model} after {RATE_LIMIT_MAX_RETRIES + 1} "
|
||||
f"[retry] GAVE UP on {model} after {retries + 1} "
|
||||
f"attempts — rate limit error: {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
raise
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
f"[retry] {model} rate limited (429): {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
f"(attempt {attempt + 1}/{retries})"
|
||||
)
|
||||
time.sleep(wait)
|
||||
# unreachable, but satisfies type checker
|
||||
@@ -253,6 +342,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion using LiteLLM."""
|
||||
# Prepare messages with system prompt
|
||||
@@ -293,12 +383,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
# Make the call
|
||||
response = self._completion_with_rate_limit_retry(**kwargs)
|
||||
response = self._completion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
|
||||
|
||||
# Extract content
|
||||
content = response.choices[0].message.content or ""
|
||||
|
||||
# Get usage info
|
||||
# Get usage info.
|
||||
# NOTE: completion_tokens includes reasoning/thinking tokens for models
|
||||
# that use them (o1, gpt-5-mini, etc.). LiteLLM does not reliably expose
|
||||
# usage.completion_tokens_details.reasoning_tokens across all providers.
|
||||
# This means output_tokens may be inflated for reasoning models.
|
||||
# Compaction is unaffected — it uses prompt_tokens (input-side only).
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
@@ -433,6 +528,267 @@ class LiteLLMProvider(LLMProvider):
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Async variants — non-blocking on the event loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _acompletion_with_rate_limit_retry(
|
||||
self, max_retries: int | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Async version of _completion_with_rate_limit_retry.
|
||||
|
||||
Uses litellm.acompletion and asyncio.sleep instead of blocking calls.
|
||||
"""
|
||||
model = kwargs.get("model", self.model)
|
||||
retries = max_retries if max_retries is not None else RATE_LIMIT_MAX_RETRIES
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
content = response.choices[0].message.content if response.choices else None
|
||||
has_tool_calls = bool(response.choices and response.choices[0].message.tool_calls)
|
||||
if not content and not has_tool_calls:
|
||||
messages = kwargs.get("messages", [])
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
logger.debug(
|
||||
"[async-retry] Empty response after assistant message — "
|
||||
"expected, not retrying."
|
||||
)
|
||||
return response
|
||||
|
||||
finish_reason = (
|
||||
response.choices[0].finish_reason if response.choices else "unknown"
|
||||
)
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_response",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[async-retry] Empty response - {len(messages)} messages, "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
|
||||
if attempt == retries:
|
||||
logger.error(
|
||||
f"[async-retry] GAVE UP on {model} after {retries + 1} "
|
||||
f"attempts — empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0})"
|
||||
)
|
||||
return response
|
||||
wait = _compute_retry_delay(attempt)
|
||||
logger.warning(
|
||||
f"[async-retry] {model} returned empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0}) — "
|
||||
f"likely rate limited or quota exceeded. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{retries})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
return response
|
||||
except RateLimitError as e:
|
||||
messages = kwargs.get("messages", [])
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
error_type="rate_limit",
|
||||
attempt=attempt,
|
||||
)
|
||||
if attempt == retries:
|
||||
logger.error(
|
||||
f"[async-retry] GAVE UP on {model} after {retries + 1} "
|
||||
f"attempts — rate limit error: {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
raise
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
f"[async-retry] {model} rate limited (429): {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{retries})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
raise RuntimeError("Exhausted rate limit retries")
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete(). Uses litellm.acompletion — non-blocking."""
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
if json_mode:
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
else:
|
||||
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=response.choices[0].finish_reason or "",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 4096,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete_with_tools(). Uses litellm.acompletion — non-blocking."""
|
||||
current_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
current_messages.append({"role": "system", "content": system})
|
||||
current_messages.extend(messages)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
openai_tools = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for _ in range(max_iterations):
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(**kwargs)
|
||||
|
||||
usage = response.usage
|
||||
if usage:
|
||||
total_input_tokens += usage.prompt_tokens
|
||||
total_output_tokens += usage.completion_tokens
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
if choice.finish_reason == "stop" or not message.tool_calls:
|
||||
return LLMResponse(
|
||||
content=message.content or "",
|
||||
model=response.model or self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason=choice.finish_reason or "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=args,
|
||||
)
|
||||
|
||||
result = tool_executor(tool_use)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content="Max tool iterations reached",
|
||||
model=self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason="max_iterations",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def _tool_to_openai_format(self, tool: Tool) -> dict[str, Any]:
|
||||
"""Convert Tool to OpenAI function calling format."""
|
||||
return {
|
||||
@@ -591,7 +947,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
wait = _compute_retry_delay(attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
@@ -619,10 +975,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
except RateLimitError as e:
|
||||
if attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"Retrying in {wait:.1f}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
@@ -631,5 +987,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
if _is_stream_transient_error(e) and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} transient error "
|
||||
f"({type(e).__name__}): {e!s}. "
|
||||
f"Retrying in {wait:.1f}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
recoverable = _is_stream_transient_error(e)
|
||||
yield StreamErrorEvent(error=str(e), recoverable=recoverable)
|
||||
return
|
||||
|
||||
@@ -120,6 +120,7 @@ class MockLLMProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a mock completion without calling a real LLM.
|
||||
@@ -182,6 +183,44 @@ class MockLLMProvider(LLMProvider):
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async mock completion (no I/O, returns immediately)."""
|
||||
return self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Async mock tool-use completion (no I/O, returns immediately)."""
|
||||
return self.complete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -65,6 +67,7 @@ class LLMProvider(ABC):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
@@ -79,6 +82,8 @@ class LLMProvider(ABC):
|
||||
- {"type": "json_schema", "json_schema": {"name": "...", "schema": {...}}}
|
||||
for strict JSON schema enforcement
|
||||
json_mode: If True, request structured JSON output from the LLM
|
||||
max_retries: Override retry count for rate-limit/empty-response retries.
|
||||
None uses the provider default.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and metadata
|
||||
@@ -109,6 +114,62 @@ class LLMProvider(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list["Tool"] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> "LLMResponse":
|
||||
"""Async version of complete(). Non-blocking on the event loop.
|
||||
|
||||
Default implementation offloads the sync complete() to a thread pool.
|
||||
Subclasses SHOULD override for native async I/O.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.complete,
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
),
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list["Tool"],
|
||||
tool_executor: Callable[["ToolUse"], "ToolResult"],
|
||||
max_iterations: int = 10,
|
||||
) -> "LLMResponse":
|
||||
"""Async version of complete_with_tools(). Non-blocking on the event loop.
|
||||
|
||||
Default implementation offloads the sync complete_with_tools() to a thread pool.
|
||||
Subclasses SHOULD override for native async I/O.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.complete_with_tools,
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
),
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -132,7 +193,7 @@ class LLMProvider(ABC):
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
response = self.complete(
|
||||
response = await self.acomplete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
|
||||
@@ -36,12 +36,7 @@ from framework.graph import ( # noqa: E402
|
||||
NodeSpec,
|
||||
SuccessCriterion,
|
||||
)
|
||||
from framework.graph.plan import Plan # noqa: E402
|
||||
|
||||
# Testing framework imports
|
||||
from framework.testing.prompts import ( # noqa: E402
|
||||
PYTEST_TEST_FILE_HEADER,
|
||||
)
|
||||
from framework.testing.prompts import PYTEST_TEST_FILE_HEADER # noqa: E402
|
||||
from framework.utils.io import atomic_write # noqa: E402
|
||||
|
||||
# Initialize MCP server
|
||||
@@ -587,13 +582,12 @@ def add_node(
|
||||
description: Annotated[str, "What this node does"],
|
||||
node_type: Annotated[
|
||||
str,
|
||||
"Type: event_loop (recommended), function, router. "
|
||||
"Deprecated: llm_generate, llm_tool_use (use event_loop instead)",
|
||||
"Type: event_loop (recommended), router.",
|
||||
],
|
||||
input_keys: Annotated[str, "JSON array of keys this node reads from shared memory"],
|
||||
output_keys: Annotated[str, "JSON array of keys this node writes to shared memory"],
|
||||
system_prompt: Annotated[str, "Instructions for LLM nodes"] = "",
|
||||
tools: Annotated[str, "JSON array of tool names for event_loop or llm_tool_use nodes"] = "[]",
|
||||
tools: Annotated[str, "JSON array of tool names for event_loop nodes"] = "[]",
|
||||
routes: Annotated[
|
||||
str, "JSON object mapping conditions to target node IDs for router nodes"
|
||||
] = "{}",
|
||||
@@ -665,24 +659,18 @@ def add_node(
|
||||
errors.append("Node must have an id")
|
||||
if not name:
|
||||
errors.append("Node must have a name")
|
||||
if node_type == "llm_tool_use" and not tools_list:
|
||||
errors.append(f"Node '{node_id}' of type llm_tool_use must specify tools")
|
||||
|
||||
# Reject removed node types
|
||||
if node_type in ("function", "llm_tool_use", "llm_generate"):
|
||||
errors.append(f"Node type '{node_type}' is no longer supported. Use 'event_loop' instead.")
|
||||
|
||||
if node_type == "router" and not routes_dict:
|
||||
errors.append(f"Router node '{node_id}' must specify routes")
|
||||
if node_type in ("llm_generate", "llm_tool_use") and not system_prompt:
|
||||
warnings.append(f"LLM node '{node_id}' should have a system_prompt")
|
||||
|
||||
# EventLoopNode validation
|
||||
if node_type == "event_loop" and not system_prompt:
|
||||
warnings.append(f"Event loop node '{node_id}' should have a system_prompt")
|
||||
|
||||
# Deprecated type warnings
|
||||
if node_type in ("llm_generate", "llm_tool_use"):
|
||||
warnings.append(
|
||||
f"Node type '{node_type}' is deprecated. Use 'event_loop' instead. "
|
||||
"EventLoopNode supports tool use, streaming, and judge-based evaluation."
|
||||
)
|
||||
|
||||
# Warn about client_facing on nodes with tools (likely autonomous work)
|
||||
if node_type == "event_loop" and client_facing and tools_list:
|
||||
warnings.append(
|
||||
@@ -838,8 +826,7 @@ def update_node(
|
||||
description: Annotated[str, "Updated description"] = "",
|
||||
node_type: Annotated[
|
||||
str,
|
||||
"Updated type: event_loop (recommended), function, router. "
|
||||
"Deprecated: llm_generate, llm_tool_use",
|
||||
"Updated type: event_loop (recommended), router.",
|
||||
] = "",
|
||||
input_keys: Annotated[str, "Updated JSON array of input keys"] = "",
|
||||
output_keys: Annotated[str, "Updated JSON array of output keys"] = "",
|
||||
@@ -919,24 +906,19 @@ def update_node(
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
if node.node_type == "llm_tool_use" and not node.tools:
|
||||
errors.append(f"Node '{node_id}' of type llm_tool_use must specify tools")
|
||||
# Reject removed node types
|
||||
if node.node_type in ("function", "llm_tool_use", "llm_generate"):
|
||||
errors.append(
|
||||
f"Node type '{node.node_type}' is no longer supported. Use 'event_loop' instead."
|
||||
)
|
||||
|
||||
if node.node_type == "router" and not node.routes:
|
||||
errors.append(f"Router node '{node_id}' must specify routes")
|
||||
if node.node_type in ("llm_generate", "llm_tool_use") and not node.system_prompt:
|
||||
warnings.append(f"LLM node '{node_id}' should have a system_prompt")
|
||||
|
||||
# EventLoopNode validation
|
||||
if node.node_type == "event_loop" and not node.system_prompt:
|
||||
warnings.append(f"Event loop node '{node_id}' should have a system_prompt")
|
||||
|
||||
# Deprecated type warnings
|
||||
if node.node_type in ("llm_generate", "llm_tool_use"):
|
||||
warnings.append(
|
||||
f"Node type '{node.node_type}' is deprecated. Use 'event_loop' instead. "
|
||||
"EventLoopNode supports tool use, streaming, and judge-based evaluation."
|
||||
)
|
||||
|
||||
# nullable_output_keys must be a subset of output_keys
|
||||
if node.nullable_output_keys:
|
||||
invalid_nullable = [k for k in node.nullable_output_keys if k not in node.output_keys]
|
||||
@@ -1112,11 +1094,11 @@ def validate_graph() -> str:
|
||||
if entry_candidates:
|
||||
reachable = set()
|
||||
|
||||
# For pause/resume agents, start from ALL entry points (including resume)
|
||||
if is_pause_resume_agent:
|
||||
to_visit = list(entry_candidates) # All nodes without incoming edges
|
||||
else:
|
||||
to_visit = [entry_candidates[0]] # Just the primary entry
|
||||
# Start from ALL entry candidates (nodes without incoming edges).
|
||||
# This handles both pause/resume agents and async entry point agents
|
||||
# where multiple nodes have no incoming edges (e.g., a primary entry
|
||||
# node and an event-driven entry node).
|
||||
to_visit = list(entry_candidates)
|
||||
|
||||
while to_visit:
|
||||
current = to_visit.pop()
|
||||
@@ -1390,17 +1372,6 @@ def validate_graph() -> str:
|
||||
f"must be a subset of output_keys {node.output_keys}"
|
||||
)
|
||||
|
||||
# Deprecated node type warnings
|
||||
deprecated_nodes = [
|
||||
{"node_id": n.id, "type": n.node_type, "replacement": "event_loop"}
|
||||
for n in session.nodes
|
||||
if n.node_type in ("llm_generate", "llm_tool_use")
|
||||
]
|
||||
for dn in deprecated_nodes:
|
||||
warnings.append(
|
||||
f"Node '{dn['node_id']}' uses deprecated type '{dn['type']}'. Use 'event_loop' instead."
|
||||
)
|
||||
|
||||
# Warn if all event_loop nodes are client_facing (common misconfiguration)
|
||||
el_nodes = [n for n in session.nodes if n.node_type == "event_loop"]
|
||||
cf_el_nodes = [n for n in el_nodes if n.client_facing]
|
||||
@@ -1436,7 +1407,6 @@ def validate_graph() -> str:
|
||||
"event_loop_nodes": event_loop_nodes,
|
||||
"client_facing_nodes": client_facing_nodes,
|
||||
"feedback_edges": feedback_edges,
|
||||
"deprecated_node_types": deprecated_nodes,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1646,9 +1616,8 @@ def export_graph() -> str:
|
||||
"""
|
||||
Export the validated graph as a GraphSpec for GraphExecutor.
|
||||
|
||||
Exports the complete agent definition including nodes, edges, goal,
|
||||
and evaluation rules. The GraphExecutor runs the graph with dynamic
|
||||
edge traversal and routing logic.
|
||||
Exports the complete agent definition including nodes, edges, and goal.
|
||||
The GraphExecutor runs the graph with dynamic edge traversal and routing logic.
|
||||
|
||||
AUTOMATICALLY WRITES FILES TO DISK:
|
||||
- exports/{agent-name}/agent.json - Full agent specification
|
||||
@@ -1856,7 +1825,6 @@ def export_graph() -> str:
|
||||
"files_written": files_written,
|
||||
"graph": graph_spec,
|
||||
"goal": session.goal.model_dump(),
|
||||
"evaluation_rules": _evaluation_rules,
|
||||
"required_tools": list(all_tools),
|
||||
"node_count": len(session.nodes),
|
||||
"edge_count": len(edges_list),
|
||||
@@ -1966,9 +1934,6 @@ def get_session_status() -> str:
|
||||
"mcp_servers": [s["name"] for s in session.mcp_servers],
|
||||
"event_loop_nodes": [n.id for n in session.nodes if n.node_type == "event_loop"],
|
||||
"client_facing_nodes": [n.id for n in session.nodes if n.client_facing],
|
||||
"deprecated_nodes": [
|
||||
n.id for n in session.nodes if n.node_type in ("llm_generate", "llm_tool_use")
|
||||
],
|
||||
"feedback_edges": [e.id for e in session.edges if e.priority < 0],
|
||||
}
|
||||
)
|
||||
@@ -2139,7 +2104,7 @@ def add_mcp_server(
|
||||
"total_mcp_servers": len(session.mcp_servers),
|
||||
"note": (
|
||||
f"MCP server '{name}' registered with {len(tool_names)} tools. "
|
||||
"These tools can now be used in llm_tool_use nodes."
|
||||
"These tools can now be used in event_loop nodes."
|
||||
),
|
||||
},
|
||||
indent=2,
|
||||
@@ -2240,7 +2205,7 @@ def list_mcp_tools(
|
||||
"success": True,
|
||||
"tools_by_server": all_tools,
|
||||
"total_tools": total_tools,
|
||||
"note": "Use these tool names in the 'tools' parameter when adding llm_tool_use nodes",
|
||||
"note": "Use these tool names in the 'tools' parameter when adding event_loop nodes",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
@@ -2339,23 +2304,6 @@ def test_node(
|
||||
+ f"Max visits per graph run: {node_spec.max_node_visits}."
|
||||
)
|
||||
|
||||
elif node_spec.node_type in ("llm_generate", "llm_tool_use"):
|
||||
# Legacy LLM node types
|
||||
result["system_prompt"] = node_spec.system_prompt
|
||||
result["available_tools"] = node_spec.tools
|
||||
result["deprecation_warning"] = (
|
||||
f"Node type '{node_spec.node_type}' is deprecated. Use 'event_loop' instead."
|
||||
)
|
||||
|
||||
if mock_llm_response:
|
||||
result["mock_response"] = mock_llm_response
|
||||
result["simulation"] = "LLM would receive prompt and produce response"
|
||||
else:
|
||||
result["simulation"] = "LLM would be called with the system prompt and input data"
|
||||
|
||||
elif node_spec.node_type == "function":
|
||||
result["simulation"] = "Function node would execute deterministic logic"
|
||||
|
||||
# Show memory state after (simulated)
|
||||
result["expected_memory_state"] = {
|
||||
"inputs_available": {k: input_data.get(k, "<not provided>") for k in node_spec.input_keys},
|
||||
@@ -2449,7 +2397,7 @@ def test_graph(
|
||||
"writes": current_node.output_keys,
|
||||
}
|
||||
|
||||
if current_node.node_type in ("llm_generate", "llm_tool_use", "event_loop"):
|
||||
if current_node.node_type == "event_loop":
|
||||
step_info["prompt_preview"] = (
|
||||
current_node.system_prompt[:200] + "..."
|
||||
if current_node.system_prompt and len(current_node.system_prompt) > 200
|
||||
@@ -2520,466 +2468,6 @@ def test_graph(
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FLEXIBLE EXECUTION TOOLS (Worker-Judge Pattern)
|
||||
# =============================================================================
|
||||
|
||||
# Storage for evaluation rules
|
||||
_evaluation_rules: list[dict] = []
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def add_evaluation_rule(
|
||||
rule_id: Annotated[str, "Unique identifier for the rule"],
|
||||
description: Annotated[str, "Human-readable description of what this rule checks"],
|
||||
condition: Annotated[
|
||||
str,
|
||||
"Python expression with result, step, goal context. E.g., 'result.get(\"success\")'",
|
||||
],
|
||||
action: Annotated[str, "Action when rule matches: accept, retry, replan, escalate"],
|
||||
feedback_template: Annotated[
|
||||
str, "Template for feedback message, can use {result}, {step}"
|
||||
] = "",
|
||||
priority: Annotated[int, "Rule priority (higher = checked first)"] = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Add an evaluation rule for the HybridJudge.
|
||||
|
||||
Rules are checked in priority order before falling back to LLM evaluation.
|
||||
Use this to define deterministic success/failure conditions.
|
||||
|
||||
Example conditions:
|
||||
- 'result.get("success") == True' - Check for explicit success flag
|
||||
- 'result.get("error_type") == "timeout"' - Check for specific error type
|
||||
- 'len(result.get("data", [])) > 0' - Check for non-empty data
|
||||
"""
|
||||
global _evaluation_rules
|
||||
|
||||
# Validate action
|
||||
valid_actions = ["accept", "retry", "replan", "escalate"]
|
||||
if action.lower() not in valid_actions:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Invalid action '{action}'. Must be one of: {valid_actions}",
|
||||
}
|
||||
)
|
||||
|
||||
# Check for duplicate
|
||||
if any(r["id"] == rule_id for r in _evaluation_rules):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Rule '{rule_id}' already exists",
|
||||
}
|
||||
)
|
||||
|
||||
rule = {
|
||||
"id": rule_id,
|
||||
"description": description,
|
||||
"condition": condition,
|
||||
"action": action.lower(),
|
||||
"feedback_template": feedback_template,
|
||||
"priority": priority,
|
||||
}
|
||||
|
||||
_evaluation_rules.append(rule)
|
||||
_evaluation_rules.sort(key=lambda r: -r["priority"])
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"rule": rule,
|
||||
"total_rules": len(_evaluation_rules),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def list_evaluation_rules() -> str:
|
||||
"""List all configured evaluation rules for the HybridJudge."""
|
||||
return json.dumps(
|
||||
{
|
||||
"rules": _evaluation_rules,
|
||||
"total": len(_evaluation_rules),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def remove_evaluation_rule(
|
||||
rule_id: Annotated[str, "ID of the rule to remove"],
|
||||
) -> str:
|
||||
"""Remove an evaluation rule."""
|
||||
global _evaluation_rules
|
||||
|
||||
for i, rule in enumerate(_evaluation_rules):
|
||||
if rule["id"] == rule_id:
|
||||
_evaluation_rules.pop(i)
|
||||
return json.dumps({"success": True, "removed": rule_id})
|
||||
|
||||
return json.dumps({"success": False, "error": f"Rule '{rule_id}' not found"})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def create_plan(
|
||||
plan_id: Annotated[str, "Unique identifier for the plan"],
|
||||
goal_id: Annotated[str, "ID of the goal this plan achieves"],
|
||||
description: Annotated[str, "Description of what this plan does"],
|
||||
steps: Annotated[
|
||||
str,
|
||||
"JSON array of plan steps with id, description, action, inputs, outputs, deps",
|
||||
],
|
||||
context: Annotated[str, "JSON object with initial context for execution"] = "{}",
|
||||
) -> str:
|
||||
"""
|
||||
Create a plan for flexible execution.
|
||||
|
||||
Plans are executed by the Worker-Judge loop. Each step specifies:
|
||||
- id: Unique step identifier
|
||||
- description: What this step does
|
||||
- action: Object with action_type and parameters
|
||||
- action_type: "llm_call", "tool_use", "function", "code_execution", "sub_graph"
|
||||
- For llm_call: prompt, system_prompt
|
||||
- For tool_use: tool_name, tool_args
|
||||
- For function: function_name, function_args
|
||||
- For code_execution: code
|
||||
- inputs: Dict mapping input names to values or "$variable" references
|
||||
- expected_outputs: List of output keys this step should produce
|
||||
- dependencies: List of step IDs that must complete first (deps)
|
||||
|
||||
Example step:
|
||||
{
|
||||
"id": "step_1",
|
||||
"description": "Fetch user data",
|
||||
"action": {"action_type": "tool_use", "tool_name": "get_user", ...},
|
||||
"inputs": {"user_id": "$input_user_id"},
|
||||
"expected_outputs": ["user_data"],
|
||||
"dependencies": []
|
||||
}
|
||||
"""
|
||||
try:
|
||||
steps_list = json.loads(steps)
|
||||
context_dict = json.loads(context)
|
||||
except json.JSONDecodeError as e:
|
||||
return json.dumps({"success": False, "error": f"Invalid JSON: {e}"})
|
||||
|
||||
# Validate steps
|
||||
errors = []
|
||||
step_ids = set()
|
||||
|
||||
for i, step in enumerate(steps_list):
|
||||
if "id" not in step:
|
||||
errors.append(f"Step {i} missing 'id'")
|
||||
else:
|
||||
if step["id"] in step_ids:
|
||||
errors.append(f"Duplicate step id: {step['id']}")
|
||||
step_ids.add(step["id"])
|
||||
|
||||
if "description" not in step:
|
||||
errors.append(f"Step {i} missing 'description'")
|
||||
|
||||
if "action" not in step:
|
||||
errors.append(f"Step {i} missing 'action'")
|
||||
elif "action_type" not in step.get("action", {}):
|
||||
errors.append(f"Step {i} action missing 'action_type'")
|
||||
|
||||
# Check dependencies exist
|
||||
for dep in step.get("dependencies", []):
|
||||
if dep not in step_ids:
|
||||
errors.append(f"Step {step.get('id', i)} has unknown dependency: {dep}")
|
||||
|
||||
if errors:
|
||||
return json.dumps({"success": False, "errors": errors})
|
||||
|
||||
# Build plan object
|
||||
plan = {
|
||||
"id": plan_id,
|
||||
"goal_id": goal_id,
|
||||
"description": description,
|
||||
"steps": steps_list,
|
||||
"context": context_dict,
|
||||
"revision": 1,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"plan": plan,
|
||||
"step_count": len(steps_list),
|
||||
"note": "Plan created. Use execute_plan to run it with the Worker-Judge loop.",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def validate_plan(
|
||||
plan_json: Annotated[str, "JSON string of the plan to validate"],
|
||||
) -> str:
|
||||
"""
|
||||
Validate a plan structure before execution.
|
||||
|
||||
Checks:
|
||||
- All required fields present
|
||||
- No circular dependencies
|
||||
- All dependencies reference existing steps
|
||||
- Action types are valid
|
||||
- Context flow: all $variable references can be resolved
|
||||
"""
|
||||
try:
|
||||
plan = json.loads(plan_json)
|
||||
except json.JSONDecodeError as e:
|
||||
return json.dumps({"valid": False, "errors": [f"Invalid JSON: {e}"]})
|
||||
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# Check required fields
|
||||
required = ["id", "goal_id", "steps"]
|
||||
for field in required:
|
||||
if field not in plan:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
if "steps" not in plan:
|
||||
return json.dumps({"valid": False, "errors": errors})
|
||||
|
||||
steps = plan["steps"]
|
||||
step_ids = {s.get("id") for s in steps if "id" in s}
|
||||
steps_by_id = {s.get("id"): s for s in steps}
|
||||
|
||||
# Check each step
|
||||
valid_action_types = ["llm_call", "tool_use", "function", "code_execution", "sub_graph"]
|
||||
|
||||
for i, step in enumerate(steps):
|
||||
step_id = step.get("id", f"step_{i}")
|
||||
|
||||
# Check dependencies
|
||||
for dep in step.get("dependencies", []):
|
||||
if dep not in step_ids:
|
||||
errors.append(f"Step '{step_id}': unknown dependency '{dep}'")
|
||||
|
||||
# Check action type
|
||||
action = step.get("action", {})
|
||||
action_type = action.get("action_type")
|
||||
if action_type and action_type not in valid_action_types:
|
||||
errors.append(f"Step '{step_id}': invalid action_type '{action_type}'")
|
||||
|
||||
# Check action has required params
|
||||
if action_type == "llm_call" and not action.get("prompt"):
|
||||
warnings.append(f"Step '{step_id}': llm_call without prompt")
|
||||
if action_type == "tool_use" and not action.get("tool_name"):
|
||||
errors.append(f"Step '{step_id}': tool_use requires tool_name")
|
||||
if action_type == "code_execution" and not action.get("code"):
|
||||
errors.append(f"Step '{step_id}': code_execution requires code")
|
||||
|
||||
# Check for circular dependencies
|
||||
def has_cycle(step_id: str, visited: set, path: set) -> bool:
|
||||
if step_id in path:
|
||||
return True
|
||||
if step_id in visited:
|
||||
return False
|
||||
|
||||
visited.add(step_id)
|
||||
path.add(step_id)
|
||||
|
||||
step = next((s for s in steps if s.get("id") == step_id), None)
|
||||
if step:
|
||||
for dep in step.get("dependencies", []):
|
||||
if has_cycle(dep, visited, path):
|
||||
return True
|
||||
|
||||
path.remove(step_id)
|
||||
return False
|
||||
|
||||
for step in steps:
|
||||
if has_cycle(step.get("id", ""), set(), set()):
|
||||
errors.append(f"Circular dependency detected involving step '{step.get('id')}'")
|
||||
break
|
||||
|
||||
# === CONTEXT FLOW VALIDATION ===
|
||||
# Compute what keys each step can access (from dependencies' outputs)
|
||||
|
||||
# Build output map (step_id -> expected_outputs)
|
||||
step_outputs: dict[str, set[str]] = {}
|
||||
for step in steps:
|
||||
step_outputs[step.get("id", "")] = set(step.get("expected_outputs", []))
|
||||
|
||||
# Compute available context for each step in topological order
|
||||
available_context: dict[str, set[str]] = {}
|
||||
computed = set()
|
||||
remaining = set(step_ids)
|
||||
|
||||
# Get initial context keys from plan.context
|
||||
initial_context = set(plan.get("context", {}).keys())
|
||||
|
||||
for _ in range(len(steps) * 2):
|
||||
if not remaining:
|
||||
break
|
||||
|
||||
for step_id in list(remaining):
|
||||
step = steps_by_id.get(step_id)
|
||||
if not step:
|
||||
remaining.discard(step_id)
|
||||
continue
|
||||
|
||||
deps = step.get("dependencies", [])
|
||||
|
||||
# Can compute if all dependencies are computed
|
||||
if all(d in computed for d in deps):
|
||||
# Collect outputs from all dependencies (transitive)
|
||||
available = set(initial_context)
|
||||
for dep_id in deps:
|
||||
available.update(step_outputs.get(dep_id, set()))
|
||||
available.update(available_context.get(dep_id, set()))
|
||||
|
||||
available_context[step_id] = available
|
||||
computed.add(step_id)
|
||||
remaining.discard(step_id)
|
||||
break
|
||||
|
||||
# Check each step's inputs can be resolved
|
||||
context_errors = []
|
||||
context_warnings = []
|
||||
|
||||
for step in steps:
|
||||
step_id = step.get("id", "")
|
||||
available = available_context.get(step_id, set())
|
||||
deps = step.get("dependencies", [])
|
||||
inputs = step.get("inputs", {})
|
||||
|
||||
missing_vars = []
|
||||
for _, input_value in inputs.items():
|
||||
# Check $variable references
|
||||
if isinstance(input_value, str) and input_value.startswith("$"):
|
||||
var_name = input_value[1:] # Remove $ prefix
|
||||
if var_name not in available:
|
||||
missing_vars.append(var_name)
|
||||
|
||||
if missing_vars:
|
||||
if not deps:
|
||||
# Entry step - inputs must come from initial context
|
||||
context_warnings.append(
|
||||
f"Step '{step_id}' requires ${missing_vars} from initial context. "
|
||||
f"Ensure these are provided when running the agent: {missing_vars}"
|
||||
)
|
||||
else:
|
||||
# Find which step could provide each missing var
|
||||
suggestions = []
|
||||
for var in missing_vars:
|
||||
producers = [s.get("id") for s in steps if var in s.get("expected_outputs", [])]
|
||||
if producers:
|
||||
suggestions.append(f"${var} is produced by {producers} - add as dependency")
|
||||
else:
|
||||
suggestions.append(
|
||||
f"${var} is not produced by any step - add a step that outputs '{var}'"
|
||||
)
|
||||
|
||||
context_errors.append(
|
||||
f"Step '{step_id}' references ${missing_vars} but deps "
|
||||
f"{deps} don't provide them. Suggestions: {'; '.join(suggestions)}"
|
||||
)
|
||||
|
||||
errors.extend(context_errors)
|
||||
warnings.extend(context_warnings)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
"step_count": len(steps),
|
||||
"context_flow": {step_id: list(keys) for step_id, keys in available_context.items()}
|
||||
if available_context
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def simulate_plan_execution(
|
||||
plan_json: Annotated[str, "JSON string of the plan to simulate"],
|
||||
max_steps: Annotated[int, "Maximum steps to simulate"] = 20,
|
||||
) -> str:
|
||||
"""
|
||||
Simulate plan execution without actually running it.
|
||||
|
||||
Shows the order steps would execute based on dependencies.
|
||||
Useful for understanding the execution flow before running.
|
||||
"""
|
||||
try:
|
||||
plan = json.loads(plan_json)
|
||||
except json.JSONDecodeError as e:
|
||||
return json.dumps({"success": False, "error": f"Invalid JSON: {e}"})
|
||||
|
||||
# Validate first
|
||||
validation = json.loads(validate_plan(plan_json))
|
||||
if not validation["valid"]:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Plan is not valid",
|
||||
"validation_errors": validation["errors"],
|
||||
}
|
||||
)
|
||||
|
||||
steps = plan.get("steps", [])
|
||||
completed = set()
|
||||
execution_order = []
|
||||
iteration = 0
|
||||
|
||||
while len(completed) < len(steps) and iteration < max_steps:
|
||||
iteration += 1
|
||||
|
||||
# Find ready steps
|
||||
ready = []
|
||||
for step in steps:
|
||||
step_id = step.get("id")
|
||||
if step_id in completed:
|
||||
continue
|
||||
deps = set(step.get("dependencies", []))
|
||||
if deps.issubset(completed):
|
||||
ready.append(step)
|
||||
|
||||
if not ready:
|
||||
break
|
||||
|
||||
# Execute first ready step (in real execution, could be parallel)
|
||||
step = ready[0]
|
||||
step_id = step.get("id")
|
||||
|
||||
execution_order.append(
|
||||
{
|
||||
"iteration": iteration,
|
||||
"step_id": step_id,
|
||||
"description": step.get("description"),
|
||||
"action_type": step.get("action", {}).get("action_type"),
|
||||
"dependencies_met": list(step.get("dependencies", [])),
|
||||
"parallel_candidates": [s.get("id") for s in ready[1:]],
|
||||
}
|
||||
)
|
||||
|
||||
completed.add(step_id)
|
||||
|
||||
remaining = [s.get("id") for s in steps if s.get("id") not in completed]
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"execution_order": execution_order,
|
||||
"steps_simulated": len(execution_order),
|
||||
"remaining_steps": remaining,
|
||||
"plan_complete": len(remaining) == 0,
|
||||
"note": (
|
||||
"This is a simulation. Actual execution may differ "
|
||||
"based on step results and judge decisions."
|
||||
),
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TESTING TOOLS (Goal-Based Evaluation)
|
||||
# =============================================================================
|
||||
@@ -3713,90 +3201,37 @@ def list_tests(
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PLAN LOADING AND EXECUTION
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def load_plan_from_json(plan_json: str | dict) -> Plan:
|
||||
"""
|
||||
Load a Plan object from exported JSON.
|
||||
|
||||
Args:
|
||||
plan_json: JSON string or dict from export_graph()
|
||||
|
||||
Returns:
|
||||
Plan object ready for FlexibleGraphExecutor
|
||||
"""
|
||||
from framework.graph.plan import Plan
|
||||
|
||||
return Plan.from_json(plan_json)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def load_exported_plan(
|
||||
plan_json: Annotated[str, "JSON string from export_graph() output"],
|
||||
) -> str:
|
||||
"""
|
||||
Validate and load an exported plan, returning its structure.
|
||||
|
||||
Use this to verify a plan can be loaded before execution.
|
||||
"""
|
||||
try:
|
||||
plan = load_plan_from_json(plan_json)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"plan_id": plan.id,
|
||||
"goal_id": plan.goal_id,
|
||||
"description": plan.description,
|
||||
"step_count": len(plan.steps),
|
||||
"steps": [
|
||||
{
|
||||
"id": s.id,
|
||||
"description": s.description,
|
||||
"action_type": s.action.action_type.value,
|
||||
"dependencies": s.dependencies,
|
||||
}
|
||||
for s in plan.steps
|
||||
],
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
return json.dumps({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CREDENTIAL STORE TOOLS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_credential_store():
|
||||
"""Get a CredentialStore that checks encrypted files and env vars.
|
||||
"""Get a CredentialStore that checks encrypted files, env vars, and Aden sync.
|
||||
|
||||
Uses CompositeStorage: encrypted file storage (primary) with env var fallback.
|
||||
This ensures credentials stored via `store_credential` AND env vars are both found.
|
||||
Uses CredentialStoreAdapter.default() which handles:
|
||||
- Aden sync + provider index (resolving hashed IDs for OAuth)
|
||||
- CompositeStorage (encrypted primary + env fallback)
|
||||
- Auto-refresh of OAuth tokens
|
||||
- Graceful fallback if Aden is unavailable
|
||||
"""
|
||||
from framework.credentials import CredentialStore
|
||||
from framework.credentials.storage import CompositeStorage, EncryptedFileStorage, EnvVarStorage
|
||||
|
||||
# Build env var mapping from CREDENTIAL_SPECS for the fallback
|
||||
env_mapping: dict[str, str] = {}
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
from aden_tools.credentials.store_adapter import CredentialStoreAdapter
|
||||
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
cred_id = spec.credential_id or name
|
||||
env_mapping[cred_id] = spec.env_var
|
||||
return CredentialStoreAdapter.default().store
|
||||
except ImportError:
|
||||
pass
|
||||
from framework.credentials import CredentialStore
|
||||
from framework.credentials.storage import (
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
|
||||
storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=env_mapping)],
|
||||
)
|
||||
return CredentialStore(storage=storage)
|
||||
storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping={})],
|
||||
)
|
||||
return CredentialStore(storage=storage)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
||||
@@ -331,6 +331,20 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
)
|
||||
resume_parser.set_defaults(func=cmd_resume)
|
||||
|
||||
# setup-credentials command
|
||||
setup_creds_parser = subparsers.add_parser(
|
||||
"setup-credentials",
|
||||
help="Interactive credential setup",
|
||||
description="Guide through setting up required credentials for an agent.",
|
||||
)
|
||||
setup_creds_parser.add_argument(
|
||||
"agent_path",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Path to agent folder (optional - runs general setup if not specified)",
|
||||
)
|
||||
setup_creds_parser.set_defaults(func=cmd_setup_credentials)
|
||||
|
||||
|
||||
def _load_resume_state(
|
||||
agent_path: str, session_id: str, checkpoint_id: str | None = None
|
||||
@@ -362,6 +376,7 @@ def _load_resume_state(
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
return {
|
||||
"resume_session_id": session_id,
|
||||
"memory": cp_data.get("shared_memory", {}),
|
||||
"paused_at": cp_data.get("next_node") or cp_data.get("current_node"),
|
||||
"execution_path": cp_data.get("execution_path", []),
|
||||
@@ -379,6 +394,7 @@ def _load_resume_state(
|
||||
progress = state_data.get("progress", {})
|
||||
paused_at = progress.get("paused_at") or progress.get("resume_from")
|
||||
return {
|
||||
"resume_session_id": session_id,
|
||||
"memory": state_data.get("memory", {}),
|
||||
"paused_at": paused_at,
|
||||
"execution_path": progress.get("path", []),
|
||||
@@ -386,6 +402,40 @@ def _load_resume_state(
|
||||
}
|
||||
|
||||
|
||||
def _prompt_before_start(agent_path: str, runner, model: str | None = None):
|
||||
"""Prompt user to start agent or update credentials.
|
||||
|
||||
Returns:
|
||||
Updated runner if user proceeds, None if user aborts.
|
||||
"""
|
||||
from framework.credentials.setup import CredentialSetupSession
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
while True:
|
||||
print()
|
||||
try:
|
||||
choice = input("Press Enter to start agent, or 'u' to update credentials: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
return None
|
||||
|
||||
if choice == "":
|
||||
return runner
|
||||
elif choice.lower() == "u":
|
||||
session = CredentialSetupSession.from_agent_path(agent_path)
|
||||
result = session.run_interactive()
|
||||
if result.success:
|
||||
# Reload runner with updated credentials
|
||||
try:
|
||||
runner = AgentRunner.load(agent_path, model=model)
|
||||
except Exception as e:
|
||||
print(f"Error reloading agent: {e}")
|
||||
return None
|
||||
# Loop back to prompt again
|
||||
elif choice.lower() == "q":
|
||||
return None
|
||||
|
||||
|
||||
def cmd_run(args: argparse.Namespace) -> int:
|
||||
"""Run an exported agent."""
|
||||
import logging
|
||||
@@ -432,11 +482,46 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return
|
||||
# Offer interactive credential setup if running in a terminal
|
||||
if sys.stdin.isatty():
|
||||
print()
|
||||
try:
|
||||
choice = input("Would you like to set up credentials now? [Y/n]: ")
|
||||
choice = choice.strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
return
|
||||
if choice.lower() != "n":
|
||||
from framework.credentials.setup import CredentialSetupSession
|
||||
|
||||
session = CredentialSetupSession.from_agent_path(args.agent_path)
|
||||
result = session.run_interactive()
|
||||
if result.success:
|
||||
# Retry loading with credentials now configured
|
||||
try:
|
||||
runner = AgentRunner.load(args.agent_path, model=args.model)
|
||||
except CredentialError as retry_e:
|
||||
print(f"\n{retry_e}", file=sys.stderr)
|
||||
return
|
||||
except Exception as retry_e:
|
||||
print(f"Error loading agent: {retry_e}")
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error loading agent: {e}")
|
||||
return
|
||||
|
||||
# Prompt before starting (allows credential updates)
|
||||
if sys.stdin.isatty():
|
||||
runner = _prompt_before_start(args.agent_path, runner, args.model)
|
||||
if runner is None:
|
||||
return
|
||||
|
||||
# Force setup inside the loop
|
||||
if runner._agent_runtime is None:
|
||||
runner._setup()
|
||||
@@ -475,11 +560,45 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
# Offer interactive credential setup if running in a terminal
|
||||
if sys.stdin.isatty():
|
||||
print()
|
||||
try:
|
||||
choice = input("Would you like to set up credentials now? [Y/n]: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
return 1
|
||||
if choice.lower() != "n":
|
||||
from framework.credentials.setup import CredentialSetupSession
|
||||
|
||||
session = CredentialSetupSession.from_agent_path(args.agent_path)
|
||||
result = session.run_interactive()
|
||||
if result.success:
|
||||
# Retry loading with credentials now configured
|
||||
try:
|
||||
runner = AgentRunner.load(args.agent_path, model=args.model)
|
||||
except CredentialError as retry_e:
|
||||
print(f"\n{retry_e}", file=sys.stderr)
|
||||
return 1
|
||||
except Exception as retry_e:
|
||||
print(f"Error loading agent: {retry_e}")
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Prompt before starting (allows credential updates)
|
||||
if sys.stdin.isatty() and not args.quiet:
|
||||
runner = _prompt_before_start(args.agent_path, runner, args.model)
|
||||
if runner is None:
|
||||
return 1
|
||||
|
||||
# Load session/checkpoint state for resume (headless mode)
|
||||
session_state = None
|
||||
resume_session = getattr(args, "resume_session", None)
|
||||
@@ -1281,7 +1400,35 @@ def cmd_tui(args: argparse.Namespace) -> int:
|
||||
)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return
|
||||
# Offer interactive credential setup if running in a terminal
|
||||
if sys.stdin.isatty():
|
||||
print()
|
||||
try:
|
||||
choice = input("Would you like to set up credentials now? [Y/n]: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
return
|
||||
if choice.lower() != "n":
|
||||
from framework.credentials.setup import CredentialSetupSession
|
||||
|
||||
session = CredentialSetupSession.from_agent_path(agent_path)
|
||||
result = session.run_interactive()
|
||||
if result.success:
|
||||
# Retry loading with credentials now configured
|
||||
try:
|
||||
runner = AgentRunner.load(agent_path, model=args.model)
|
||||
except CredentialError as retry_e:
|
||||
print(f"\n{retry_e}", file=sys.stderr)
|
||||
return
|
||||
except Exception as retry_e:
|
||||
print(f"Error loading agent: {retry_e}")
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error loading agent: {e}")
|
||||
return
|
||||
@@ -1447,6 +1594,7 @@ def _select_agent(agents_dir: Path) -> str | None:
|
||||
for path in agents_dir.iterdir():
|
||||
if _is_valid_agent_dir(path):
|
||||
agents.append(path)
|
||||
agents.sort(key=lambda p: p.name)
|
||||
|
||||
if not agents:
|
||||
print(f"No agents found in {agents_dir}", file=sys.stderr)
|
||||
@@ -1716,3 +1864,25 @@ def cmd_resume(args: argparse.Namespace) -> int:
|
||||
if args.tui:
|
||||
print("Mode: TUI")
|
||||
return 1
|
||||
|
||||
|
||||
def cmd_setup_credentials(args: argparse.Namespace) -> int:
|
||||
"""Interactive credential setup for an agent."""
|
||||
from framework.credentials.setup import CredentialSetupSession
|
||||
|
||||
agent_path = getattr(args, "agent_path", None)
|
||||
|
||||
if agent_path:
|
||||
# Setup credentials for a specific agent
|
||||
session = CredentialSetupSession.from_agent_path(agent_path)
|
||||
else:
|
||||
# No agent specified - show usage
|
||||
print("Usage: hive setup-credentials <agent_path>")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" hive setup-credentials exports/my-agent")
|
||||
print(" hive setup-credentials examples/templates/deep_research_agent")
|
||||
return 1
|
||||
|
||||
result = session.run_interactive()
|
||||
return 0 if result.success else 1
|
||||
|
||||
@@ -183,8 +183,11 @@ class MCPClient:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
# Create persistent stdio client context
|
||||
self._stdio_context = stdio_client(server_params)
|
||||
# Create persistent stdio client context.
|
||||
# Redirect server stderr to devnull to prevent raw
|
||||
# output from leaking behind the TUI.
|
||||
devnull = open(os.devnull, "w") # noqa: SIM115
|
||||
self._stdio_context = stdio_client(server_params, errlog=devnull)
|
||||
(
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
|
||||
@@ -456,7 +456,7 @@ Respond with JSON only:
|
||||
}}"""
|
||||
|
||||
try:
|
||||
response = self._llm.complete(
|
||||
response = await self._llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a request router. Respond with JSON only.",
|
||||
max_tokens=256,
|
||||
|
||||
+68
-199
@@ -9,6 +9,10 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.config import get_hive_config, get_preferred_model
|
||||
from framework.credentials.validation import (
|
||||
ensure_credential_key_env as _ensure_credential_key_env,
|
||||
validate_agent_credentials,
|
||||
)
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import (
|
||||
DEFAULT_MAX_TOKENS,
|
||||
@@ -21,7 +25,7 @@ from framework.graph.executor import ExecutionResult
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.runtime_log_store import RuntimeLogStore
|
||||
|
||||
@@ -31,32 +35,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ensure_credential_key_env() -> None:
|
||||
"""Load HIVE_CREDENTIAL_KEY from shell config if not already in environment.
|
||||
|
||||
The setup-credentials skill writes the encryption key to ~/.zshrc or ~/.bashrc.
|
||||
If the user hasn't sourced their config in the current shell, this reads it
|
||||
directly so the runner (and any MCP subprocesses it spawns) can unlock the
|
||||
encrypted credential store.
|
||||
|
||||
Only HIVE_CREDENTIAL_KEY is loaded this way — all other secrets (API keys, etc.)
|
||||
come from the credential store itself.
|
||||
"""
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
return
|
||||
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import check_env_var_in_shell_config
|
||||
|
||||
found, value = check_env_var_in_shell_config("HIVE_CREDENTIAL_KEY")
|
||||
if found and value:
|
||||
os.environ["HIVE_CREDENTIAL_KEY"] = value
|
||||
logger.debug("Loaded HIVE_CREDENTIAL_KEY from shell config")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
CLAUDE_CREDENTIALS_FILE = Path.home() / ".claude" / ".credentials.json"
|
||||
|
||||
|
||||
@@ -268,6 +246,7 @@ class AgentRunner:
|
||||
storage_path: Path | None = None,
|
||||
model: str | None = None,
|
||||
intro_message: str = "",
|
||||
runtime_config: "AgentRuntimeConfig | None" = None,
|
||||
):
|
||||
"""
|
||||
Initialize the runner (use AgentRunner.load() instead).
|
||||
@@ -280,6 +259,7 @@ class AgentRunner:
|
||||
storage_path: Path for runtime storage (defaults to temp)
|
||||
model: Model to use (reads from agent config or ~/.hive/configuration.json if None)
|
||||
intro_message: Optional greeting shown to user on TUI load
|
||||
runtime_config: Optional AgentRuntimeConfig (webhook settings, etc.)
|
||||
"""
|
||||
self.agent_path = agent_path
|
||||
self.graph = graph
|
||||
@@ -287,6 +267,7 @@ class AgentRunner:
|
||||
self.mock_mode = mock_mode
|
||||
self.model = model or self._resolve_default_model()
|
||||
self.intro_message = intro_message
|
||||
self.runtime_config = runtime_config
|
||||
|
||||
# Set up storage
|
||||
if storage_path:
|
||||
@@ -331,88 +312,8 @@ class AgentRunner:
|
||||
"""Check that required credentials are available before spawning MCP servers.
|
||||
|
||||
Raises CredentialError with actionable guidance if any are missing.
|
||||
Uses graph node specs + CREDENTIAL_SPECS — no tool registry needed.
|
||||
"""
|
||||
required_tools: set[str] = set()
|
||||
for node in self.graph.nodes:
|
||||
if node.tools:
|
||||
required_tools.update(node.tools)
|
||||
node_types: set[str] = {node.node_type for node in self.graph.nodes}
|
||||
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
from framework.credentials import CredentialStore
|
||||
from framework.credentials.storage import (
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
except ImportError:
|
||||
return # aden_tools not installed, skip check
|
||||
|
||||
# Build credential store (same logic as validate())
|
||||
env_mapping = {
|
||||
(spec.credential_id or name): spec.env_var for name, spec in CREDENTIAL_SPECS.items()
|
||||
}
|
||||
storages: list = [EnvVarStorage(env_mapping=env_mapping)]
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
storages.insert(0, EncryptedFileStorage())
|
||||
if len(storages) == 1:
|
||||
storage = storages[0]
|
||||
else:
|
||||
storage = CompositeStorage(primary=storages[0], fallbacks=storages[1:])
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Build reverse mappings
|
||||
tool_to_cred: dict[str, str] = {}
|
||||
node_type_to_cred: dict[str, str] = {}
|
||||
for cred_name, spec in CREDENTIAL_SPECS.items():
|
||||
for tool_name in spec.tools:
|
||||
tool_to_cred[tool_name] = cred_name
|
||||
for nt in spec.node_types:
|
||||
node_type_to_cred[nt] = cred_name
|
||||
|
||||
missing: list[str] = []
|
||||
checked: set[str] = set()
|
||||
|
||||
# Check tool credentials
|
||||
for tool_name in sorted(required_tools):
|
||||
cred_name = tool_to_cred.get(tool_name)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected = sorted(t for t in required_tools if t in spec.tools)
|
||||
entry = f" {spec.env_var} for {', '.join(affected)}"
|
||||
if spec.help_url:
|
||||
entry += f"\n Get it at: {spec.help_url}"
|
||||
missing.append(entry)
|
||||
|
||||
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
|
||||
for nt in sorted(node_types):
|
||||
cred_name = node_type_to_cred.get(nt)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected_types = sorted(t for t in node_types if t in spec.node_types)
|
||||
entry = f" {spec.env_var} for {', '.join(affected_types)} nodes"
|
||||
if spec.help_url:
|
||||
entry += f"\n Get it at: {spec.help_url}"
|
||||
missing.append(entry)
|
||||
|
||||
if missing:
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
lines = ["Missing required credentials:\n"]
|
||||
lines.extend(missing)
|
||||
lines.append("\nTo fix: run /hive-credentials in Claude Code.")
|
||||
raise CredentialError("\n".join(lines))
|
||||
validate_agent_credentials(self.graph.nodes)
|
||||
|
||||
@staticmethod
|
||||
def _import_agent_module(agent_path: Path):
|
||||
@@ -510,18 +411,32 @@ class AgentRunner:
|
||||
intro_message = agent_metadata.intro_message
|
||||
|
||||
# Build GraphSpec from module-level variables
|
||||
graph = GraphSpec(
|
||||
id=f"{agent_path.name}-graph",
|
||||
goal_id=goal.id,
|
||||
version="1.0.0",
|
||||
entry_node=getattr(agent_module, "entry_node", nodes[0].id),
|
||||
entry_points=getattr(agent_module, "entry_points", {}),
|
||||
terminal_nodes=getattr(agent_module, "terminal_nodes", []),
|
||||
pause_nodes=getattr(agent_module, "pause_nodes", []),
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
graph_kwargs: dict = {
|
||||
"id": f"{agent_path.name}-graph",
|
||||
"goal_id": goal.id,
|
||||
"version": "1.0.0",
|
||||
"entry_node": getattr(agent_module, "entry_node", nodes[0].id),
|
||||
"entry_points": getattr(agent_module, "entry_points", {}),
|
||||
"async_entry_points": getattr(agent_module, "async_entry_points", []),
|
||||
"terminal_nodes": getattr(agent_module, "terminal_nodes", []),
|
||||
"pause_nodes": getattr(agent_module, "pause_nodes", []),
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"max_tokens": max_tokens,
|
||||
"loop_config": getattr(agent_module, "loop_config", {}),
|
||||
}
|
||||
# Only pass optional fields if explicitly defined by the agent module
|
||||
conversation_mode = getattr(agent_module, "conversation_mode", None)
|
||||
if conversation_mode is not None:
|
||||
graph_kwargs["conversation_mode"] = conversation_mode
|
||||
identity_prompt = getattr(agent_module, "identity_prompt", None)
|
||||
if identity_prompt is not None:
|
||||
graph_kwargs["identity_prompt"] = identity_prompt
|
||||
|
||||
graph = GraphSpec(**graph_kwargs)
|
||||
|
||||
# Read runtime config (webhook settings, etc.) if defined
|
||||
agent_runtime_config = getattr(agent_module, "runtime_config", None)
|
||||
|
||||
return cls(
|
||||
agent_path=agent_path,
|
||||
@@ -531,6 +446,7 @@ class AgentRunner:
|
||||
storage_path=storage_path,
|
||||
model=model,
|
||||
intro_message=intro_message,
|
||||
runtime_config=agent_runtime_config,
|
||||
)
|
||||
|
||||
# Fallback: load from agent.json (legacy JSON-based agents)
|
||||
@@ -686,7 +602,9 @@ class AgentRunner:
|
||||
else:
|
||||
# Fall back to environment variable
|
||||
# First check api_key_env_var from config (set by quickstart)
|
||||
api_key_env = llm_config.get("api_key_env_var") or self._get_api_key_env_var(self.model)
|
||||
api_key_env = llm_config.get("api_key_env_var") or self._get_api_key_env_var(
|
||||
self.model
|
||||
)
|
||||
if api_key_env and os.environ.get(api_key_env):
|
||||
self._llm = LiteLLMProvider(model=self.model)
|
||||
else:
|
||||
@@ -784,17 +702,19 @@ class AgentRunner:
|
||||
)
|
||||
entry_points.append(ep)
|
||||
|
||||
# Single-entry agent with no async entry points: create a default entry point
|
||||
if not entry_points and self.graph.entry_node:
|
||||
logger.info("Creating default entry point for single-entry agent")
|
||||
entry_points.append(
|
||||
# Always create a primary entry point for the graph's entry node.
|
||||
# For multi-entry-point agents this ensures the primary path (e.g.
|
||||
# user-facing rule setup) is reachable alongside async entry points.
|
||||
if self.graph.entry_node:
|
||||
entry_points.insert(
|
||||
0,
|
||||
EntryPointSpec(
|
||||
id="default",
|
||||
name="Default",
|
||||
entry_node=self.graph.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Create AgentRuntime with all entry points
|
||||
@@ -821,6 +741,7 @@ class AgentRunner:
|
||||
tool_executor=tool_executor,
|
||||
runtime_log_store=log_store,
|
||||
checkpoint_config=checkpoint_config,
|
||||
config=self.runtime_config,
|
||||
)
|
||||
|
||||
# Pass intro_message through for TUI display
|
||||
@@ -1143,88 +1064,36 @@ class AgentRunner:
|
||||
warnings.append(f"Missing tool implementations: {', '.join(missing_tools)}")
|
||||
|
||||
# Check credentials for required tools and node types
|
||||
# Uses CredentialStore (encrypted files + env var fallback)
|
||||
# Uses CredentialStoreAdapter.default() which includes Aden sync support
|
||||
missing_credentials = []
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
from aden_tools.credentials.store_adapter import CredentialStoreAdapter
|
||||
|
||||
from framework.credentials import CredentialStore
|
||||
from framework.credentials.storage import (
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
|
||||
# Build env mapping for credential lookup
|
||||
env_mapping = {
|
||||
(spec.credential_id or name): spec.env_var
|
||||
for name, spec in CREDENTIAL_SPECS.items()
|
||||
}
|
||||
|
||||
# Only use EncryptedFileStorage if the encryption key is configured;
|
||||
# otherwise just check env vars (avoids generating a throwaway key)
|
||||
storages: list = [EnvVarStorage(env_mapping=env_mapping)]
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
storages.insert(0, EncryptedFileStorage())
|
||||
|
||||
if len(storages) == 1:
|
||||
storage = storages[0]
|
||||
else:
|
||||
storage = CompositeStorage(
|
||||
primary=storages[0],
|
||||
fallbacks=storages[1:],
|
||||
)
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Build reverse mappings
|
||||
tool_to_cred: dict[str, str] = {}
|
||||
node_type_to_cred: dict[str, str] = {}
|
||||
for cred_name, spec in CREDENTIAL_SPECS.items():
|
||||
for tool_name in spec.tools:
|
||||
tool_to_cred[tool_name] = cred_name
|
||||
for nt in spec.node_types:
|
||||
node_type_to_cred[nt] = cred_name
|
||||
adapter = CredentialStoreAdapter.default()
|
||||
|
||||
# Check tool credentials
|
||||
checked: set[str] = set()
|
||||
for tool_name in info.required_tools:
|
||||
cred_name = tool_to_cred.get(tool_name)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
missing_credentials.append(spec.env_var)
|
||||
affected_tools = [t for t in info.required_tools if t in spec.tools]
|
||||
tools_str = ", ".join(affected_tools)
|
||||
warning_msg = f"Missing {spec.env_var} for {tools_str}"
|
||||
if spec.help_url:
|
||||
warning_msg += f"\n Get it at: {spec.help_url}"
|
||||
warnings.append(warning_msg)
|
||||
for _cred_name, spec in adapter.get_missing_for_tools(list(info.required_tools)):
|
||||
missing_credentials.append(spec.env_var)
|
||||
affected_tools = [t for t in info.required_tools if t in spec.tools]
|
||||
tools_str = ", ".join(affected_tools)
|
||||
warning_msg = f"Missing {spec.env_var} for {tools_str}"
|
||||
if spec.help_url:
|
||||
warning_msg += f"\n Get it at: {spec.help_url}"
|
||||
warnings.append(warning_msg)
|
||||
|
||||
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
|
||||
node_types = list({node.node_type for node in self.graph.nodes})
|
||||
for nt in node_types:
|
||||
cred_name = node_type_to_cred.get(nt)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
missing_credentials.append(spec.env_var)
|
||||
affected_types = [t for t in node_types if t in spec.node_types]
|
||||
types_str = ", ".join(affected_types)
|
||||
warning_msg = f"Missing {spec.env_var} for {types_str} nodes"
|
||||
if spec.help_url:
|
||||
warning_msg += f"\n Get it at: {spec.help_url}"
|
||||
warnings.append(warning_msg)
|
||||
for _cred_name, spec in adapter.get_missing_for_node_types(node_types):
|
||||
missing_credentials.append(spec.env_var)
|
||||
affected_types = [t for t in node_types if t in spec.node_types]
|
||||
types_str = ", ".join(affected_types)
|
||||
warning_msg = f"Missing {spec.env_var} for {types_str} nodes"
|
||||
if spec.help_url:
|
||||
warning_msg += f"\n Get it at: {spec.help_url}"
|
||||
warnings.append(warning_msg)
|
||||
except ImportError:
|
||||
# aden_tools not installed - fall back to direct check
|
||||
has_llm_nodes = any(
|
||||
node.node_type in ("llm_generate", "llm_tool_use") for node in self.graph.nodes
|
||||
)
|
||||
has_llm_nodes = any(node.node_type == "event_loop" for node in self.graph.nodes)
|
||||
if has_llm_nodes:
|
||||
api_key_env = self._get_api_key_env_var(self.model)
|
||||
if api_key_env and not os.environ.get(api_key_env):
|
||||
@@ -1306,7 +1175,7 @@ Respond with JSON only:
|
||||
}}"""
|
||||
|
||||
try:
|
||||
response = eval_llm.complete(
|
||||
response = await eval_llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a capability evaluator. Respond with JSON only.",
|
||||
max_tokens=256,
|
||||
|
||||
@@ -7,8 +7,9 @@ while preserving the goal-driven approach.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -39,6 +40,11 @@ class AgentRuntimeConfig:
|
||||
max_history: int = 1000
|
||||
execution_result_max: int = 1000
|
||||
execution_result_ttl_seconds: float | None = None
|
||||
# Webhook server config (only starts if webhook_routes is non-empty)
|
||||
webhook_host: str = "127.0.0.1"
|
||||
webhook_port: int = 8080
|
||||
webhook_routes: list[dict] = field(default_factory=list)
|
||||
# Each dict: {"source_id": str, "path": str, "methods": ["POST"], "secret": str|None}
|
||||
|
||||
|
||||
class AgentRuntime:
|
||||
@@ -150,6 +156,15 @@ class AgentRuntime:
|
||||
self._entry_points: dict[str, EntryPointSpec] = {}
|
||||
self._streams: dict[str, ExecutionStream] = {}
|
||||
|
||||
# Webhook server (created on start if webhook_routes configured)
|
||||
self._webhook_server: Any = None
|
||||
# Event-driven entry point subscriptions
|
||||
self._event_subscriptions: list[str] = []
|
||||
# Timer tasks for scheduled entry points
|
||||
self._timer_tasks: list[asyncio.Task] = []
|
||||
# Next fire time for each timer entry point (ep_id -> datetime)
|
||||
self._timer_next_fire: dict[str, float] = {}
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
self._lock = asyncio.Lock()
|
||||
@@ -234,6 +249,137 @@ class AgentRuntime:
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
|
||||
# Start webhook server if routes are configured
|
||||
if self._config.webhook_routes:
|
||||
from framework.runtime.webhook_server import (
|
||||
WebhookRoute,
|
||||
WebhookServer,
|
||||
WebhookServerConfig,
|
||||
)
|
||||
|
||||
wh_config = WebhookServerConfig(
|
||||
host=self._config.webhook_host,
|
||||
port=self._config.webhook_port,
|
||||
)
|
||||
self._webhook_server = WebhookServer(self._event_bus, wh_config)
|
||||
|
||||
for rc in self._config.webhook_routes:
|
||||
route = WebhookRoute(
|
||||
source_id=rc["source_id"],
|
||||
path=rc["path"],
|
||||
methods=rc.get("methods", ["POST"]),
|
||||
secret=rc.get("secret"),
|
||||
)
|
||||
self._webhook_server.add_route(route)
|
||||
|
||||
await self._webhook_server.start()
|
||||
|
||||
# Subscribe event-driven entry points to EventBus
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
for ep_id, spec in self._entry_points.items():
|
||||
if spec.trigger_type != "event":
|
||||
continue
|
||||
|
||||
tc = spec.trigger_config
|
||||
event_types = [_ET(et) for et in tc.get("event_types", [])]
|
||||
if not event_types:
|
||||
logger.warning(
|
||||
f"Entry point '{ep_id}' has trigger_type='event' "
|
||||
"but no event_types in trigger_config"
|
||||
)
|
||||
continue
|
||||
|
||||
# Capture ep_id in closure
|
||||
def _make_handler(entry_point_id: str):
|
||||
async def _on_event(event):
|
||||
if self._running and entry_point_id in self._streams:
|
||||
# Run in the same session as the primary entry
|
||||
# point so memory (e.g. user-defined rules) is
|
||||
# shared and logs land in one session directory.
|
||||
session_state = self._get_primary_session_state(
|
||||
exclude_entry_point=entry_point_id
|
||||
)
|
||||
await self.trigger(
|
||||
entry_point_id,
|
||||
{"event": event.to_dict()},
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
return _on_event
|
||||
|
||||
sub_id = self._event_bus.subscribe(
|
||||
event_types=event_types,
|
||||
handler=_make_handler(ep_id),
|
||||
filter_stream=tc.get("filter_stream"),
|
||||
filter_node=tc.get("filter_node"),
|
||||
)
|
||||
self._event_subscriptions.append(sub_id)
|
||||
|
||||
# Start timer-driven entry points
|
||||
for ep_id, spec in self._entry_points.items():
|
||||
if spec.trigger_type != "timer":
|
||||
continue
|
||||
|
||||
tc = spec.trigger_config
|
||||
interval = tc.get("interval_minutes")
|
||||
if not interval or interval <= 0:
|
||||
logger.warning(
|
||||
f"Entry point '{ep_id}' has trigger_type='timer' "
|
||||
"but no valid interval_minutes in trigger_config"
|
||||
)
|
||||
continue
|
||||
|
||||
run_immediately = tc.get("run_immediately", False)
|
||||
|
||||
def _make_timer(entry_point_id: str, mins: float, immediate: bool):
|
||||
async def _timer_loop():
|
||||
interval_secs = mins * 60
|
||||
if not immediate:
|
||||
self._timer_next_fire[entry_point_id] = time.monotonic() + interval_secs
|
||||
await asyncio.sleep(interval_secs)
|
||||
while self._running:
|
||||
self._timer_next_fire.pop(entry_point_id, None)
|
||||
try:
|
||||
if self._should_skip_timer(entry_point_id):
|
||||
logger.info(
|
||||
"Timer '%s' skipped — primary stream busy",
|
||||
entry_point_id,
|
||||
)
|
||||
else:
|
||||
session_state = self._get_primary_session_state(
|
||||
exclude_entry_point=entry_point_id
|
||||
)
|
||||
await self.trigger(
|
||||
entry_point_id,
|
||||
{"event": {"source": "timer", "reason": "scheduled"}},
|
||||
session_state=session_state,
|
||||
)
|
||||
logger.info(
|
||||
"Timer fired for entry point '%s' (next in %s min)",
|
||||
entry_point_id,
|
||||
mins,
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Timer trigger failed for '%s'",
|
||||
entry_point_id,
|
||||
exc_info=True,
|
||||
)
|
||||
self._timer_next_fire[entry_point_id] = time.monotonic() + interval_secs
|
||||
await asyncio.sleep(interval_secs)
|
||||
|
||||
return _timer_loop
|
||||
|
||||
task = asyncio.create_task(_make_timer(ep_id, interval, run_immediately)())
|
||||
self._timer_tasks.append(task)
|
||||
logger.info(
|
||||
"Started timer for entry point '%s' every %s min%s",
|
||||
ep_id,
|
||||
interval,
|
||||
" (immediate first run)" if run_immediately else "",
|
||||
)
|
||||
|
||||
self._running = True
|
||||
logger.info(f"AgentRuntime started with {len(self._streams)} streams")
|
||||
|
||||
@@ -243,6 +389,21 @@ class AgentRuntime:
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
# Cancel timer tasks
|
||||
for task in self._timer_tasks:
|
||||
task.cancel()
|
||||
self._timer_tasks.clear()
|
||||
|
||||
# Unsubscribe event-driven entry points
|
||||
for sub_id in self._event_subscriptions:
|
||||
self._event_bus.unsubscribe(sub_id)
|
||||
self._event_subscriptions.clear()
|
||||
|
||||
# Stop webhook server
|
||||
if self._webhook_server:
|
||||
await self._webhook_server.stop()
|
||||
self._webhook_server = None
|
||||
|
||||
# Stop all streams
|
||||
for stream in self._streams.values():
|
||||
await stream.stop()
|
||||
@@ -314,6 +475,83 @@ class AgentRuntime:
|
||||
raise ValueError(f"Entry point '{entry_point_id}' not found")
|
||||
return await stream.wait_for_completion(exec_id, timeout)
|
||||
|
||||
def _should_skip_timer(self, timer_ep_id: str) -> bool:
|
||||
"""Return True if a non-timer stream is actively running (not waiting for input).
|
||||
|
||||
Timers should only fire when the primary stream is idle (blocked
|
||||
waiting for client input) or has no active execution. This prevents
|
||||
concurrent pipeline runs that would race on shared memory.
|
||||
"""
|
||||
for ep_id, stream in self._streams.items():
|
||||
if ep_id == timer_ep_id:
|
||||
continue
|
||||
spec = self._entry_points.get(ep_id)
|
||||
if spec and spec.trigger_type == "timer":
|
||||
continue
|
||||
if stream.active_execution_ids and not stream.is_awaiting_input:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_primary_session_state(self, exclude_entry_point: str) -> dict[str, Any] | None:
|
||||
"""Build session_state so an async entry point runs in the primary session.
|
||||
|
||||
Looks for an active execution from another stream (the "primary"
|
||||
session, e.g. the user-facing intake loop) and returns a
|
||||
``session_state`` dict containing:
|
||||
|
||||
- ``resume_session_id``: reuse the same session directory
|
||||
- ``memory``: only the keys that the async entry node declares
|
||||
as inputs (e.g. ``rules``, ``max_emails``). Stale outputs
|
||||
from previous runs (``emails``, ``actions_taken``, …) are
|
||||
excluded so each trigger starts fresh.
|
||||
|
||||
The memory is read from the primary session's ``state.json``
|
||||
which is kept up-to-date by ``GraphExecutor._write_progress()``
|
||||
at every node transition.
|
||||
|
||||
Returns ``None`` if no primary session is active (the webhook
|
||||
execution will just create its own session).
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
# Determine which memory keys the async entry node needs.
|
||||
allowed_keys: set[str] | None = None
|
||||
ep_spec = self._entry_points.get(exclude_entry_point)
|
||||
if ep_spec:
|
||||
entry_node = self.graph.get_node(ep_spec.entry_node)
|
||||
if entry_node and entry_node.input_keys:
|
||||
allowed_keys = set(entry_node.input_keys)
|
||||
|
||||
for ep_id, stream in self._streams.items():
|
||||
if ep_id == exclude_entry_point:
|
||||
continue
|
||||
for exec_id in stream.active_execution_ids:
|
||||
state_path = self._storage.base_path / "sessions" / exec_id / "state.json"
|
||||
try:
|
||||
if state_path.exists():
|
||||
data = _json.loads(state_path.read_text(encoding="utf-8"))
|
||||
full_memory = data.get("memory", {})
|
||||
if not full_memory:
|
||||
continue
|
||||
# Filter to only input keys so stale outputs
|
||||
# from previous triggers don't leak through.
|
||||
if allowed_keys is not None:
|
||||
memory = {k: v for k, v in full_memory.items() if k in allowed_keys}
|
||||
else:
|
||||
memory = full_memory
|
||||
if memory:
|
||||
return {
|
||||
"resume_session_id": exec_id,
|
||||
"memory": memory,
|
||||
}
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not read state.json for %s: skipping",
|
||||
exec_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def inject_input(self, node_id: str, content: str) -> bool:
|
||||
"""Inject user input into a running client-facing node.
|
||||
|
||||
@@ -448,6 +686,11 @@ class AgentRuntime:
|
||||
"""Access the outcome aggregator."""
|
||||
return self._outcome_aggregator
|
||||
|
||||
@property
|
||||
def webhook_server(self) -> Any:
|
||||
"""Access the webhook server (None if no webhook entry points)."""
|
||||
return self._webhook_server
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if runtime is running."""
|
||||
|
||||
@@ -56,6 +56,12 @@ class Runtime:
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: str | Path):
|
||||
# Validate and create storage path if needed
|
||||
path = Path(storage_path) if isinstance(storage_path, str) else storage_path
|
||||
if not path.exists():
|
||||
logger.warning(f"Storage path does not exist, creating: {path}")
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.storage = FileStorage(storage_path)
|
||||
self._current_run: Run | None = None
|
||||
self._current_node: str = "unknown"
|
||||
|
||||
@@ -62,6 +62,23 @@ class EventType(StrEnum):
|
||||
NODE_INTERNAL_OUTPUT = "node_internal_output"
|
||||
NODE_INPUT_BLOCKED = "node_input_blocked"
|
||||
NODE_STALLED = "node_stalled"
|
||||
NODE_TOOL_DOOM_LOOP = "node_tool_doom_loop"
|
||||
|
||||
# Judge decisions
|
||||
JUDGE_VERDICT = "judge_verdict"
|
||||
|
||||
# Output tracking
|
||||
OUTPUT_KEY_SET = "output_key_set"
|
||||
|
||||
# Retry / edge tracking
|
||||
NODE_RETRY = "node_retry"
|
||||
EDGE_TRAVERSED = "edge_traversed"
|
||||
|
||||
# Context management
|
||||
CONTEXT_COMPACTED = "context_compacted"
|
||||
|
||||
# External triggers
|
||||
WEBHOOK_RECEIVED = "webhook_received"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
@@ -615,6 +632,24 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_tool_doom_loop(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
description: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool doom loop detection event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_TOOL_DOOM_LOOP,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"description": description},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_input_blocked(
|
||||
self,
|
||||
stream_id: str,
|
||||
@@ -633,6 +668,158 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
# === JUDGE / OUTPUT / RETRY / EDGE PUBLISHERS ===
|
||||
|
||||
async def emit_judge_verdict(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
action: str,
|
||||
feedback: str = "",
|
||||
judge_type: str = "implicit",
|
||||
iteration: int = 0,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit judge verdict event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.JUDGE_VERDICT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"action": action,
|
||||
"feedback": feedback,
|
||||
"judge_type": judge_type,
|
||||
"iteration": iteration,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_output_key_set(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
key: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit output key set event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.OUTPUT_KEY_SET,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"key": key},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_retry(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
retry_count: int,
|
||||
max_retries: int,
|
||||
error: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node retry event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_RETRY,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"retry_count": retry_count,
|
||||
"max_retries": max_retries,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_edge_traversed(
|
||||
self,
|
||||
stream_id: str,
|
||||
source_node: str,
|
||||
target_node: str,
|
||||
edge_condition: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit edge traversed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EDGE_TRAVERSED,
|
||||
stream_id=stream_id,
|
||||
node_id=source_node,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"source_node": source_node,
|
||||
"target_node": target_node,
|
||||
"edge_condition": edge_condition,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_execution_paused(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
reason: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution paused event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_PAUSED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"reason": reason},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_execution_resumed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution resumed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_RESUMED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_webhook_received(
|
||||
self,
|
||||
source_id: str,
|
||||
path: str,
|
||||
method: str,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
query_params: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Emit webhook received event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.WEBHOOK_RECEIVED,
|
||||
stream_id=source_id,
|
||||
data={
|
||||
"path": path,
|
||||
"method": method,
|
||||
"headers": headers,
|
||||
"payload": payload,
|
||||
"query_params": query_params or {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
|
||||
@@ -196,6 +196,22 @@ class ExecutionStream:
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def active_execution_ids(self) -> list[str]:
|
||||
"""Return IDs of all currently active executions."""
|
||||
return list(self._active_executions.keys())
|
||||
|
||||
@property
|
||||
def is_awaiting_input(self) -> bool:
|
||||
"""True when an active execution is blocked waiting for client input."""
|
||||
if not self._active_executors:
|
||||
return False
|
||||
for executor in self._active_executors.values():
|
||||
for node in executor.node_registry.values():
|
||||
if getattr(node, "_awaiting_input", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _record_execution_result(self, execution_id: str, result: ExecutionResult) -> None:
|
||||
"""Record a completed execution result with retention pruning."""
|
||||
self._execution_results[execution_id] = result
|
||||
@@ -293,8 +309,13 @@ class ExecutionStream:
|
||||
if not self._running:
|
||||
raise RuntimeError(f"ExecutionStream '{self.stream_id}' is not running")
|
||||
|
||||
# Generate execution ID using unified session format
|
||||
if self._session_store:
|
||||
# When resuming, reuse the original session ID so the execution
|
||||
# continues in the same session directory instead of creating a new one.
|
||||
resume_session_id = session_state.get("resume_session_id") if session_state else None
|
||||
|
||||
if resume_session_id:
|
||||
execution_id = resume_session_id
|
||||
elif self._session_store:
|
||||
execution_id = self._session_store.generate_session_id()
|
||||
else:
|
||||
# Fallback to old format if SessionStore not available (shouldn't happen)
|
||||
@@ -337,6 +358,11 @@ class ExecutionStream:
|
||||
"""Run a single execution within the stream."""
|
||||
execution_id = ctx.id
|
||||
|
||||
# When sharing a session with another entry point (resume_session_id),
|
||||
# skip writing initial/final session state — the primary execution
|
||||
# owns the state.json and _write_progress() keeps memory up-to-date.
|
||||
_is_shared_session = bool(ctx.session_state and ctx.session_state.get("resume_session_id"))
|
||||
|
||||
# Acquire semaphore to limit concurrency
|
||||
async with self._semaphore:
|
||||
ctx.status = "running"
|
||||
@@ -399,7 +425,8 @@ class ExecutionStream:
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
# Write initial session state
|
||||
await self._write_session_state(execution_id, ctx)
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx)
|
||||
|
||||
# Create modified graph with entry point
|
||||
# We need to override the entry_node to use our entry point
|
||||
@@ -433,8 +460,9 @@ class ExecutionStream:
|
||||
if result.paused_at:
|
||||
ctx.status = "paused"
|
||||
|
||||
# Write final session state
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
# Write final session state (skip for shared-session executions)
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
|
||||
# Emit completion/failure event
|
||||
if self._event_bus:
|
||||
@@ -485,11 +513,14 @@ class ExecutionStream:
|
||||
# Store result with retention
|
||||
self._record_execution_result(execution_id, result)
|
||||
|
||||
# Write session state
|
||||
if has_result and result.paused_at:
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
else:
|
||||
await self._write_session_state(execution_id, ctx, error="Execution cancelled")
|
||||
# Write session state (skip for shared-session executions)
|
||||
if not _is_shared_session:
|
||||
if has_result and result.paused_at:
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
else:
|
||||
await self._write_session_state(
|
||||
execution_id, ctx, error="Execution cancelled"
|
||||
)
|
||||
|
||||
# Don't re-raise - we've handled it and saved state
|
||||
|
||||
@@ -506,8 +537,9 @@ class ExecutionStream:
|
||||
),
|
||||
)
|
||||
|
||||
# Write error session state
|
||||
await self._write_session_state(execution_id, ctx, error=str(e))
|
||||
# Write error session state (skip for shared-session executions)
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx, error=str(e))
|
||||
|
||||
# End run with failure (for observability)
|
||||
try:
|
||||
@@ -597,10 +629,22 @@ class ExecutionStream:
|
||||
entry_point=self.entry_spec.id,
|
||||
)
|
||||
else:
|
||||
# Create initial state
|
||||
from framework.schemas.session_state import SessionTimestamps
|
||||
# Create initial state — when resuming, preserve the previous
|
||||
# execution's progress so crashes don't lose track of state.
|
||||
from framework.schemas.session_state import (
|
||||
SessionProgress,
|
||||
SessionTimestamps,
|
||||
)
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
ss = ctx.session_state or {}
|
||||
progress = SessionProgress(
|
||||
current_node=ss.get("paused_at") or ss.get("resume_from"),
|
||||
paused_at=ss.get("paused_at"),
|
||||
resume_from=ss.get("paused_at") or ss.get("resume_from"),
|
||||
path=ss.get("execution_path", []),
|
||||
node_visit_counts=ss.get("node_visit_counts", {}),
|
||||
)
|
||||
state = SessionState(
|
||||
session_id=execution_id,
|
||||
stream_id=self.stream_id,
|
||||
@@ -613,6 +657,8 @@ class ExecutionStream:
|
||||
started_at=ctx.started_at.isoformat(),
|
||||
updated_at=now,
|
||||
),
|
||||
progress=progress,
|
||||
memory=ss.get("memory", {}),
|
||||
input_data=ctx.input_data,
|
||||
)
|
||||
|
||||
@@ -629,20 +675,35 @@ class ExecutionStream:
|
||||
logger.error(f"Failed to write state.json for {execution_id}: {e}")
|
||||
|
||||
def _create_modified_graph(self) -> "GraphSpec":
|
||||
"""Create a graph with the entry point overridden."""
|
||||
# Use the existing graph but override entry_node
|
||||
"""Create a graph with the entry point overridden.
|
||||
|
||||
Preserves the original graph's entry_points and async_entry_points
|
||||
so that validation correctly considers ALL entry nodes reachable.
|
||||
Each stream only executes from its own entry_node, but the full
|
||||
graph must validate with all entry points accounted for.
|
||||
"""
|
||||
from framework.graph.edge import GraphSpec
|
||||
|
||||
# Create a copy with modified entry node
|
||||
# Merge entry points: this stream's entry + original graph's primary
|
||||
# entry + any other entry points. This ensures all nodes are
|
||||
# reachable during validation even though this stream only starts
|
||||
# from self.entry_spec.entry_node.
|
||||
merged_entry_points = {
|
||||
"start": self.entry_spec.entry_node,
|
||||
}
|
||||
# Preserve the original graph's primary entry node
|
||||
if self.graph.entry_node:
|
||||
merged_entry_points["primary"] = self.graph.entry_node
|
||||
# Include any explicitly defined entry points from the graph
|
||||
merged_entry_points.update(self.graph.entry_points)
|
||||
|
||||
return GraphSpec(
|
||||
id=self.graph.id,
|
||||
goal_id=self.graph.goal_id,
|
||||
version=self.graph.version,
|
||||
entry_node=self.entry_spec.entry_node, # Use our entry point
|
||||
entry_points={
|
||||
"start": self.entry_spec.entry_node,
|
||||
**self.graph.entry_points,
|
||||
},
|
||||
entry_points=merged_entry_points,
|
||||
async_entry_points=self.graph.async_entry_points,
|
||||
terminal_nodes=self.graph.terminal_nodes,
|
||||
pause_nodes=self.graph.pause_nodes,
|
||||
nodes=self.graph.nodes,
|
||||
@@ -651,6 +712,9 @@ class ExecutionStream:
|
||||
max_tokens=self.graph.max_tokens,
|
||||
max_steps=self.graph.max_steps,
|
||||
cleanup_llm_model=self.graph.cleanup_llm_model,
|
||||
loop_config=self.graph.loop_config,
|
||||
conversation_mode=self.graph.conversation_mode,
|
||||
identity_prompt=self.graph.identity_prompt,
|
||||
)
|
||||
|
||||
async def wait_for_completion(
|
||||
|
||||
@@ -30,14 +30,14 @@ class NodeStepLog(BaseModel):
|
||||
"""Full tool and LLM details for one step within a node.
|
||||
|
||||
For EventLoopNode, each iteration is a step. For single-step nodes
|
||||
(LLMNode, FunctionNode, RouterNode), step_index is 0.
|
||||
(e.g. RouterNode), step_index is 0.
|
||||
|
||||
OTel-aligned fields (trace_id, span_id, execution_id) enable correlation
|
||||
and future OpenTelemetry export without schema changes.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
node_type: str = "" # "event_loop"|"llm_tool_use"|"llm_generate"|"function"|"router"
|
||||
node_type: str = "" # "event_loop" (the only valid type)
|
||||
step_index: int = 0 # iteration number for event_loop, 0 for single-step nodes
|
||||
llm_text: str = ""
|
||||
tool_calls: list[ToolCallLog] = Field(default_factory=list)
|
||||
|
||||
@@ -64,7 +64,7 @@ def sample_graph():
|
||||
id="process-webhook",
|
||||
name="Process Webhook",
|
||||
description="Process incoming webhook",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
input_keys=["webhook_data"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
@@ -72,7 +72,7 @@ def sample_graph():
|
||||
id="process-api",
|
||||
name="Process API Request",
|
||||
description="Process API request",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
input_keys=["request_data"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
@@ -538,7 +538,7 @@ class TestGraphSpecValidation:
|
||||
id="valid-node",
|
||||
name="Valid Node",
|
||||
description="A valid node",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
input_keys=[],
|
||||
output_keys=[],
|
||||
),
|
||||
|
||||
@@ -0,0 +1,717 @@
|
||||
"""
|
||||
Tests for WebhookServer and event-driven entry points.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac as hmac_mod
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.webhook_server import (
|
||||
WebhookRoute,
|
||||
WebhookServer,
|
||||
WebhookServerConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_server(event_bus: EventBus, routes: list[WebhookRoute] | None = None):
|
||||
"""Helper to create a WebhookServer with port=0 for OS-assigned port."""
|
||||
config = WebhookServerConfig(host="127.0.0.1", port=0)
|
||||
server = WebhookServer(event_bus, config)
|
||||
for route in routes or []:
|
||||
server.add_route(route)
|
||||
return server
|
||||
|
||||
|
||||
def _base_url(server: WebhookServer) -> str:
|
||||
"""Get the base URL for a running server."""
|
||||
return f"http://127.0.0.1:{server.port}"
|
||||
|
||||
|
||||
class TestWebhookServerLifecycle:
|
||||
"""Tests for server start/stop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="test", path="/webhooks/test", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
|
||||
await server.start()
|
||||
assert server.is_running
|
||||
assert server.port is not None
|
||||
|
||||
await server.stop()
|
||||
assert not server.is_running
|
||||
assert server.port is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_routes_skips_start(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(bus) # no routes
|
||||
|
||||
await server.start()
|
||||
assert not server.is_running
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_not_started(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(bus)
|
||||
|
||||
# Should be a no-op, not raise
|
||||
await server.stop()
|
||||
assert not server.is_running
|
||||
|
||||
|
||||
class TestWebhookEventPublishing:
|
||||
"""Tests for HTTP request -> EventBus event publishing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_publishes_webhook_received(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="gh", path="/webhooks/github", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/github",
|
||||
json={"action": "opened", "number": 42},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
body = await resp.json()
|
||||
assert body["status"] == "accepted"
|
||||
|
||||
# Give event bus time to dispatch
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
event = received[0]
|
||||
assert event.type == EventType.WEBHOOK_RECEIVED
|
||||
assert event.stream_id == "gh"
|
||||
assert event.data["path"] == "/webhooks/github"
|
||||
assert event.data["method"] == "POST"
|
||||
assert event.data["payload"] == {"action": "opened", "number": 42}
|
||||
assert isinstance(event.data["headers"], dict)
|
||||
assert event.data["query_params"] == {}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_params_included(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="hook", path="/webhooks/hook", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/hook?source=test&v=2",
|
||||
json={"data": "hello"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["query_params"] == {"source": "test", "v": "2"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_json_body(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="raw", path="/webhooks/raw", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/raw",
|
||||
data=b"plain text body",
|
||||
headers={"Content-Type": "text/plain"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["payload"] == {"raw_body": "plain text body"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_body(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="empty", path="/webhooks/empty", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{_base_url(server)}/webhooks/empty") as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["payload"] == {}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_routes(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="a", path="/webhooks/a", methods=["POST"]),
|
||||
WebhookRoute(source_id="b", path="/webhooks/b", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/a", json={"from": "a"}
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/b", json={"from": "b"}
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 2
|
||||
stream_ids = {e.stream_id for e in received}
|
||||
assert stream_ids == {"a", "b"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_stream_subscription(self):
|
||||
"""Subscribers can filter by stream_id (source_id)."""
|
||||
bus = EventBus()
|
||||
a_events = []
|
||||
b_events = []
|
||||
|
||||
async def handle_a(event):
|
||||
a_events.append(event)
|
||||
|
||||
async def handle_b(event):
|
||||
b_events.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handle_a, filter_stream="a")
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handle_b, filter_stream="b")
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="a", path="/webhooks/a", methods=["POST"]),
|
||||
WebhookRoute(source_id="b", path="/webhooks/b", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(f"{_base_url(server)}/webhooks/a", json={"x": 1})
|
||||
await session.post(f"{_base_url(server)}/webhooks/b", json={"x": 2})
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(a_events) == 1
|
||||
assert a_events[0].data["payload"] == {"x": 1}
|
||||
assert len(b_events) == 1
|
||||
assert b_events[0].data["payload"] == {"x": 2}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
class TestHMACVerification:
|
||||
"""Tests for HMAC-SHA256 signature verification."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_signature_accepted(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
secret = "test-secret-key"
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret=secret,
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
body = json.dumps({"event": "push"}).encode()
|
||||
sig = hmac_mod.new(secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Hub-Signature-256": f"sha256={sig}",
|
||||
},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 1
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_signature_rejected(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret="real-secret",
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
json={"event": "push"},
|
||||
headers={"X-Hub-Signature-256": "sha256=invalidsignature"},
|
||||
) as resp:
|
||||
assert resp.status == 401
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 0 # No event published
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_signature_rejected(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret="my-secret",
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# No X-Hub-Signature-256 header
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
json={"event": "push"},
|
||||
) as resp:
|
||||
assert resp.status == 401
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 0
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_secret_skips_verification(self):
|
||||
"""Routes without a secret accept any request."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="open",
|
||||
path="/webhooks/open",
|
||||
methods=["POST"],
|
||||
secret=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/open",
|
||||
json={"data": "test"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 1
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
class TestEventDrivenEntryPoints:
|
||||
"""Tests for event-driven entry points wired through AgentRuntime."""
|
||||
|
||||
def _make_graph_and_goal(self):
|
||||
"""Minimal graph + goal for testing entry point triggering."""
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import SuccessCriterion
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="process-event",
|
||||
name="Process Event",
|
||||
description="Process incoming event",
|
||||
node_type="event_loop",
|
||||
input_keys=["event"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
]
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id="test-goal",
|
||||
version="1.0.0",
|
||||
entry_node="process-event",
|
||||
entry_points={"start": "process-event"},
|
||||
async_entry_points=[],
|
||||
terminal_nodes=[],
|
||||
pause_nodes=[],
|
||||
nodes=nodes,
|
||||
edges=[],
|
||||
)
|
||||
goal = Goal(
|
||||
id="test-goal",
|
||||
name="Test Goal",
|
||||
description="Test",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="sc-1",
|
||||
description="Done",
|
||||
metric="done",
|
||||
target="yes",
|
||||
weight=1.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
return graph, goal
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_subscribes_to_bus(self):
|
||||
"""Entry point with trigger_type='event' subscribes and triggers on matching events."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
config = AgentRuntimeConfig(
|
||||
webhook_host="127.0.0.1",
|
||||
webhook_port=0,
|
||||
webhook_routes=[
|
||||
{"source_id": "gh", "path": "/webhooks/github"},
|
||||
],
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
config=config,
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="gh-handler",
|
||||
name="GitHub Handler",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["webhook_received"],
|
||||
"filter_stream": "gh",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
assert runtime.webhook_server is not None
|
||||
assert runtime.webhook_server.is_running
|
||||
|
||||
port = runtime.webhook_server.port
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/github",
|
||||
json={"action": "push", "ref": "main"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
ep_id, data = trigger_calls[0]
|
||||
assert ep_id == "gh-handler"
|
||||
assert "event" in data
|
||||
assert data["event"]["type"] == "webhook_received"
|
||||
assert data["event"]["stream_id"] == "gh"
|
||||
assert data["event"]["data"]["payload"] == {
|
||||
"action": "push",
|
||||
"ref": "main",
|
||||
}
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
assert runtime.webhook_server is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_filter_stream(self):
|
||||
"""Entry point only triggers for matching stream_id (source_id)."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
config = AgentRuntimeConfig(
|
||||
webhook_routes=[
|
||||
{"source_id": "github", "path": "/webhooks/github"},
|
||||
{"source_id": "stripe", "path": "/webhooks/stripe"},
|
||||
],
|
||||
webhook_port=0,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
config=config,
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="gh-only",
|
||||
name="GitHub Only",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["webhook_received"],
|
||||
"filter_stream": "github",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
port = runtime.webhook_server.port
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# POST to stripe — should NOT trigger
|
||||
await session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/stripe",
|
||||
json={"type": "payment"},
|
||||
)
|
||||
# POST to github — should trigger
|
||||
await session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/github",
|
||||
json={"action": "opened"},
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
assert trigger_calls[0][0] == "gh-only"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_webhook_routes_skips_server(self):
|
||||
"""Runtime without webhook_routes does not start a webhook server."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="manual",
|
||||
name="Manual",
|
||||
entry_node="process-event",
|
||||
trigger_type="manual",
|
||||
)
|
||||
)
|
||||
|
||||
await runtime.start()
|
||||
try:
|
||||
assert runtime.webhook_server is None
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_custom_event(self):
|
||||
"""Entry point can subscribe to CUSTOM events, not just webhooks."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="custom-handler",
|
||||
name="Custom Handler",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["custom"],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
await runtime.event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CUSTOM,
|
||||
stream_id="some-source",
|
||||
data={"key": "value"},
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
assert trigger_calls[0][0] == "custom-handler"
|
||||
assert trigger_calls[0][1]["event"]["type"] == "custom"
|
||||
assert trigger_calls[0][1]["event"]["data"]["key"] == "value"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Webhook HTTP Server - Receives HTTP requests and publishes them as EventBus events.
|
||||
|
||||
Only starts if webhook-type entry points are registered. Uses aiohttp for
|
||||
a lightweight embedded HTTP server that runs within the existing asyncio loop.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebhookRoute:
|
||||
"""A registered webhook route derived from an EntryPointSpec."""
|
||||
|
||||
source_id: str
|
||||
path: str
|
||||
methods: list[str]
|
||||
secret: str | None = None # For HMAC-SHA256 signature verification
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebhookServerConfig:
|
||||
"""Configuration for the webhook HTTP server."""
|
||||
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8080
|
||||
|
||||
|
||||
class WebhookServer:
|
||||
"""
|
||||
Embedded HTTP server that receives webhook requests and publishes
|
||||
them as WEBHOOK_RECEIVED events on the EventBus.
|
||||
|
||||
The server's only job is: receive HTTP -> publish AgentEvent.
|
||||
Subscribers decide what to do with the event.
|
||||
|
||||
Lifecycle:
|
||||
server = WebhookServer(event_bus, config)
|
||||
server.add_route(WebhookRoute(...))
|
||||
await server.start()
|
||||
# ... server running ...
|
||||
await server.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_bus: EventBus,
|
||||
config: WebhookServerConfig | None = None,
|
||||
):
|
||||
self._event_bus = event_bus
|
||||
self._config = config or WebhookServerConfig()
|
||||
self._routes: dict[str, WebhookRoute] = {} # path -> route
|
||||
self._app: web.Application | None = None
|
||||
self._runner: web.AppRunner | None = None
|
||||
self._site: web.TCPSite | None = None
|
||||
|
||||
def add_route(self, route: WebhookRoute) -> None:
|
||||
"""Register a webhook route."""
|
||||
self._routes[route.path] = route
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the HTTP server. No-op if no routes registered."""
|
||||
if not self._routes:
|
||||
logger.debug("No webhook routes registered, skipping server start")
|
||||
return
|
||||
|
||||
self._app = web.Application()
|
||||
|
||||
for path, route in self._routes.items():
|
||||
for method in route.methods:
|
||||
self._app.router.add_route(method, path, self._handle_request)
|
||||
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
self._site = web.TCPSite(
|
||||
self._runner,
|
||||
self._config.host,
|
||||
self._config.port,
|
||||
)
|
||||
await self._site.start()
|
||||
logger.info(
|
||||
f"Webhook server started on {self._config.host}:{self._config.port} "
|
||||
f"with {len(self._routes)} route(s)"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the HTTP server gracefully."""
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._app = None
|
||||
self._site = None
|
||||
logger.info("Webhook server stopped")
|
||||
|
||||
async def _handle_request(self, request: web.Request) -> web.Response:
|
||||
"""Handle an incoming webhook request."""
|
||||
path = request.path
|
||||
route = self._routes.get(path)
|
||||
|
||||
if route is None:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
|
||||
# Read body
|
||||
try:
|
||||
body = await request.read()
|
||||
except Exception:
|
||||
return web.json_response(
|
||||
{"error": "Failed to read request body"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Verify HMAC signature if secret is configured
|
||||
if route.secret:
|
||||
if not self._verify_signature(request, body, route.secret):
|
||||
return web.json_response({"error": "Invalid signature"}, status=401)
|
||||
|
||||
# Parse body as JSON (fall back to raw text for non-JSON)
|
||||
try:
|
||||
payload = json.loads(body) if body else {}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = {"raw_body": body.decode("utf-8", errors="replace")}
|
||||
|
||||
# Publish event to bus
|
||||
await self._event_bus.emit_webhook_received(
|
||||
source_id=route.source_id,
|
||||
path=path,
|
||||
method=request.method,
|
||||
headers=dict(request.headers),
|
||||
payload=payload,
|
||||
query_params=dict(request.query),
|
||||
)
|
||||
|
||||
return web.json_response({"status": "accepted"}, status=202)
|
||||
|
||||
def _verify_signature(
|
||||
self,
|
||||
request: web.Request,
|
||||
body: bytes,
|
||||
secret: str,
|
||||
) -> bool:
|
||||
"""Verify HMAC-SHA256 signature from X-Hub-Signature-256 header."""
|
||||
signature_header = request.headers.get("X-Hub-Signature-256", "")
|
||||
if not signature_header.startswith("sha256="):
|
||||
return False
|
||||
|
||||
expected_sig = signature_header[7:] # strip "sha256="
|
||||
computed_sig = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
body,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
return hmac.compare_digest(expected_sig, computed_sig)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the server is running."""
|
||||
return self._site is not None
|
||||
|
||||
@property
|
||||
def port(self) -> int | None:
|
||||
"""Return the actual listening port (useful when configured with port=0)."""
|
||||
if self._site and self._site._server and self._site._server.sockets:
|
||||
return self._site._server.sockets[0].getsockname()[1]
|
||||
return None
|
||||
@@ -156,8 +156,13 @@ class SessionState(BaseModel):
|
||||
@computed_field
|
||||
@property
|
||||
def is_resumable(self) -> bool:
|
||||
"""Can this session be resumed?"""
|
||||
return self.status == SessionStatus.PAUSED and self.progress.resume_from is not None
|
||||
"""Can this session be resumed?
|
||||
|
||||
Every non-completed session is resumable. If resume_from/paused_at
|
||||
aren't set, the executor falls back to the graph entry point —
|
||||
so we don't gate on those. Even catastrophic failures are resumable.
|
||||
"""
|
||||
return self.status != SessionStatus.COMPLETED
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
@@ -279,9 +284,16 @@ class SessionState(BaseModel):
|
||||
|
||||
def to_session_state_dict(self) -> dict[str, Any]:
|
||||
"""Convert to session_state format for GraphExecutor.execute()."""
|
||||
# Derive resume target: explicit > last node in path > entry point
|
||||
resume_from = (
|
||||
self.progress.resume_from
|
||||
or self.progress.paused_at
|
||||
or (self.progress.path[-1] if self.progress.path else None)
|
||||
)
|
||||
return {
|
||||
"paused_at": self.progress.paused_at,
|
||||
"resume_from": self.progress.resume_from,
|
||||
"paused_at": resume_from,
|
||||
"resume_from": resume_from,
|
||||
"memory": self.memory,
|
||||
"next_node": None,
|
||||
"execution_path": self.progress.path,
|
||||
"node_visit_counts": self.progress.node_visit_counts,
|
||||
}
|
||||
|
||||
+137
-47
@@ -1,18 +1,18 @@
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal, Vertical
|
||||
from textual.widgets import Footer, Input, Label
|
||||
from textual.containers import Container, Horizontal
|
||||
from textual.widgets import Footer, Label
|
||||
|
||||
from framework.runtime.agent_runtime import AgentRuntime
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.tui.widgets.chat_repl import ChatRepl
|
||||
from framework.tui.widgets.graph_view import GraphOverview
|
||||
from framework.tui.widgets.log_pane import LogPane
|
||||
from framework.tui.widgets.selectable_rich_log import SelectableRichLog
|
||||
|
||||
|
||||
@@ -136,28 +136,15 @@ class AdenTUI(App):
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
#left-pane {
|
||||
width: 60%;
|
||||
height: 100%;
|
||||
layout: vertical;
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
GraphOverview {
|
||||
height: 40%;
|
||||
width: 40%;
|
||||
height: 100%;
|
||||
background: $panel;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
LogPane {
|
||||
height: 60%;
|
||||
background: $surface;
|
||||
padding: 0;
|
||||
margin-bottom: 1;
|
||||
}
|
||||
|
||||
ChatRepl {
|
||||
width: 40%;
|
||||
width: 60%;
|
||||
height: 100%;
|
||||
background: $panel;
|
||||
border-left: tall $primary;
|
||||
@@ -180,13 +167,13 @@ class AdenTUI(App):
|
||||
scrollbar-color: $primary;
|
||||
}
|
||||
|
||||
Input {
|
||||
ChatTextArea {
|
||||
background: $surface;
|
||||
border: tall $primary;
|
||||
margin-top: 1;
|
||||
}
|
||||
|
||||
Input:focus {
|
||||
ChatTextArea:focus {
|
||||
border: tall $accent;
|
||||
}
|
||||
|
||||
@@ -208,8 +195,10 @@ class AdenTUI(App):
|
||||
Binding("ctrl+c", "ctrl_c", "Interrupt", show=False, priority=True),
|
||||
Binding("super+c", "ctrl_c", "Copy", show=False, priority=True),
|
||||
Binding("ctrl+s", "screenshot", "Screenshot (SVG)", show=True, priority=True),
|
||||
Binding("ctrl+l", "toggle_logs", "Toggle Logs", show=True, priority=True),
|
||||
Binding("ctrl+z", "pause_execution", "Pause", show=True, priority=True),
|
||||
Binding("ctrl+r", "show_sessions", "Sessions", show=True, priority=True),
|
||||
Binding("ctrl+p", "attach_pdf", "Attach PDF", show=True, priority=True),
|
||||
Binding("tab", "focus_next", "Next Panel", show=True),
|
||||
Binding("shift+tab", "focus_previous", "Previous Panel", show=False),
|
||||
]
|
||||
@@ -223,7 +212,6 @@ class AdenTUI(App):
|
||||
super().__init__()
|
||||
|
||||
self.runtime = runtime
|
||||
self.log_pane = LogPane()
|
||||
self.graph_view = GraphOverview(runtime)
|
||||
self.chat_repl = ChatRepl(runtime, resume_session, resume_checkpoint)
|
||||
self.status_bar = StatusBar(graph_id=runtime.graph.id)
|
||||
@@ -253,11 +241,7 @@ class AdenTUI(App):
|
||||
yield self.status_bar
|
||||
|
||||
yield Horizontal(
|
||||
Vertical(
|
||||
self.log_pane,
|
||||
self.graph_view,
|
||||
id="left-pane",
|
||||
),
|
||||
self.graph_view,
|
||||
self.chat_repl,
|
||||
)
|
||||
|
||||
@@ -328,7 +312,7 @@ class AdenTUI(App):
|
||||
if record.name.startswith(("textual", "LiteLLM", "litellm")):
|
||||
continue
|
||||
|
||||
self.log_pane.write_python_log(record)
|
||||
self.chat_repl.write_python_log(record)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -350,6 +334,14 @@ class AdenTUI(App):
|
||||
EventType.CONSTRAINT_VIOLATION,
|
||||
EventType.STATE_CHANGED,
|
||||
EventType.NODE_INPUT_BLOCKED,
|
||||
EventType.CONTEXT_COMPACTED,
|
||||
EventType.NODE_INTERNAL_OUTPUT,
|
||||
EventType.JUDGE_VERDICT,
|
||||
EventType.OUTPUT_KEY_SET,
|
||||
EventType.NODE_RETRY,
|
||||
EventType.EDGE_TRAVERSED,
|
||||
EventType.EXECUTION_PAUSED,
|
||||
EventType.EXECUTION_RESUMED,
|
||||
]
|
||||
|
||||
_LOG_PANE_EVENTS = frozenset(_EVENT_TYPES) - {
|
||||
@@ -368,15 +360,36 @@ class AdenTUI(App):
|
||||
pass
|
||||
|
||||
async def _handle_event(self, event: AgentEvent) -> None:
|
||||
"""Called from the agent thread — bridge to Textual's main thread."""
|
||||
"""Bridge events to Textual's main thread for UI updates.
|
||||
|
||||
Events may arrive from the agent-execution thread (normal LLM/tool
|
||||
work) or from the Textual thread itself (e.g. webhook server events).
|
||||
``call_from_thread`` requires a *different* thread, so we detect
|
||||
which thread we're on and act accordingly.
|
||||
"""
|
||||
try:
|
||||
self.call_from_thread(self._route_event, event)
|
||||
except Exception:
|
||||
pass
|
||||
if threading.get_ident() == self._thread_id:
|
||||
# Already on Textual's thread — call directly.
|
||||
self._route_event(event)
|
||||
else:
|
||||
# On a different thread — bridge via call_from_thread.
|
||||
self.call_from_thread(self._route_event, event)
|
||||
except Exception as e:
|
||||
logging.getLogger("tui.events").error(
|
||||
"call_from_thread failed for %s (node=%s): %s",
|
||||
event.type.value,
|
||||
event.node_id or "?",
|
||||
e,
|
||||
)
|
||||
|
||||
def _route_event(self, event: AgentEvent) -> None:
|
||||
"""Route incoming events to widgets. Runs on Textual's main thread."""
|
||||
if not self.is_ready:
|
||||
logging.getLogger("tui.events").warning(
|
||||
"Event dropped (not ready): %s node=%s",
|
||||
event.type.value,
|
||||
event.node_id or "?",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -407,6 +420,35 @@ class AdenTUI(App):
|
||||
self.chat_repl.handle_input_requested(
|
||||
event.node_id or event.data.get("node_id", ""),
|
||||
)
|
||||
elif et == EventType.NODE_LOOP_STARTED:
|
||||
self.chat_repl.handle_node_started(event.node_id or "")
|
||||
elif et == EventType.NODE_LOOP_ITERATION:
|
||||
self.chat_repl.handle_loop_iteration(event.data.get("iteration", 0))
|
||||
|
||||
# Track active node in chat_repl for mid-execution input
|
||||
if et == EventType.NODE_LOOP_STARTED:
|
||||
self.chat_repl.handle_node_started(event.node_id or "")
|
||||
elif et == EventType.NODE_LOOP_COMPLETED:
|
||||
self.chat_repl.handle_node_completed(event.node_id or "")
|
||||
|
||||
# Non-client-facing node output → chat repl
|
||||
if et == EventType.NODE_INTERNAL_OUTPUT:
|
||||
content = event.data.get("content", "")
|
||||
if content.strip():
|
||||
self.chat_repl.handle_internal_output(event.node_id or "", content)
|
||||
|
||||
# Execution paused/resumed → chat repl
|
||||
if et == EventType.EXECUTION_PAUSED:
|
||||
reason = event.data.get("reason", "")
|
||||
self.chat_repl.handle_execution_paused(event.node_id or "", reason)
|
||||
elif et == EventType.EXECUTION_RESUMED:
|
||||
self.chat_repl.handle_execution_resumed(event.node_id or "")
|
||||
|
||||
# Goal achieved / constraint violation → chat repl
|
||||
if et == EventType.GOAL_ACHIEVED:
|
||||
self.chat_repl.handle_goal_achieved(event.data)
|
||||
elif et == EventType.CONSTRAINT_VIOLATION:
|
||||
self.chat_repl.handle_constraint_violation(event.data)
|
||||
|
||||
# --- Graph view events ---
|
||||
if et in (
|
||||
@@ -444,6 +486,13 @@ class AdenTUI(App):
|
||||
started=False,
|
||||
)
|
||||
|
||||
# Edge traversal → graph view
|
||||
if et == EventType.EDGE_TRAVERSED:
|
||||
self.graph_view.handle_edge_traversed(
|
||||
event.data.get("source_node", ""),
|
||||
event.data.get("target_node", ""),
|
||||
)
|
||||
|
||||
# --- Status bar events ---
|
||||
if et == EventType.EXECUTION_STARTED:
|
||||
entry_node = event.data.get("entry_node") or (
|
||||
@@ -464,12 +513,36 @@ class AdenTUI(App):
|
||||
self.status_bar.set_node_detail("thinking...")
|
||||
elif et == EventType.NODE_STALLED:
|
||||
self.status_bar.set_node_detail(f"stalled: {event.data.get('reason', '')}")
|
||||
elif et == EventType.CONTEXT_COMPACTED:
|
||||
before = event.data.get("usage_before", "?")
|
||||
after = event.data.get("usage_after", "?")
|
||||
self.status_bar.set_node_detail(f"compacted: {before}% \u2192 {after}%")
|
||||
elif et == EventType.JUDGE_VERDICT:
|
||||
action = event.data.get("action", "?")
|
||||
self.status_bar.set_node_detail(f"judge: {action}")
|
||||
elif et == EventType.OUTPUT_KEY_SET:
|
||||
key = event.data.get("key", "?")
|
||||
self.status_bar.set_node_detail(f"set: {key}")
|
||||
elif et == EventType.NODE_RETRY:
|
||||
retry = event.data.get("retry_count", "?")
|
||||
max_r = event.data.get("max_retries", "?")
|
||||
self.status_bar.set_node_detail(f"retry {retry}/{max_r}")
|
||||
elif et == EventType.EXECUTION_PAUSED:
|
||||
self.status_bar.set_node_detail("paused")
|
||||
elif et == EventType.EXECUTION_RESUMED:
|
||||
self.status_bar.set_node_detail("resumed")
|
||||
|
||||
# --- Log pane events ---
|
||||
# --- Log events (inline in chat) ---
|
||||
if et in self._LOG_PANE_EVENTS:
|
||||
self.log_pane.write_event(event)
|
||||
except Exception:
|
||||
pass
|
||||
self.chat_repl.write_log_event(event)
|
||||
except Exception as e:
|
||||
logging.getLogger("tui.events").error(
|
||||
"Route failed for %s (node=%s): %s",
|
||||
event.type.value,
|
||||
event.node_id or "?",
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def save_screenshot(self, filename: str | None = None) -> str:
|
||||
"""Save a screenshot of the current screen as SVG (viewable in browsers).
|
||||
@@ -504,8 +577,8 @@ class AdenTUI(App):
|
||||
original_chat_border = chat_widget.styles.border_left
|
||||
chat_widget.styles.border_left = ("none", "transparent")
|
||||
|
||||
# Hide all Input widget borders
|
||||
input_widgets = self.query("Input")
|
||||
# Hide all TextArea widget borders
|
||||
input_widgets = self.query("ChatTextArea")
|
||||
original_input_borders = []
|
||||
for input_widget in input_widgets:
|
||||
original_input_borders.append(input_widget.styles.border)
|
||||
@@ -535,6 +608,12 @@ class AdenTUI(App):
|
||||
except Exception as e:
|
||||
self.notify(f"Screenshot failed: {e}", severity="error", timeout=5)
|
||||
|
||||
def action_toggle_logs(self) -> None:
|
||||
"""Toggle inline log display in chat (bound to Ctrl+L)."""
|
||||
self.chat_repl.toggle_logs()
|
||||
mode = "ON" if self.chat_repl._show_logs else "OFF"
|
||||
self.notify(f"Logs {mode}", severity="information", timeout=2)
|
||||
|
||||
def action_pause_execution(self) -> None:
|
||||
"""Immediately pause execution by cancelling task (bound to Ctrl+Z)."""
|
||||
try:
|
||||
@@ -575,19 +654,12 @@ class AdenTUI(App):
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
def action_show_sessions(self) -> None:
|
||||
async def action_show_sessions(self) -> None:
|
||||
"""Show sessions list (bound to Ctrl+R)."""
|
||||
# Send /sessions command to chat input
|
||||
try:
|
||||
chat_repl = self.query_one(ChatRepl)
|
||||
chat_input = chat_repl.query_one("#chat-input", Input)
|
||||
chat_input.value = "/sessions"
|
||||
# Trigger submission
|
||||
self.notify(
|
||||
"💡 Type /sessions in the chat to see all sessions",
|
||||
severity="information",
|
||||
timeout=3,
|
||||
)
|
||||
await chat_repl._submit_input("/sessions")
|
||||
except Exception:
|
||||
self.notify(
|
||||
"Use /sessions command to see all sessions",
|
||||
@@ -595,6 +667,24 @@ class AdenTUI(App):
|
||||
timeout=3,
|
||||
)
|
||||
|
||||
async def action_attach_pdf(self) -> None:
|
||||
"""Open native OS file dialog for PDF selection (bound to Ctrl+P)."""
|
||||
from framework.tui.widgets.file_browser import _has_gui, pick_pdf_file
|
||||
|
||||
if not _has_gui():
|
||||
self.notify(
|
||||
"No GUI available. Use /attach <path> instead.",
|
||||
severity="warning",
|
||||
timeout=5,
|
||||
)
|
||||
return
|
||||
|
||||
self.notify("Opening file dialog...", severity="information", timeout=2)
|
||||
path = await pick_pdf_file()
|
||||
|
||||
if path is not None:
|
||||
self.chat_repl.attach_pdf(path)
|
||||
|
||||
async def on_unmount(self) -> None:
|
||||
"""Cleanup on app shutdown - cancel execution which will save state."""
|
||||
self.is_ready = False
|
||||
|
||||
@@ -15,19 +15,49 @@ Client-facing input:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Vertical
|
||||
from textual.widgets import Input, Label
|
||||
from textual.message import Message
|
||||
from textual.widgets import Label, TextArea
|
||||
|
||||
from framework.runtime.agent_runtime import AgentRuntime
|
||||
from framework.runtime.event_bus import AgentEvent
|
||||
from framework.tui.widgets.log_pane import format_event, format_python_log
|
||||
from framework.tui.widgets.selectable_rich_log import SelectableRichLog as RichLog
|
||||
|
||||
|
||||
class ChatTextArea(TextArea):
|
||||
"""TextArea that submits on Enter and inserts newlines on Shift+Enter."""
|
||||
|
||||
class Submitted(Message):
|
||||
"""Posted when the user presses Enter."""
|
||||
|
||||
def __init__(self, text: str) -> None:
|
||||
super().__init__()
|
||||
self.text = text
|
||||
|
||||
async def _on_key(self, event) -> None:
|
||||
if event.key == "enter":
|
||||
text = self.text.strip()
|
||||
self.clear()
|
||||
if text:
|
||||
self.post_message(self.Submitted(text))
|
||||
event.stop()
|
||||
event.prevent_default()
|
||||
elif event.key == "shift+enter":
|
||||
event.key = "enter"
|
||||
await super()._on_key(event)
|
||||
else:
|
||||
await super()._on_key(event)
|
||||
|
||||
|
||||
class ChatRepl(Vertical):
|
||||
"""Widget for interactive chat/REPL."""
|
||||
|
||||
@@ -56,16 +86,17 @@ class ChatRepl(Vertical):
|
||||
display: none;
|
||||
}
|
||||
|
||||
ChatRepl > Input {
|
||||
ChatRepl > ChatTextArea {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
max-height: 7;
|
||||
dock: bottom;
|
||||
background: $surface;
|
||||
border: tall $primary;
|
||||
margin-top: 1;
|
||||
}
|
||||
|
||||
ChatRepl > Input:focus {
|
||||
ChatRepl > ChatTextArea:focus {
|
||||
border: tall $accent;
|
||||
}
|
||||
"""
|
||||
@@ -83,9 +114,13 @@ class ChatRepl(Vertical):
|
||||
self._waiting_for_input: bool = False
|
||||
self._input_node_id: str | None = None
|
||||
self._pending_ask_question: str = ""
|
||||
self._active_node_id: str | None = None # Currently executing node
|
||||
self._resume_session = resume_session
|
||||
self._resume_checkpoint = resume_checkpoint
|
||||
self._session_index: list[str] = [] # IDs from last listing
|
||||
self._show_logs: bool = False # Clean mode by default
|
||||
self._log_buffer: list[str] = [] # Buffered log lines for backfill on toggle ON
|
||||
self._attached_pdf: dict | None = None # Pending PDF attachment for next message
|
||||
|
||||
# Dedicated event loop for agent execution.
|
||||
# Keeps blocking runtime code (LLM calls, MCP tools) off
|
||||
@@ -108,7 +143,7 @@ class ChatRepl(Vertical):
|
||||
min_width=0,
|
||||
)
|
||||
yield Label("Agent is processing...", id="processing-indicator")
|
||||
yield Input(placeholder="Enter input for agent...", id="chat-input")
|
||||
yield ChatTextArea(id="chat-input", placeholder="Enter input for agent...")
|
||||
|
||||
# Regex for file:// URIs that are NOT already inside Rich [link=...] markup
|
||||
_FILE_URI_RE = re.compile(r"(?<!\[link=)(file://[^\s)\]>*]+)")
|
||||
@@ -131,6 +166,31 @@ class ChatRepl(Vertical):
|
||||
if was_at_bottom:
|
||||
history.scroll_end(animate=False)
|
||||
|
||||
def toggle_logs(self) -> None:
|
||||
"""Toggle inline log display on/off. Backfills buffered logs on toggle ON."""
|
||||
self._show_logs = not self._show_logs
|
||||
if self._show_logs and self._log_buffer:
|
||||
self._write_history("[dim]--- Backfilling logs ---[/dim]")
|
||||
for line in self._log_buffer:
|
||||
self._write_history(line)
|
||||
self._write_history("[dim]--- Live logs ---[/dim]")
|
||||
mode = "ON (dirty)" if self._show_logs else "OFF (clean)"
|
||||
self._write_history(f"[dim]Logs {mode}[/dim]")
|
||||
|
||||
def write_log_event(self, event: AgentEvent) -> None:
|
||||
"""Buffer a formatted agent event. Display inline if logs are ON."""
|
||||
formatted = format_event(event)
|
||||
self._log_buffer.append(formatted)
|
||||
if self._show_logs:
|
||||
self._write_history(formatted)
|
||||
|
||||
def write_python_log(self, record: logging.LogRecord) -> None:
|
||||
"""Buffer a formatted Python log record. Display inline if logs are ON."""
|
||||
formatted = format_python_log(record)
|
||||
self._log_buffer.append(formatted)
|
||||
if self._show_logs:
|
||||
self._write_history(formatted)
|
||||
|
||||
async def _handle_command(self, command: str) -> None:
|
||||
"""Handle slash commands for session and checkpoint operations."""
|
||||
parts = command.split(maxsplit=2)
|
||||
@@ -138,6 +198,9 @@ class ChatRepl(Vertical):
|
||||
|
||||
if cmd == "/help":
|
||||
self._write_history("""[bold cyan]Available Commands:[/bold cyan]
|
||||
[bold]/attach[/bold] - Open file dialog to attach a PDF
|
||||
[bold]/attach[/bold] <file_path> - Attach a PDF from a specific path
|
||||
[bold]/detach[/bold] - Remove the currently attached PDF
|
||||
[bold]/sessions[/bold] - List all sessions for this agent
|
||||
[bold]/sessions[/bold] <session_id> - Show session details and checkpoints
|
||||
[bold]/resume[/bold] - List sessions and pick one to resume
|
||||
@@ -148,12 +211,11 @@ class ChatRepl(Vertical):
|
||||
[bold]/help[/bold] - Show this help message
|
||||
|
||||
[dim]Examples:[/dim]
|
||||
/attach [dim]# Open file picker dialog[/dim]
|
||||
/attach ~/Documents/report.pdf [dim]# Attach a specific PDF[/dim]
|
||||
/detach [dim]# Remove attached PDF[/dim]
|
||||
/sessions [dim]# List all sessions[/dim]
|
||||
/sessions session_20260208_143022 [dim]# Show session details[/dim]
|
||||
/resume [dim]# Show numbered session list[/dim]
|
||||
/resume 1 [dim]# Resume first listed session[/dim]
|
||||
/resume session_20260208_143022 [dim]# Resume by full session ID[/dim]
|
||||
/recover session_20260208_143022 cp_xxx [dim]# Recover from specific checkpoint[/dim]
|
||||
/pause [dim]# Pause (or Ctrl+Z)[/dim]
|
||||
""")
|
||||
elif cmd == "/sessions":
|
||||
@@ -194,6 +256,16 @@ class ChatRepl(Vertical):
|
||||
session_id = parts[1].strip()
|
||||
checkpoint_id = parts[2].strip()
|
||||
await self._cmd_recover(session_id, checkpoint_id)
|
||||
elif cmd == "/attach":
|
||||
file_path = parts[1].strip() if len(parts) > 1 else None
|
||||
await self._cmd_attach(file_path)
|
||||
elif cmd == "/detach":
|
||||
if self._attached_pdf:
|
||||
name = self._attached_pdf["filename"]
|
||||
self._attached_pdf = None
|
||||
self._write_history(f"[dim]Detached: {name}[/dim]")
|
||||
else:
|
||||
self._write_history("[dim]No PDF attached.[/dim]")
|
||||
elif cmd == "/pause":
|
||||
await self._cmd_pause()
|
||||
else:
|
||||
@@ -202,6 +274,63 @@ class ChatRepl(Vertical):
|
||||
"Type [bold]/help[/bold] for available commands"
|
||||
)
|
||||
|
||||
def attach_pdf(self, path: Path) -> None:
|
||||
"""Validate and stage a PDF file for the next message.
|
||||
|
||||
Copies the PDF to ~/.hive/assets/ and stores the path. The agent's
|
||||
pdf_read tool handles text extraction at runtime.
|
||||
|
||||
Called by /attach <path> or by the native file dialog.
|
||||
"""
|
||||
path = Path(path).expanduser().resolve()
|
||||
|
||||
if not path.exists():
|
||||
self._write_history(f"[bold red]Error:[/bold red] File not found: {path}")
|
||||
return
|
||||
if path.suffix.lower() != ".pdf":
|
||||
self._write_history("[bold red]Error:[/bold red] Only PDF files are supported")
|
||||
return
|
||||
|
||||
# Copy to ~/.hive/assets/, deduplicating like a normal filesystem:
|
||||
# resume.pdf → resume(1).pdf → resume(2).pdf
|
||||
assets_dir = Path.home() / ".hive" / "assets"
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = assets_dir / path.name
|
||||
counter = 1
|
||||
while dest.exists():
|
||||
dest = assets_dir / f"{path.stem}({counter}){path.suffix}"
|
||||
counter += 1
|
||||
shutil.copy2(path, dest)
|
||||
|
||||
self._attached_pdf = {
|
||||
"filename": path.name,
|
||||
"path": str(dest),
|
||||
}
|
||||
|
||||
self._write_history(f"[green]Attached:[/green] {path.name}")
|
||||
self._write_history("[dim]PDF will be read by the agent on your next message.[/dim]")
|
||||
|
||||
async def _cmd_attach(self, file_path: str | None = None) -> None:
|
||||
"""Attach a PDF file for context injection into the next message."""
|
||||
if file_path is None:
|
||||
from framework.tui.widgets.file_browser import _has_gui, pick_pdf_file
|
||||
|
||||
if not _has_gui():
|
||||
self._write_history(
|
||||
"[bold yellow]No GUI available.[/bold yellow] "
|
||||
"Provide a path: [bold]/attach /path/to/file.pdf[/bold]"
|
||||
)
|
||||
return
|
||||
|
||||
self._write_history("[dim]Opening file dialog...[/dim]")
|
||||
path = await pick_pdf_file()
|
||||
|
||||
if path is not None:
|
||||
self.attach_pdf(path)
|
||||
return
|
||||
|
||||
self.attach_pdf(Path(file_path))
|
||||
|
||||
async def _cmd_sessions(self, session_id: str | None) -> None:
|
||||
"""List sessions or show details of a specific session."""
|
||||
try:
|
||||
@@ -451,6 +580,7 @@ class ChatRepl(Vertical):
|
||||
if paused_at:
|
||||
# Has paused_at - resume from there
|
||||
resume_session_state = {
|
||||
"resume_session_id": session_id,
|
||||
"paused_at": paused_at,
|
||||
"memory": state.get("memory", {}),
|
||||
"execution_path": progress.get("path", []),
|
||||
@@ -458,8 +588,13 @@ class ChatRepl(Vertical):
|
||||
}
|
||||
resume_info = f"From node: [cyan]{paused_at}[/cyan]"
|
||||
else:
|
||||
# No paused_at - just retry with same input
|
||||
resume_session_state = {}
|
||||
# No paused_at - retry with same input but reuse session directory
|
||||
resume_session_state = {
|
||||
"resume_session_id": session_id,
|
||||
"memory": state.get("memory", {}),
|
||||
"execution_path": progress.get("path", []),
|
||||
"node_visit_counts": progress.get("node_visit_counts", {}),
|
||||
}
|
||||
resume_info = "Retrying with same input"
|
||||
|
||||
# Display resume info
|
||||
@@ -485,7 +620,7 @@ class ChatRepl(Vertical):
|
||||
indicator.display = True
|
||||
|
||||
# Update placeholder
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent resuming...)"
|
||||
|
||||
# Trigger execution with resume state
|
||||
@@ -563,6 +698,7 @@ class ChatRepl(Vertical):
|
||||
|
||||
# Create session_state for checkpoint recovery
|
||||
recover_session_state = {
|
||||
"resume_session_id": session_id,
|
||||
"resume_from_checkpoint": checkpoint_id,
|
||||
}
|
||||
|
||||
@@ -572,7 +708,7 @@ class ChatRepl(Vertical):
|
||||
indicator.display = True
|
||||
|
||||
# Update placeholder
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent recovering...)"
|
||||
|
||||
# Trigger execution with checkpoint recovery
|
||||
@@ -739,9 +875,12 @@ class ChatRepl(Vertical):
|
||||
# Silently fail - don't block TUI startup
|
||||
pass
|
||||
|
||||
async def on_input_submitted(self, message: Input.Submitted) -> None:
|
||||
"""Handle input submission — either start new execution or inject input."""
|
||||
user_input = message.value.strip()
|
||||
async def on_chat_text_area_submitted(self, message: ChatTextArea.Submitted) -> None:
|
||||
"""Handle chat input submission."""
|
||||
await self._submit_input(message.text)
|
||||
|
||||
async def _submit_input(self, user_input: str) -> None:
|
||||
"""Handle submitted text — either start new execution or inject input."""
|
||||
if not user_input:
|
||||
return
|
||||
|
||||
@@ -749,16 +888,14 @@ class ChatRepl(Vertical):
|
||||
# Commands work during execution, during client-facing input, anytime
|
||||
if user_input.startswith("/"):
|
||||
await self._handle_command(user_input)
|
||||
message.input.value = ""
|
||||
return
|
||||
|
||||
# Client-facing input: route to the waiting node
|
||||
if self._waiting_for_input and self._input_node_id:
|
||||
self._write_history(f"[bold green]You:[/bold green] {user_input}")
|
||||
message.input.value = ""
|
||||
|
||||
# Keep input enabled for commands (but change placeholder)
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.placeholder = "Commands: /pause, /sessions (agent processing...)"
|
||||
self._waiting_for_input = False
|
||||
|
||||
@@ -778,16 +915,29 @@ class ChatRepl(Vertical):
|
||||
self._write_history(f"[bold red]Error delivering input:[/bold red] {e}")
|
||||
return
|
||||
|
||||
# Double-submit guard: reject input while an execution is in-flight
|
||||
# Mid-execution input: inject into the active node's conversation
|
||||
if self._current_exec_id is not None and self._active_node_id:
|
||||
self._write_history(f"[bold green]You:[/bold green] {user_input}")
|
||||
node_id = self._active_node_id
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.runtime.inject_input(node_id, user_input),
|
||||
self._agent_loop,
|
||||
)
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error delivering input:[/bold red] {e}")
|
||||
return
|
||||
|
||||
# Double-submit guard: no active node to inject into
|
||||
if self._current_exec_id is not None:
|
||||
self._write_history("[dim]Agent is still running — please wait.[/dim]")
|
||||
return
|
||||
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
|
||||
# Append user message and clear input
|
||||
# Append user message
|
||||
self._write_history(f"[bold green]You:[/bold green] {user_input}")
|
||||
message.input.value = ""
|
||||
|
||||
try:
|
||||
# Get entry point
|
||||
@@ -813,9 +963,16 @@ class ChatRepl(Vertical):
|
||||
indicator.display = True
|
||||
|
||||
# Keep input enabled for commands during execution
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.placeholder = "Commands available: /pause, /sessions, /help"
|
||||
|
||||
# Build input data, injecting attached PDF file path if present
|
||||
input_data = {input_key: user_input}
|
||||
if self._attached_pdf:
|
||||
input_data["pdf_file_path"] = self._attached_pdf["path"]
|
||||
self._write_history(f"[dim]Including PDF: {self._attached_pdf['filename']}[/dim]")
|
||||
self._attached_pdf = None
|
||||
|
||||
# Submit execution to the dedicated agent loop so blocking
|
||||
# runtime code (LLM, MCP tools) never touches Textual's loop.
|
||||
# trigger() returns immediately with an exec_id; the heavy
|
||||
@@ -823,7 +980,7 @@ class ChatRepl(Vertical):
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.runtime.trigger(
|
||||
entry_point_id=entry_point.id,
|
||||
input_data={input_key: user_input},
|
||||
input_data=input_data,
|
||||
),
|
||||
self._agent_loop,
|
||||
)
|
||||
@@ -834,12 +991,32 @@ class ChatRepl(Vertical):
|
||||
indicator.display = False
|
||||
self._current_exec_id = None
|
||||
# Re-enable input on error
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.disabled = False
|
||||
self._write_history(f"[bold red]Error:[/bold red] {e}")
|
||||
|
||||
# -- Event handlers called by app.py _handle_event --
|
||||
|
||||
def handle_node_started(self, node_id: str) -> None:
|
||||
"""Reset streaming state and track active node when a new node begins.
|
||||
|
||||
Flushes any stale ``_streaming_snapshot`` left over from the
|
||||
previous node and resets the processing indicator so the user
|
||||
sees a clean transition between graph nodes.
|
||||
"""
|
||||
self._active_node_id = node_id
|
||||
if self._streaming_snapshot:
|
||||
self._write_history(f"[bold blue]Agent:[/bold blue] {self._streaming_snapshot}")
|
||||
self._streaming_snapshot = ""
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
indicator.update("Thinking...")
|
||||
|
||||
def handle_loop_iteration(self, iteration: int) -> None:
|
||||
"""Flush accumulated streaming text when a new loop iteration starts."""
|
||||
if self._streaming_snapshot:
|
||||
self._write_history(f"[bold blue]Agent:[/bold blue] {self._streaming_snapshot}")
|
||||
self._streaming_snapshot = ""
|
||||
|
||||
def handle_text_delta(self, content: str, snapshot: str) -> None:
|
||||
"""Handle a streaming text token from the LLM."""
|
||||
self._streaming_snapshot = snapshot
|
||||
@@ -867,8 +1044,11 @@ class ChatRepl(Vertical):
|
||||
# Update indicator to show tool activity
|
||||
indicator.update(f"Using tool: {tool_name}...")
|
||||
|
||||
# Write a discrete status line to history
|
||||
self._write_history(f"[dim]Tool: {tool_name}[/dim]")
|
||||
# Buffer and conditionally display tool status line
|
||||
line = f"[dim]Tool: {tool_name}[/dim]"
|
||||
self._log_buffer.append(line)
|
||||
if self._show_logs:
|
||||
self._write_history(line)
|
||||
|
||||
def handle_tool_completed(self, tool_name: str, result: str, is_error: bool) -> None:
|
||||
"""Handle a tool call completing."""
|
||||
@@ -882,9 +1062,12 @@ class ChatRepl(Vertical):
|
||||
preview = preview.replace("\n", " ")
|
||||
|
||||
if is_error:
|
||||
self._write_history(f"[dim red]Tool {tool_name} error: {preview}[/dim red]")
|
||||
line = f"[dim red]Tool {tool_name} error: {preview}[/dim red]"
|
||||
else:
|
||||
self._write_history(f"[dim]Tool {tool_name} result: {preview}[/dim]")
|
||||
line = f"[dim]Tool {tool_name} result: {preview}[/dim]"
|
||||
self._log_buffer.append(line)
|
||||
if self._show_logs:
|
||||
self._write_history(line)
|
||||
|
||||
# Restore thinking indicator
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
@@ -907,10 +1090,12 @@ class ChatRepl(Vertical):
|
||||
self._streaming_snapshot = ""
|
||||
self._waiting_for_input = False
|
||||
self._input_node_id = None
|
||||
self._active_node_id = None
|
||||
self._pending_ask_question = ""
|
||||
self._log_buffer.clear()
|
||||
|
||||
# Re-enable input
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.disabled = False
|
||||
chat_input.placeholder = "Enter input for agent..."
|
||||
chat_input.focus()
|
||||
@@ -928,9 +1113,11 @@ class ChatRepl(Vertical):
|
||||
self._waiting_for_input = False
|
||||
self._pending_ask_question = ""
|
||||
self._input_node_id = None
|
||||
self._active_node_id = None
|
||||
self._log_buffer.clear()
|
||||
|
||||
# Re-enable input
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.disabled = False
|
||||
chat_input.placeholder = "Enter input for agent..."
|
||||
chat_input.focus()
|
||||
@@ -961,7 +1148,36 @@ class ChatRepl(Vertical):
|
||||
indicator = self.query_one("#processing-indicator", Label)
|
||||
indicator.update("Waiting for your input...")
|
||||
|
||||
chat_input = self.query_one("#chat-input", Input)
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
chat_input.disabled = False
|
||||
chat_input.placeholder = "Type your response..."
|
||||
chat_input.focus()
|
||||
|
||||
def handle_node_completed(self, node_id: str) -> None:
|
||||
"""Clear active node when it finishes."""
|
||||
if self._active_node_id == node_id:
|
||||
self._active_node_id = None
|
||||
|
||||
def handle_internal_output(self, node_id: str, content: str) -> None:
|
||||
"""Show output from non-client-facing nodes."""
|
||||
self._write_history(f"[dim cyan]⟨{node_id}⟩[/dim cyan] {content}")
|
||||
|
||||
def handle_execution_paused(self, node_id: str, reason: str) -> None:
|
||||
"""Show that execution has been paused."""
|
||||
msg = f"[bold yellow]⏸ Paused[/bold yellow] at [cyan]{node_id}[/cyan]"
|
||||
if reason:
|
||||
msg += f" [dim]({reason})[/dim]"
|
||||
self._write_history(msg)
|
||||
|
||||
def handle_execution_resumed(self, node_id: str) -> None:
|
||||
"""Show that execution has been resumed."""
|
||||
self._write_history(f"[bold green]▶ Resumed[/bold green] from [cyan]{node_id}[/cyan]")
|
||||
|
||||
def handle_goal_achieved(self, data: dict[str, Any]) -> None:
|
||||
"""Show goal achievement prominently."""
|
||||
self._write_history("[bold green]★ Goal achieved![/bold green]")
|
||||
|
||||
def handle_constraint_violation(self, data: dict[str, Any]) -> None:
|
||||
"""Show constraint violation as a warning."""
|
||||
desc = data.get("description", "Unknown constraint")
|
||||
self._write_history(f"[bold red]⚠ Constraint violation:[/bold red] {desc}")
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Native OS file dialog for PDF selection.
|
||||
|
||||
Launches the platform's native file picker (macOS: NSOpenPanel via osascript,
|
||||
Linux: zenity/kdialog, Windows: PowerShell OpenFileDialog) in a background
|
||||
thread so Textual's event loop stays responsive.
|
||||
|
||||
Falls back to None when no GUI is available (SSH, headless).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _has_gui() -> bool:
|
||||
"""Detect whether a GUI display is available."""
|
||||
if sys.platform == "darwin":
|
||||
# macOS: GUI is available unless running over SSH without display forwarding.
|
||||
return "SSH_CONNECTION" not in os.environ or "DISPLAY" in os.environ
|
||||
elif sys.platform == "win32":
|
||||
return True
|
||||
else:
|
||||
# Linux/BSD: Need X11 or Wayland.
|
||||
return bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"))
|
||||
|
||||
|
||||
def _linux_file_dialog() -> subprocess.CompletedProcess | None:
|
||||
"""Try zenity, then kdialog, on Linux. Returns CompletedProcess or None."""
|
||||
# Try zenity (GTK)
|
||||
try:
|
||||
return subprocess.run(
|
||||
[
|
||||
"zenity",
|
||||
"--file-selection",
|
||||
"--title=Select a PDF file",
|
||||
"--file-filter=PDF files (*.pdf)|*.pdf",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# Try kdialog (KDE)
|
||||
try:
|
||||
return subprocess.run(
|
||||
[
|
||||
"kdialog",
|
||||
"--getopenfilename",
|
||||
".",
|
||||
"PDF files (*.pdf)",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _pick_pdf_subprocess() -> Path | None:
|
||||
"""Run the native file dialog. BLOCKS until user picks or cancels.
|
||||
|
||||
Returns a Path on success, None on cancel or error.
|
||||
Must be called from a non-main thread (via asyncio.to_thread).
|
||||
"""
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
result = subprocess.run(
|
||||
[
|
||||
"osascript",
|
||||
"-e",
|
||||
'POSIX path of (choose file of type {"com.adobe.pdf"} '
|
||||
'with prompt "Select a PDF file")',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
elif sys.platform == "win32":
|
||||
ps_script = (
|
||||
"Add-Type -AssemblyName System.Windows.Forms; "
|
||||
"$f = New-Object System.Windows.Forms.OpenFileDialog; "
|
||||
"$f.Filter = 'PDF files (*.pdf)|*.pdf'; "
|
||||
"$f.Title = 'Select a PDF file'; "
|
||||
"if ($f.ShowDialog() -eq 'OK') { $f.FileName }"
|
||||
)
|
||||
result = subprocess.run(
|
||||
["powershell", "-NoProfile", "-Command", ps_script],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
else:
|
||||
result = _linux_file_dialog()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
path_str = result.stdout.strip()
|
||||
if not path_str:
|
||||
return None
|
||||
|
||||
path = Path(path_str)
|
||||
if path.is_file() and path.suffix.lower() == ".pdf":
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
async def pick_pdf_file() -> Path | None:
|
||||
"""Open a native OS file dialog to pick a PDF file.
|
||||
|
||||
Non-blocking: runs the dialog subprocess in a background thread via
|
||||
asyncio.to_thread(), so the calling event loop stays responsive.
|
||||
|
||||
Returns:
|
||||
Path to the selected PDF, or None if the user cancelled,
|
||||
no GUI is available, or the dialog command was not found.
|
||||
"""
|
||||
if not _has_gui():
|
||||
return None
|
||||
|
||||
return await asyncio.to_thread(_pick_pdf_subprocess)
|
||||
@@ -1,7 +1,15 @@
|
||||
"""
|
||||
Graph/Tree Overview Widget - Displays real agent graph structure.
|
||||
|
||||
Supports rendering loops (back-edges) via right-side return channels:
|
||||
arrows drawn on the right margin that visually point back up to earlier nodes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Vertical
|
||||
|
||||
@@ -9,6 +17,17 @@ from framework.runtime.agent_runtime import AgentRuntime
|
||||
from framework.runtime.event_bus import EventType
|
||||
from framework.tui.widgets.selectable_rich_log import SelectableRichLog as RichLog
|
||||
|
||||
# Width of each return-channel column (padding + │ + gap)
|
||||
_CHANNEL_WIDTH = 5
|
||||
|
||||
# Regex to strip Rich markup tags for measuring visible width
|
||||
_MARKUP_RE = re.compile(r"\[/?[^\]]*\]")
|
||||
|
||||
|
||||
def _plain_len(s: str) -> int:
|
||||
"""Return the visible character length of a Rich-markup string."""
|
||||
return len(_MARKUP_RE.sub("", s))
|
||||
|
||||
|
||||
class GraphOverview(Vertical):
|
||||
"""Widget to display Agent execution graph/tree with real data."""
|
||||
@@ -46,6 +65,13 @@ class GraphOverview(Vertical):
|
||||
def on_mount(self) -> None:
|
||||
"""Display initial graph structure."""
|
||||
self._display_graph()
|
||||
# Refresh every 1s so timer countdowns stay current
|
||||
if self.runtime._timer_next_fire is not None:
|
||||
self.set_interval(1.0, self._display_graph)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph analysis helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _topo_order(self) -> list[str]:
|
||||
"""BFS from entry_node following edges."""
|
||||
@@ -68,6 +94,39 @@ class GraphOverview(Vertical):
|
||||
visited.append(node.id)
|
||||
return visited
|
||||
|
||||
def _detect_back_edges(self, ordered: list[str]) -> list[dict]:
|
||||
"""Find edges where target appears before (or equal to) source in topo order.
|
||||
|
||||
Returns a list of dicts with keys: edge, source, target, source_idx, target_idx.
|
||||
"""
|
||||
order_idx = {nid: i for i, nid in enumerate(ordered)}
|
||||
back_edges: list[dict] = []
|
||||
for node_id in ordered:
|
||||
for edge in self.runtime.graph.get_outgoing_edges(node_id):
|
||||
target_idx = order_idx.get(edge.target, -1)
|
||||
source_idx = order_idx.get(node_id, -1)
|
||||
if target_idx != -1 and target_idx <= source_idx:
|
||||
back_edges.append(
|
||||
{
|
||||
"edge": edge,
|
||||
"source": node_id,
|
||||
"target": edge.target,
|
||||
"source_idx": source_idx,
|
||||
"target_idx": target_idx,
|
||||
}
|
||||
)
|
||||
return back_edges
|
||||
|
||||
def _is_back_edge(self, source: str, target: str, order_idx: dict[str, int]) -> bool:
|
||||
"""Check whether an edge from *source* to *target* is a back-edge."""
|
||||
si = order_idx.get(source, -1)
|
||||
ti = order_idx.get(target, -1)
|
||||
return ti != -1 and ti <= si
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Line rendering (Pass 1)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _render_node_line(self, node_id: str) -> str:
|
||||
"""Render a single node with status symbol and optional status text."""
|
||||
graph = self.runtime.graph
|
||||
@@ -95,43 +154,349 @@ class GraphOverview(Vertical):
|
||||
suffix = f" [italic]{status}[/italic]" if status else ""
|
||||
return f" {sym} {name}{suffix}"
|
||||
|
||||
def _render_edges(self, node_id: str) -> list[str]:
|
||||
"""Render edge connectors from this node to its targets."""
|
||||
edges = self.runtime.graph.get_outgoing_edges(node_id)
|
||||
if not edges:
|
||||
def _render_edges(self, node_id: str, order_idx: dict[str, int]) -> list[str]:
|
||||
"""Render forward-edge connectors from *node_id*.
|
||||
|
||||
Back-edges are excluded here — they are drawn by the return-channel
|
||||
overlay in Pass 2.
|
||||
"""
|
||||
all_edges = self.runtime.graph.get_outgoing_edges(node_id)
|
||||
if not all_edges:
|
||||
return []
|
||||
if len(edges) == 1:
|
||||
|
||||
# Split into forward and back
|
||||
forward = [e for e in all_edges if not self._is_back_edge(node_id, e.target, order_idx)]
|
||||
|
||||
if not forward:
|
||||
# All edges are back-edges — nothing to render here
|
||||
return []
|
||||
|
||||
if len(forward) == 1:
|
||||
return [" │", " ▼"]
|
||||
|
||||
# Fan-out: show branches
|
||||
lines: list[str] = []
|
||||
for i, edge in enumerate(edges):
|
||||
connector = "└" if i == len(edges) - 1 else "├"
|
||||
for i, edge in enumerate(forward):
|
||||
connector = "└" if i == len(forward) - 1 else "├"
|
||||
cond = ""
|
||||
if edge.condition.value not in ("always", "on_success"):
|
||||
cond = f" [dim]({edge.condition.value})[/dim]"
|
||||
lines.append(f" {connector}──▶ {edge.target}{cond}")
|
||||
return lines
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Return-channel overlay (Pass 2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _overlay_return_channels(
|
||||
self,
|
||||
lines: list[str],
|
||||
node_line_map: dict[str, int],
|
||||
back_edges: list[dict],
|
||||
available_width: int,
|
||||
) -> list[str]:
|
||||
"""Overlay right-side return channels onto the line buffer.
|
||||
|
||||
Each back-edge gets a vertical channel on the right margin. Channels
|
||||
are allocated left-to-right by increasing span length so that shorter
|
||||
(inner) loops are closer to the graph body and longer (outer) loops are
|
||||
further right.
|
||||
|
||||
If the terminal is too narrow to fit even one channel, we fall back to
|
||||
simple inline ``↺`` annotations instead.
|
||||
"""
|
||||
if not back_edges:
|
||||
return lines
|
||||
|
||||
num_channels = len(back_edges)
|
||||
|
||||
# Sort by span length ascending → inner loops get nearest channel
|
||||
sorted_be = sorted(back_edges, key=lambda b: b["source_idx"] - b["target_idx"])
|
||||
|
||||
# --- Insert dedicated connector lines for back-edge sources ---
|
||||
# Each back-edge source gets a blank line inserted after its node
|
||||
# section (after any forward-edge lines). We process insertions in
|
||||
# reverse order so that earlier indices remain valid.
|
||||
all_node_lines_set = set(node_line_map.values())
|
||||
|
||||
insertions: list[tuple[int, int]] = [] # (insert_after_line, be_index)
|
||||
for be_idx, be in enumerate(sorted_be):
|
||||
source_node_line = node_line_map.get(be["source"])
|
||||
if source_node_line is None:
|
||||
continue
|
||||
# Walk forward to find the last line in this node's section
|
||||
last_section_line = source_node_line
|
||||
for li in range(source_node_line + 1, len(lines)):
|
||||
if li in all_node_lines_set:
|
||||
break
|
||||
last_section_line = li
|
||||
insertions.append((last_section_line, be_idx))
|
||||
|
||||
source_line_for_be: dict[int, int] = {}
|
||||
for insert_after, be_idx in sorted(insertions, reverse=True):
|
||||
insert_at = insert_after + 1
|
||||
lines.insert(insert_at, "") # placeholder for connector
|
||||
source_line_for_be[be_idx] = insert_at
|
||||
# Shift node_line_map entries that come after the insertion point
|
||||
for nid in node_line_map:
|
||||
if node_line_map[nid] > insert_after:
|
||||
node_line_map[nid] += 1
|
||||
# Also shift already-assigned source lines
|
||||
for prev_idx in source_line_for_be:
|
||||
if prev_idx != be_idx and source_line_for_be[prev_idx] > insert_after:
|
||||
source_line_for_be[prev_idx] += 1
|
||||
|
||||
# Recompute max content width after insertions
|
||||
max_content_w = max(_plain_len(ln) for ln in lines) if lines else 0
|
||||
|
||||
# Check if we have room for channels
|
||||
channels_total_w = num_channels * _CHANNEL_WIDTH
|
||||
if max_content_w + channels_total_w + 2 > available_width:
|
||||
return self._inline_back_edge_fallback(lines, node_line_map, back_edges)
|
||||
|
||||
content_pad = max_content_w + 3 # gap between content and first channel
|
||||
|
||||
# Build channel info with final line positions
|
||||
channel_info: list[dict] = []
|
||||
for ch_idx, be in enumerate(sorted_be):
|
||||
target_line = node_line_map.get(be["target"])
|
||||
source_line = source_line_for_be.get(ch_idx)
|
||||
if target_line is None or source_line is None:
|
||||
continue
|
||||
col = content_pad + ch_idx * _CHANNEL_WIDTH
|
||||
channel_info.append(
|
||||
{
|
||||
"target_line": target_line,
|
||||
"source_line": source_line,
|
||||
"col": col,
|
||||
}
|
||||
)
|
||||
|
||||
if not channel_info:
|
||||
return lines
|
||||
|
||||
# Build overlay grid — one row per line, columns for channel area
|
||||
total_width = content_pad + num_channels * _CHANNEL_WIDTH + 1
|
||||
overlay_width = total_width - max_content_w
|
||||
overlays: list[list[str]] = [[" "] * overlay_width for _ in range(len(lines))]
|
||||
|
||||
for ci in channel_info:
|
||||
tl = ci["target_line"]
|
||||
sl = ci["source_line"]
|
||||
col_offset = ci["col"] - max_content_w
|
||||
|
||||
if col_offset < 0 or col_offset >= overlay_width:
|
||||
continue
|
||||
|
||||
# Target line: ◄──...──┐
|
||||
if 0 <= tl < len(overlays):
|
||||
for c in range(col_offset):
|
||||
if overlays[tl][c] == " ":
|
||||
overlays[tl][c] = "─"
|
||||
overlays[tl][col_offset] = "┐"
|
||||
|
||||
# Source line: ──...──┘
|
||||
if 0 <= sl < len(overlays):
|
||||
for c in range(col_offset):
|
||||
if overlays[sl][c] == " ":
|
||||
overlays[sl][c] = "─"
|
||||
overlays[sl][col_offset] = "┘"
|
||||
|
||||
# Vertical lines between target+1 and source-1
|
||||
for li in range(tl + 1, sl):
|
||||
if 0 <= li < len(overlays) and overlays[li][col_offset] == " ":
|
||||
overlays[li][col_offset] = "│"
|
||||
|
||||
# Merge overlays into the line strings
|
||||
result: list[str] = []
|
||||
for i, line in enumerate(lines):
|
||||
pw = _plain_len(line)
|
||||
pad = max_content_w - pw
|
||||
overlay_chars = overlays[i] if i < len(overlays) else []
|
||||
overlay_str = "".join(overlay_chars)
|
||||
overlay_trimmed = overlay_str.rstrip()
|
||||
if overlay_trimmed:
|
||||
is_target_line = any(ci["target_line"] == i for ci in channel_info)
|
||||
if is_target_line:
|
||||
overlay_trimmed = "◄" + overlay_trimmed[1:]
|
||||
|
||||
is_source_line = any(ci["source_line"] == i for ci in channel_info)
|
||||
if is_source_line and not line.strip():
|
||||
# Inserted blank line → build └───┘ connector.
|
||||
# " └" = 3 chars of content prefix, so remaining pad = max_content_w - 3
|
||||
remaining_pad = max_content_w - 3
|
||||
full = list(" " * remaining_pad + overlay_trimmed)
|
||||
# Find the ┘ corner for this source connector
|
||||
corner_pos = -1
|
||||
for ci_s in channel_info:
|
||||
if ci_s["source_line"] == i:
|
||||
corner_pos = remaining_pad + (ci_s["col"] - max_content_w)
|
||||
break
|
||||
# Fill everything up to the corner with ─
|
||||
if corner_pos >= 0:
|
||||
for c in range(corner_pos):
|
||||
if full[c] not in ("│", "┘", "┐"):
|
||||
full[c] = "─"
|
||||
connector = " └" + "".join(full).rstrip()
|
||||
result.append(f"[dim]{connector}[/dim]")
|
||||
continue
|
||||
|
||||
colored_overlay = f"[dim]{' ' * pad}{overlay_trimmed}[/dim]"
|
||||
result.append(f"{line}{colored_overlay}")
|
||||
else:
|
||||
result.append(line)
|
||||
|
||||
return result
|
||||
|
||||
def _inline_back_edge_fallback(
|
||||
self,
|
||||
lines: list[str],
|
||||
node_line_map: dict[str, int],
|
||||
back_edges: list[dict],
|
||||
) -> list[str]:
|
||||
"""Fallback: add inline ↺ annotations when terminal is too narrow for channels."""
|
||||
# Group back-edges by source node
|
||||
source_to_be: dict[str, list[dict]] = {}
|
||||
for be in back_edges:
|
||||
source_to_be.setdefault(be["source"], []).append(be)
|
||||
|
||||
result = list(lines)
|
||||
# Insert annotation lines after each source node's section
|
||||
offset = 0
|
||||
all_node_lines = sorted(node_line_map.values())
|
||||
for source, bes in source_to_be.items():
|
||||
source_line = node_line_map.get(source)
|
||||
if source_line is None:
|
||||
continue
|
||||
# Find end of source node section
|
||||
end_line = source_line
|
||||
for nl in all_node_lines:
|
||||
if nl > source_line:
|
||||
end_line = nl - 1
|
||||
break
|
||||
else:
|
||||
end_line = len(lines) - 1
|
||||
# Insert after last content line of this node's section
|
||||
insert_at = end_line + offset + 1
|
||||
for be in bes:
|
||||
cond = ""
|
||||
edge = be["edge"]
|
||||
if edge.condition.value not in ("always", "on_success"):
|
||||
cond = f" [dim]({edge.condition.value})[/dim]"
|
||||
annotation = f" [yellow]↺[/yellow] {be['target']}{cond}"
|
||||
result.insert(insert_at, annotation)
|
||||
insert_at += 1
|
||||
offset += 1
|
||||
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main display
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _display_graph(self) -> None:
|
||||
"""Display the graph as an ASCII DAG with edge connectors."""
|
||||
"""Display the graph as an ASCII DAG with edge connectors and loop channels."""
|
||||
display = self.query_one("#graph-display", RichLog)
|
||||
display.clear()
|
||||
|
||||
graph = self.runtime.graph
|
||||
display.write(f"[bold cyan]Agent Graph:[/bold cyan] {graph.id}\n")
|
||||
|
||||
# Render each node in topological order with edges
|
||||
ordered = self._topo_order()
|
||||
order_idx = {nid: i for i, nid in enumerate(ordered)}
|
||||
|
||||
# --- Pass 1: Build line buffer ---
|
||||
lines: list[str] = []
|
||||
node_line_map: dict[str, int] = {}
|
||||
|
||||
for node_id in ordered:
|
||||
display.write(self._render_node_line(node_id))
|
||||
for edge_line in self._render_edges(node_id):
|
||||
display.write(edge_line)
|
||||
node_line_map[node_id] = len(lines)
|
||||
lines.append(self._render_node_line(node_id))
|
||||
for edge_line in self._render_edges(node_id, order_idx):
|
||||
lines.append(edge_line)
|
||||
|
||||
# --- Pass 2: Overlay return channels for back-edges ---
|
||||
back_edges = self._detect_back_edges(ordered)
|
||||
if back_edges:
|
||||
# Try to get actual widget width; default to a reasonable value
|
||||
try:
|
||||
available_width = self.size.width or 60
|
||||
except Exception:
|
||||
available_width = 60
|
||||
lines = self._overlay_return_channels(lines, node_line_map, back_edges, available_width)
|
||||
|
||||
# Write all lines
|
||||
for line in lines:
|
||||
display.write(line)
|
||||
|
||||
# Execution path footer
|
||||
if self.execution_path:
|
||||
display.write("")
|
||||
display.write(f"[dim]Path:[/dim] {' → '.join(self.execution_path[-5:])}")
|
||||
|
||||
# Event sources section
|
||||
self._render_event_sources(display)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event sources display
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _render_event_sources(self, display: RichLog) -> None:
|
||||
"""Render event source info (webhooks, timers) below the graph."""
|
||||
entry_points = self.runtime.get_entry_points()
|
||||
|
||||
# Filter to non-manual entry points (webhooks, timers, events)
|
||||
event_sources = [ep for ep in entry_points if ep.trigger_type not in ("manual",)]
|
||||
if not event_sources:
|
||||
return
|
||||
|
||||
display.write("")
|
||||
display.write("[bold cyan]Event Sources[/bold cyan]")
|
||||
|
||||
config = self.runtime._config
|
||||
|
||||
for ep in event_sources:
|
||||
if ep.trigger_type == "timer":
|
||||
interval = ep.trigger_config.get("interval_minutes", "?")
|
||||
display.write(f" [green]⏱[/green] {ep.name} [dim]→ {ep.entry_node}[/dim]")
|
||||
# Show interval + next fire countdown
|
||||
next_fire = self.runtime._timer_next_fire.get(ep.id)
|
||||
if next_fire is not None:
|
||||
remaining = max(0, next_fire - time.monotonic())
|
||||
mins, secs = divmod(int(remaining), 60)
|
||||
display.write(
|
||||
f" [dim]every {interval} min — next in {mins}m {secs:02d}s[/dim]"
|
||||
)
|
||||
else:
|
||||
display.write(f" [dim]every {interval} min[/dim]")
|
||||
|
||||
elif ep.trigger_type in ("event", "webhook"):
|
||||
display.write(f" [yellow]⚡[/yellow] {ep.name} [dim]→ {ep.entry_node}[/dim]")
|
||||
# Show webhook endpoint if configured
|
||||
route = None
|
||||
for r in config.webhook_routes:
|
||||
src = r.get("source_id", "")
|
||||
if src and src in ep.id:
|
||||
route = r
|
||||
break
|
||||
if not route and config.webhook_routes:
|
||||
# Fall back to first route
|
||||
route = config.webhook_routes[0]
|
||||
|
||||
if route:
|
||||
host = config.webhook_host
|
||||
port = config.webhook_port
|
||||
path = route.get("path", "/webhook")
|
||||
display.write(f" [dim]{host}:{port}{path}[/dim]")
|
||||
else:
|
||||
event_types = ep.trigger_config.get("event_types", [])
|
||||
if event_types:
|
||||
display.write(f" [dim]events: {', '.join(event_types)}[/dim]")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API (called by app.py)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def update_active_node(self, node_id: str) -> None:
|
||||
"""Update the currently active node."""
|
||||
self.active_node = node_id
|
||||
@@ -177,6 +542,8 @@ class GraphOverview(Vertical):
|
||||
def handle_node_loop_completed(self, node_id: str) -> None:
|
||||
"""A node's event loop completed."""
|
||||
self._node_status.pop(node_id, None)
|
||||
if self.active_node == node_id:
|
||||
self.active_node = None
|
||||
self._display_graph()
|
||||
|
||||
def handle_tool_call(self, node_id: str, tool_name: str, *, started: bool) -> None:
|
||||
@@ -192,3 +559,8 @@ class GraphOverview(Vertical):
|
||||
"""Highlight a stalled node."""
|
||||
self._node_status[node_id] = f"[red]stalled: {reason}[/red]"
|
||||
self._display_graph()
|
||||
|
||||
def handle_edge_traversed(self, source_node: str, target_node: str) -> None:
|
||||
"""Highlight an edge being traversed."""
|
||||
self._node_status[source_node] = f"[dim]→ {target_node}[/dim]"
|
||||
self._display_graph()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""
|
||||
Log Pane Widget - Uses RichLog for reliable rendering.
|
||||
Log formatting utilities and LogPane widget.
|
||||
|
||||
The module-level functions (format_event, extract_event_text, format_python_log)
|
||||
can be used by any widget that needs to render log lines without instantiating LogPane.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -11,36 +14,108 @@ from textual.containers import Container
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.tui.widgets.selectable_rich_log import SelectableRichLog as RichLog
|
||||
|
||||
# --- Module-level formatting constants ---
|
||||
|
||||
EVENT_FORMAT: dict[EventType, tuple[str, str]] = {
|
||||
EventType.EXECUTION_STARTED: (">>", "bold cyan"),
|
||||
EventType.EXECUTION_COMPLETED: ("<<", "bold green"),
|
||||
EventType.EXECUTION_FAILED: ("!!", "bold red"),
|
||||
EventType.TOOL_CALL_STARTED: ("->", "yellow"),
|
||||
EventType.TOOL_CALL_COMPLETED: ("<-", "green"),
|
||||
EventType.NODE_LOOP_STARTED: ("@@", "cyan"),
|
||||
EventType.NODE_LOOP_ITERATION: ("..", "dim"),
|
||||
EventType.NODE_LOOP_COMPLETED: ("@@", "dim"),
|
||||
EventType.NODE_STALLED: ("!!", "bold yellow"),
|
||||
EventType.NODE_INPUT_BLOCKED: ("!!", "yellow"),
|
||||
EventType.GOAL_PROGRESS: ("%%", "blue"),
|
||||
EventType.GOAL_ACHIEVED: ("**", "bold green"),
|
||||
EventType.CONSTRAINT_VIOLATION: ("!!", "bold red"),
|
||||
EventType.STATE_CHANGED: ("~~", "dim"),
|
||||
EventType.CLIENT_INPUT_REQUESTED: ("??", "magenta"),
|
||||
}
|
||||
|
||||
LOG_LEVEL_COLORS: dict[int, str] = {
|
||||
logging.DEBUG: "dim",
|
||||
logging.INFO: "",
|
||||
logging.WARNING: "yellow",
|
||||
logging.ERROR: "red",
|
||||
logging.CRITICAL: "bold red",
|
||||
}
|
||||
|
||||
|
||||
# --- Module-level formatting functions ---
|
||||
|
||||
|
||||
def extract_event_text(event: AgentEvent) -> str:
|
||||
"""Extract human-readable text from an event's data dict."""
|
||||
et = event.type
|
||||
data = event.data
|
||||
|
||||
if et == EventType.EXECUTION_STARTED:
|
||||
return "Execution started"
|
||||
elif et == EventType.EXECUTION_COMPLETED:
|
||||
return "Execution completed"
|
||||
elif et == EventType.EXECUTION_FAILED:
|
||||
return f"Execution FAILED: {data.get('error', 'unknown')}"
|
||||
elif et == EventType.TOOL_CALL_STARTED:
|
||||
return f"Tool call: {data.get('tool_name', 'unknown')}"
|
||||
elif et == EventType.TOOL_CALL_COMPLETED:
|
||||
name = data.get("tool_name", "unknown")
|
||||
if data.get("is_error"):
|
||||
preview = str(data.get("result", ""))[:80]
|
||||
return f"Tool error: {name} - {preview}"
|
||||
return f"Tool done: {name}"
|
||||
elif et == EventType.NODE_LOOP_STARTED:
|
||||
return f"Node started: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.NODE_LOOP_ITERATION:
|
||||
return f"{event.node_id or 'unknown'} iteration {data.get('iteration', '?')}"
|
||||
elif et == EventType.NODE_LOOP_COMPLETED:
|
||||
return f"Node done: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.NODE_STALLED:
|
||||
reason = data.get("reason", "")
|
||||
node = event.node_id or "unknown"
|
||||
return f"Node stalled: {node} - {reason}" if reason else f"Node stalled: {node}"
|
||||
elif et == EventType.NODE_INPUT_BLOCKED:
|
||||
return f"Node input blocked: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.GOAL_PROGRESS:
|
||||
return f"Goal progress: {data.get('progress', '?')}"
|
||||
elif et == EventType.GOAL_ACHIEVED:
|
||||
return "Goal achieved"
|
||||
elif et == EventType.CONSTRAINT_VIOLATION:
|
||||
return f"Constraint violated: {data.get('description', 'unknown')}"
|
||||
elif et == EventType.STATE_CHANGED:
|
||||
return f"State changed: {data.get('key', 'unknown')}"
|
||||
elif et == EventType.CLIENT_INPUT_REQUESTED:
|
||||
return "Waiting for user input"
|
||||
else:
|
||||
return f"{et.value}: {data}"
|
||||
|
||||
|
||||
def format_event(event: AgentEvent) -> str:
|
||||
"""Format an AgentEvent as a Rich markup string with timestamp + symbol."""
|
||||
ts = event.timestamp.strftime("%H:%M:%S")
|
||||
symbol, color = EVENT_FORMAT.get(event.type, ("--", "dim"))
|
||||
text = extract_event_text(event)
|
||||
return f"[dim]{ts}[/dim] [{color}]{symbol} {text}[/{color}]"
|
||||
|
||||
|
||||
def format_python_log(record: logging.LogRecord) -> str:
|
||||
"""Format a Python log record as a Rich markup string with timestamp and severity color."""
|
||||
ts = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
|
||||
color = LOG_LEVEL_COLORS.get(record.levelno, "")
|
||||
msg = record.getMessage()
|
||||
if color:
|
||||
return f"[dim]{ts}[/dim] [{color}]{record.levelname}[/{color}] {msg}"
|
||||
else:
|
||||
return f"[dim]{ts}[/dim] {record.levelname} {msg}"
|
||||
|
||||
|
||||
# --- LogPane widget (kept for backward compatibility) ---
|
||||
|
||||
|
||||
class LogPane(Container):
|
||||
"""Widget to display logs with reliable rendering."""
|
||||
|
||||
_EVENT_FORMAT: dict[EventType, tuple[str, str]] = {
|
||||
EventType.EXECUTION_STARTED: (">>", "bold cyan"),
|
||||
EventType.EXECUTION_COMPLETED: ("<<", "bold green"),
|
||||
EventType.EXECUTION_FAILED: ("!!", "bold red"),
|
||||
EventType.TOOL_CALL_STARTED: ("->", "yellow"),
|
||||
EventType.TOOL_CALL_COMPLETED: ("<-", "green"),
|
||||
EventType.NODE_LOOP_STARTED: ("@@", "cyan"),
|
||||
EventType.NODE_LOOP_ITERATION: ("..", "dim"),
|
||||
EventType.NODE_LOOP_COMPLETED: ("@@", "dim"),
|
||||
EventType.NODE_STALLED: ("!!", "bold yellow"),
|
||||
EventType.NODE_INPUT_BLOCKED: ("!!", "yellow"),
|
||||
EventType.GOAL_PROGRESS: ("%%", "blue"),
|
||||
EventType.GOAL_ACHIEVED: ("**", "bold green"),
|
||||
EventType.CONSTRAINT_VIOLATION: ("!!", "bold red"),
|
||||
EventType.STATE_CHANGED: ("~~", "dim"),
|
||||
EventType.CLIENT_INPUT_REQUESTED: ("??", "magenta"),
|
||||
}
|
||||
|
||||
_LOG_LEVEL_COLORS = {
|
||||
logging.DEBUG: "dim",
|
||||
logging.INFO: "",
|
||||
logging.WARNING: "yellow",
|
||||
logging.ERROR: "red",
|
||||
logging.CRITICAL: "bold red",
|
||||
}
|
||||
|
||||
DEFAULT_CSS = """
|
||||
LogPane {
|
||||
width: 100%;
|
||||
@@ -58,84 +133,27 @@ class LogPane(Container):
|
||||
"""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
# RichLog is designed for log display and doesn't have TextArea's rendering issues
|
||||
yield RichLog(id="main-log", highlight=True, markup=True, auto_scroll=False)
|
||||
|
||||
def write_event(self, event: AgentEvent) -> None:
|
||||
"""Format an AgentEvent with timestamp + symbol and write to the log."""
|
||||
ts = event.timestamp.strftime("%H:%M:%S")
|
||||
symbol, color = self._EVENT_FORMAT.get(event.type, ("--", "dim"))
|
||||
text = self._extract_event_text(event)
|
||||
self.write_log(f"[dim]{ts}[/dim] [{color}]{symbol} {text}[/{color}]")
|
||||
|
||||
def _extract_event_text(self, event: AgentEvent) -> str:
|
||||
"""Extract human-readable text from an event's data dict."""
|
||||
et = event.type
|
||||
data = event.data
|
||||
|
||||
if et == EventType.EXECUTION_STARTED:
|
||||
return "Execution started"
|
||||
elif et == EventType.EXECUTION_COMPLETED:
|
||||
return "Execution completed"
|
||||
elif et == EventType.EXECUTION_FAILED:
|
||||
return f"Execution FAILED: {data.get('error', 'unknown')}"
|
||||
elif et == EventType.TOOL_CALL_STARTED:
|
||||
return f"Tool call: {data.get('tool_name', 'unknown')}"
|
||||
elif et == EventType.TOOL_CALL_COMPLETED:
|
||||
name = data.get("tool_name", "unknown")
|
||||
if data.get("is_error"):
|
||||
preview = str(data.get("result", ""))[:80]
|
||||
return f"Tool error: {name} - {preview}"
|
||||
return f"Tool done: {name}"
|
||||
elif et == EventType.NODE_LOOP_STARTED:
|
||||
return f"Node started: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.NODE_LOOP_ITERATION:
|
||||
return f"{event.node_id or 'unknown'} iteration {data.get('iteration', '?')}"
|
||||
elif et == EventType.NODE_LOOP_COMPLETED:
|
||||
return f"Node done: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.NODE_STALLED:
|
||||
reason = data.get("reason", "")
|
||||
node = event.node_id or "unknown"
|
||||
return f"Node stalled: {node} - {reason}" if reason else f"Node stalled: {node}"
|
||||
elif et == EventType.NODE_INPUT_BLOCKED:
|
||||
return f"Node input blocked: {event.node_id or 'unknown'}"
|
||||
elif et == EventType.GOAL_PROGRESS:
|
||||
return f"Goal progress: {data.get('progress', '?')}"
|
||||
elif et == EventType.GOAL_ACHIEVED:
|
||||
return "Goal achieved"
|
||||
elif et == EventType.CONSTRAINT_VIOLATION:
|
||||
return f"Constraint violated: {data.get('description', 'unknown')}"
|
||||
elif et == EventType.STATE_CHANGED:
|
||||
return f"State changed: {data.get('key', 'unknown')}"
|
||||
elif et == EventType.CLIENT_INPUT_REQUESTED:
|
||||
return "Waiting for user input"
|
||||
else:
|
||||
return f"{et.value}: {data}"
|
||||
self.write_log(format_event(event))
|
||||
|
||||
def write_python_log(self, record: logging.LogRecord) -> None:
|
||||
"""Format a Python log record with timestamp and severity color."""
|
||||
ts = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
|
||||
color = self._LOG_LEVEL_COLORS.get(record.levelno, "")
|
||||
msg = record.getMessage()
|
||||
if color:
|
||||
self.write_log(f"[dim]{ts}[/dim] [{color}]{record.levelname}[/{color}] {msg}")
|
||||
else:
|
||||
self.write_log(f"[dim]{ts}[/dim] {record.levelname} {msg}")
|
||||
self.write_log(format_python_log(record))
|
||||
|
||||
def write_log(self, message: str) -> None:
|
||||
"""Write a log message to the log pane."""
|
||||
try:
|
||||
# Check if widget is mounted
|
||||
if not self.is_mounted:
|
||||
return
|
||||
|
||||
log = self.query_one("#main-log", RichLog)
|
||||
|
||||
# Check if log is mounted
|
||||
if not log.is_mounted:
|
||||
return
|
||||
|
||||
# Only auto-scroll if user is already at the bottom
|
||||
was_at_bottom = log.is_vertical_scroll_end
|
||||
|
||||
log.write(message)
|
||||
|
||||
@@ -195,12 +195,27 @@ def _copy_to_clipboard(text: str) -> None:
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.run(["pbcopy"], input=text.encode(), check=True, timeout=5)
|
||||
elif sys.platform.startswith("linux"):
|
||||
elif sys.platform == "win32":
|
||||
subprocess.run(
|
||||
["xclip", "-selection", "clipboard"],
|
||||
input=text.encode(),
|
||||
["clip.exe"],
|
||||
input=text.encode("utf-16le"),
|
||||
check=True,
|
||||
timeout=5,
|
||||
)
|
||||
elif sys.platform.startswith("linux"):
|
||||
try:
|
||||
subprocess.run(
|
||||
["xclip", "-selection", "clipboard"],
|
||||
input=text.encode(),
|
||||
check=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
subprocess.run(
|
||||
["xsel", "--clipboard", "--input"],
|
||||
input=text.encode(),
|
||||
check=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Utility functions for the Hive framework."""
|
||||
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
__all__ = ["atomic_write"]
|
||||
|
||||
@@ -20,6 +20,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
tui = ["textual>=0.75.0"]
|
||||
webhook = ["aiohttp>=3.9.0"]
|
||||
|
||||
[project.scripts]
|
||||
hive = "framework.cli:main"
|
||||
|
||||
@@ -157,39 +157,6 @@ class TestEventLoopOutputKeyOverlap:
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(key_errors) == 0
|
||||
|
||||
def test_overlapping_keys_non_event_loop_no_error(self):
|
||||
"""Non-event_loop nodes with overlapping keys -> no error (last-wins OK)."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="a",
|
||||
description="Node a",
|
||||
node_type="llm_generate",
|
||||
output_keys=["shared"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="llm_generate",
|
||||
output_keys=["shared"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(key_errors) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Baseline: no fan-out -> no errors from these rules
|
||||
|
||||
@@ -85,14 +85,14 @@ async def test_direct_key_access_in_conditional_edge():
|
||||
id="score_node",
|
||||
name="ScoreNode",
|
||||
description="Outputs a score",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["score"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="high_score_node",
|
||||
name="HighScoreNode",
|
||||
description="Handles high scores",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["score"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
@@ -153,14 +153,14 @@ async def test_backward_compatibility_output_syntax():
|
||||
id="score_node",
|
||||
name="ScoreNode",
|
||||
description="Outputs a score",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["score"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="consumer_node",
|
||||
name="ConsumerNode",
|
||||
description="Consumer",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["score"],
|
||||
output_keys=["processed"],
|
||||
),
|
||||
@@ -221,14 +221,14 @@ async def test_multiple_keys_in_expression():
|
||||
id="multi_key_node",
|
||||
name="MultiKeyNode",
|
||||
description="Outputs multiple keys",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["x", "y"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="consumer_node",
|
||||
name="ConsumerNode",
|
||||
description="Consumer",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["x", "y"],
|
||||
output_keys=["processed"],
|
||||
),
|
||||
@@ -295,14 +295,14 @@ async def test_negative_case_condition_false():
|
||||
id="low_score_node",
|
||||
name="LowScoreNode",
|
||||
description="Outputs low score",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["score"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="high_score_handler",
|
||||
name="HighScoreHandler",
|
||||
description="Should NOT execute",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["score"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
|
||||
@@ -0,0 +1,538 @@
|
||||
"""Tests for the Continuous Agent architecture (conversation threading + cumulative tools).
|
||||
|
||||
Validates:
|
||||
- conversation_mode="isolated" preserves existing behavior
|
||||
- conversation_mode="continuous" threads one conversation across nodes
|
||||
- Transition markers are inserted at phase boundaries
|
||||
- System prompt updates at each transition (layered prompt composition)
|
||||
- Tools accumulate across nodes in continuous mode
|
||||
- prompt_composer functions work correctly
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeResult, NodeSpec, SharedMemory
|
||||
from framework.graph.prompt_composer import (
|
||||
build_narrative,
|
||||
build_transition_marker,
|
||||
compose_system_prompt,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences."""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _set_output_scenario(key: str, value: str) -> list:
|
||||
"""LLM calls set_output then finishes."""
|
||||
return [
|
||||
ToolCallEvent(
|
||||
tool_use_id=f"call_{key}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": key, "value": value},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _text_then_set_output(text: str, key: str, value: str) -> list:
|
||||
"""LLM produces text, then calls set_output, then finishes (2 turns needed)."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
ToolCallEvent(
|
||||
tool_use_id=f"call_{key}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": key, "value": value},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _text_finish(text: str) -> list:
|
||||
"""LLM produces text and stops (triggers judge)."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _make_runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
def _make_goal():
|
||||
return Goal(id="g1", name="test", description="test goal")
|
||||
|
||||
|
||||
def _make_tool(name: str) -> Tool:
|
||||
return Tool(
|
||||
name=name,
|
||||
description=f"Tool {name}",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# prompt_composer unit tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestComposeSystemPrompt:
|
||||
def test_all_layers(self):
|
||||
result = compose_system_prompt(
|
||||
identity_prompt="I am a research agent.",
|
||||
focus_prompt="Focus on writing the report.",
|
||||
narrative="We found 5 sources on topic X.",
|
||||
)
|
||||
assert "I am a research agent." in result
|
||||
assert "Focus on writing the report." in result
|
||||
assert "We found 5 sources on topic X." in result
|
||||
# Identity comes first
|
||||
assert result.index("I am a research agent.") < result.index("Focus on writing")
|
||||
|
||||
def test_identity_only(self):
|
||||
result = compose_system_prompt(identity_prompt="I am an agent.", focus_prompt=None)
|
||||
assert result == "I am an agent."
|
||||
|
||||
def test_focus_only(self):
|
||||
result = compose_system_prompt(identity_prompt=None, focus_prompt="Do the thing.")
|
||||
assert "Current Focus" in result
|
||||
assert "Do the thing." in result
|
||||
|
||||
def test_empty(self):
|
||||
result = compose_system_prompt(identity_prompt=None, focus_prompt=None)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestBuildNarrative:
|
||||
def test_with_execution_path(self):
|
||||
memory = SharedMemory()
|
||||
memory.write("findings", "some findings")
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a", name="Research", description="Research the topic", node_type="event_loop"
|
||||
)
|
||||
node_b = NodeSpec(id="b", name="Report", description="Write report", node_type="event_loop")
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
result = build_narrative(memory, ["a"], graph)
|
||||
assert "Research" in result
|
||||
assert "findings" in result
|
||||
|
||||
def test_empty_state(self):
|
||||
memory = SharedMemory()
|
||||
graph = GraphSpec(id="g1", goal_id="g1", entry_node="a", nodes=[], edges=[])
|
||||
result = build_narrative(memory, [], graph)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestBuildTransitionMarker:
|
||||
def test_basic_marker(self):
|
||||
prev = NodeSpec(
|
||||
id="research", name="Research", description="Find sources", node_type="event_loop"
|
||||
)
|
||||
next_n = NodeSpec(
|
||||
id="report", name="Report", description="Write report", node_type="event_loop"
|
||||
)
|
||||
memory = SharedMemory()
|
||||
memory.write("findings", "important stuff")
|
||||
|
||||
marker = build_transition_marker(
|
||||
previous_node=prev,
|
||||
next_node=next_n,
|
||||
memory=memory,
|
||||
cumulative_tool_names=["web_search", "save_data"],
|
||||
)
|
||||
|
||||
assert "PHASE TRANSITION" in marker
|
||||
assert "Research" in marker
|
||||
assert "Report" in marker
|
||||
assert "findings" in marker
|
||||
assert "web_search" in marker
|
||||
assert "reflect" in marker.lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# NodeConversation.update_system_prompt
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestUpdateSystemPrompt:
|
||||
def test_update(self):
|
||||
conv = NodeConversation(system_prompt="original")
|
||||
assert conv.system_prompt == "original"
|
||||
conv.update_system_prompt("updated")
|
||||
assert conv.system_prompt == "updated"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Conversation threading through executor
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestContinuousConversation:
|
||||
"""Test that conversation_mode='continuous' threads a single conversation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_isolated_mode_no_conversation_in_result(self):
|
||||
"""In isolated mode, NodeResult.conversation should be None."""
|
||||
runtime = _make_runtime()
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_set_output_scenario("result", "done"),
|
||||
_text_finish("accepted"),
|
||||
]
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node1",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
conversation_mode="isolated",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuous_threads_conversation(self):
|
||||
"""In continuous mode, second node sees messages from first node."""
|
||||
runtime = _make_runtime()
|
||||
|
||||
# Node A: set_output("brief", "the brief"), then finish (accept)
|
||||
# Node B: set_output("report", "the report"), then finish (accept)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("I'll research this.", "brief", "the brief"),
|
||||
_text_finish(""), # triggers accept for node A (all keys set)
|
||||
_text_then_set_output("Here's the report.", "report", "the report"),
|
||||
_text_finish(""), # triggers accept for node B
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Intake",
|
||||
description="Gather requirements",
|
||||
node_type="event_loop",
|
||||
output_keys=["brief"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Report",
|
||||
description="Write report",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["report"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="continuous",
|
||||
identity_prompt="You are a thorough research agent.",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
|
||||
assert result.success
|
||||
assert result.path == ["a", "b"]
|
||||
|
||||
# Verify the LLM saw the identity prompt in system messages
|
||||
# The second node's system prompt should contain the identity
|
||||
if len(llm.stream_calls) >= 3:
|
||||
system_at_node_b = llm.stream_calls[2]["system"]
|
||||
assert "thorough research agent" in system_at_node_b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuous_transition_marker_present(self):
|
||||
"""Transition marker should appear in messages when switching nodes."""
|
||||
runtime = _make_runtime()
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("Research done.", "brief", "the brief"),
|
||||
_text_finish(""),
|
||||
_text_then_set_output("Report done.", "report", "the report"),
|
||||
_text_finish(""),
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Research",
|
||||
description="Do research",
|
||||
node_type="event_loop",
|
||||
output_keys=["brief"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Report",
|
||||
description="Write report",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["report"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="continuous",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
# When node B's first LLM call happens, its messages should contain
|
||||
# the transition marker from the executor
|
||||
if len(llm.stream_calls) >= 3:
|
||||
node_b_messages = llm.stream_calls[2]["messages"]
|
||||
all_content = " ".join(
|
||||
m.get("content", "") for m in node_b_messages if isinstance(m.get("content"), str)
|
||||
)
|
||||
assert "PHASE TRANSITION" in all_content
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Cumulative tools
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCumulativeTools:
|
||||
"""Test that tools accumulate in continuous mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_isolated_mode_tools_scoped(self):
|
||||
"""In isolated mode, each node only gets its own declared tools."""
|
||||
runtime = _make_runtime()
|
||||
tool_a = _make_tool("web_search")
|
||||
tool_b = _make_tool("save_data")
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("Done.", "brief", "brief"),
|
||||
_text_finish(""),
|
||||
_text_then_set_output("Done.", "report", "report"),
|
||||
_text_finish(""),
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Research",
|
||||
description="Research",
|
||||
node_type="event_loop",
|
||||
output_keys=["brief"],
|
||||
tools=["web_search"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Report",
|
||||
description="Report",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["report"],
|
||||
tools=["save_data"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="isolated",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime,
|
||||
llm=llm,
|
||||
tools=[tool_a, tool_b],
|
||||
)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
# In isolated mode, node B should NOT have web_search
|
||||
if len(llm.stream_calls) >= 3:
|
||||
node_b_tools = llm.stream_calls[2].get("tools") or []
|
||||
tool_names = [t.name for t in node_b_tools]
|
||||
assert "save_data" in tool_names or "set_output" in tool_names
|
||||
# web_search should NOT be present (only set_output + save_data)
|
||||
real_tools = [n for n in tool_names if n != "set_output"]
|
||||
assert "web_search" not in real_tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuous_mode_tools_accumulate(self):
|
||||
"""In continuous mode, node B should have both web_search and save_data."""
|
||||
runtime = _make_runtime()
|
||||
tool_a = _make_tool("web_search")
|
||||
tool_b = _make_tool("save_data")
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("Done.", "brief", "brief"),
|
||||
_text_finish(""),
|
||||
_text_then_set_output("Done.", "report", "report"),
|
||||
_text_finish(""),
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Research",
|
||||
description="Research",
|
||||
node_type="event_loop",
|
||||
output_keys=["brief"],
|
||||
tools=["web_search"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Report",
|
||||
description="Report",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["report"],
|
||||
tools=["save_data"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="continuous",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime,
|
||||
llm=llm,
|
||||
tools=[tool_a, tool_b],
|
||||
)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
# In continuous mode, node B should have BOTH tools
|
||||
if len(llm.stream_calls) >= 3:
|
||||
node_b_tools = llm.stream_calls[2].get("tools") or []
|
||||
tool_names = [t.name for t in node_b_tools]
|
||||
real_tools = [n for n in tool_names if n != "set_output"]
|
||||
assert "web_search" in real_tools
|
||||
assert "save_data" in real_tools
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Schema field defaults
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSchemaDefaults:
|
||||
def test_graphspec_defaults(self):
|
||||
"""New fields should have safe defaults."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[],
|
||||
edges=[],
|
||||
)
|
||||
assert graph.conversation_mode == "continuous"
|
||||
assert graph.identity_prompt is None
|
||||
|
||||
def test_nodespec_defaults(self):
|
||||
"""NodeSpec.success_criteria should default to None."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="test",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
assert spec.success_criteria is None
|
||||
|
||||
def test_noderesult_defaults(self):
|
||||
"""NodeResult.conversation should default to None."""
|
||||
result = NodeResult(success=True)
|
||||
assert result.conversation is None
|
||||
@@ -0,0 +1,380 @@
|
||||
"""Tests for Level 2 conversation-aware judge.
|
||||
|
||||
Validates:
|
||||
- No success_criteria → Level 0 only (existing behavior)
|
||||
- success_criteria set, good conversation → Level 2 ACCEPT
|
||||
- success_criteria set, poor conversation → Level 2 RETRY with feedback
|
||||
- Custom explicit judge takes priority over Level 2
|
||||
- Level 2 fires only when Level 0 passes (all keys set)
|
||||
- _parse_verdict correctly parses LLM responses
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.conversation_judge import (
|
||||
_parse_verdict,
|
||||
evaluate_phase_completion,
|
||||
)
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences."""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None, complete_response: str = ""):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
self.complete_response = complete_response
|
||||
self.complete_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
self.complete_calls.append({"messages": messages, "system": system})
|
||||
return LLMResponse(content=self.complete_response, model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _set_output_scenario(key: str, value: str) -> list:
|
||||
return [
|
||||
ToolCallEvent(
|
||||
tool_use_id=f"call_{key}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": key, "value": value},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _text_then_set_output(text: str, key: str, value: str) -> list:
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
ToolCallEvent(
|
||||
tool_use_id=f"call_{key}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": key, "value": value},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _text_finish(text: str) -> list:
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def _make_runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
def _make_goal():
|
||||
return Goal(id="g1", name="test", description="test goal")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Unit tests for _parse_verdict
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestParseVerdict:
|
||||
def test_accept(self):
|
||||
v = _parse_verdict("ACTION: ACCEPT\nCONFIDENCE: 0.9\nFEEDBACK:")
|
||||
assert v.action == "ACCEPT"
|
||||
assert v.confidence == 0.9
|
||||
assert v.feedback == ""
|
||||
|
||||
def test_retry_with_feedback(self):
|
||||
v = _parse_verdict("ACTION: RETRY\nCONFIDENCE: 0.6\nFEEDBACK: Research is too shallow.")
|
||||
assert v.action == "RETRY"
|
||||
assert v.confidence == 0.6
|
||||
assert "shallow" in v.feedback
|
||||
|
||||
def test_defaults_on_garbage(self):
|
||||
v = _parse_verdict("some random text\nno structured output")
|
||||
assert v.action == "ACCEPT" # default
|
||||
assert v.confidence == 0.8 # default
|
||||
|
||||
def test_invalid_action_defaults_to_accept(self):
|
||||
v = _parse_verdict("ACTION: ESCALATE\nCONFIDENCE: 0.5")
|
||||
assert v.action == "ACCEPT" # ESCALATE not valid for Level 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Unit tests for evaluate_phase_completion
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEvaluatePhaseCompletion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_on_good_response(self):
|
||||
"""LLM says ACCEPT → verdict is ACCEPT."""
|
||||
llm = MockStreamingLLM(complete_response="ACTION: ACCEPT\nCONFIDENCE: 0.95\nFEEDBACK:")
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
await conv.add_user_message("Do research on topic X")
|
||||
await conv.add_assistant_message("I found 5 high-quality sources on X.")
|
||||
|
||||
verdict = await evaluate_phase_completion(
|
||||
llm=llm,
|
||||
conversation=conv,
|
||||
phase_name="Research",
|
||||
phase_description="Research the topic",
|
||||
success_criteria="Find at least 3 credible sources",
|
||||
accumulator_state={"findings": "5 sources found"},
|
||||
)
|
||||
assert verdict.action == "ACCEPT"
|
||||
assert verdict.confidence == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_on_poor_response(self):
|
||||
"""LLM says RETRY → verdict is RETRY with feedback."""
|
||||
llm = MockStreamingLLM(
|
||||
complete_response=(
|
||||
"ACTION: RETRY\nCONFIDENCE: 0.4\nFEEDBACK: Only found 1 source, need 3."
|
||||
)
|
||||
)
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
await conv.add_user_message("Do research")
|
||||
await conv.add_assistant_message("I found 1 source.")
|
||||
|
||||
verdict = await evaluate_phase_completion(
|
||||
llm=llm,
|
||||
conversation=conv,
|
||||
phase_name="Research",
|
||||
phase_description="Research the topic",
|
||||
success_criteria="Find at least 3 credible sources",
|
||||
accumulator_state={"findings": "1 source"},
|
||||
)
|
||||
assert verdict.action == "RETRY"
|
||||
assert "1 source" in verdict.feedback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_defaults_to_accept(self):
|
||||
"""When LLM fails, Level 2 should not block (Level 0 already passed)."""
|
||||
llm = MockStreamingLLM()
|
||||
# Make complete() raise an exception
|
||||
llm.complete = MagicMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
|
||||
conv = NodeConversation(system_prompt="test")
|
||||
await conv.add_assistant_message("Done.")
|
||||
|
||||
verdict = await evaluate_phase_completion(
|
||||
llm=llm,
|
||||
conversation=conv,
|
||||
phase_name="Test",
|
||||
phase_description="Test phase",
|
||||
success_criteria="Do the thing",
|
||||
accumulator_state={"result": "done"},
|
||||
)
|
||||
assert verdict.action == "ACCEPT"
|
||||
assert verdict.confidence == 0.5
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Integration: Level 2 in EventLoopNode implicit judge
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLevel2InImplicitJudge:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_success_criteria_level0_only(self):
|
||||
"""Without success_criteria, Level 0 accepts normally (existing behavior)."""
|
||||
runtime = _make_runtime()
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_set_output_scenario("result", "done"),
|
||||
_text_finish("accepted"),
|
||||
]
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node1",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
# No success_criteria!
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
# LLM.complete should NOT have been called for Level 2
|
||||
assert len(llm.complete_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_criteria_accept(self):
|
||||
"""With success_criteria and good work, Level 2 accepts."""
|
||||
runtime = _make_runtime()
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("I did thorough research.", "result", "done"),
|
||||
_text_finish(""), # triggers judge
|
||||
],
|
||||
complete_response="ACTION: ACCEPT\nCONFIDENCE: 0.9\nFEEDBACK:",
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Research",
|
||||
description="Do research",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
success_criteria="Provide thorough research with multiple sources.",
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
# LLM.complete should have been called for Level 2
|
||||
assert len(llm.complete_calls) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_criteria_retry_then_accept(self):
|
||||
"""Level 2 rejects first attempt, LLM tries again, Level 2 accepts."""
|
||||
runtime = _make_runtime()
|
||||
|
||||
# Track complete calls to alternate responses
|
||||
complete_responses = [
|
||||
"ACTION: RETRY\nCONFIDENCE: 0.4\nFEEDBACK: Need more detail.",
|
||||
"ACTION: ACCEPT\nCONFIDENCE: 0.9\nFEEDBACK:",
|
||||
]
|
||||
call_count = [0]
|
||||
|
||||
class SequentialLLM(MockStreamingLLM):
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
idx = call_count[0]
|
||||
call_count[0] += 1
|
||||
resp = complete_responses[idx % len(complete_responses)]
|
||||
return LLMResponse(content=resp, model="mock", stop_reason="stop")
|
||||
|
||||
llm = SequentialLLM(
|
||||
scenarios=[
|
||||
# Turn 1: set output, then stop → Level 2 RETRY
|
||||
_text_then_set_output("Brief research.", "result", "brief"),
|
||||
_text_finish(""), # triggers judge → Level 2 RETRY
|
||||
# Turn 2: after retry feedback, set output again, stop → Level 2 ACCEPT
|
||||
_text_then_set_output("Much more detailed research.", "result", "detailed"),
|
||||
_text_finish(""), # triggers judge → Level 2 ACCEPT
|
||||
]
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Research",
|
||||
description="Do research",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
success_criteria="Provide thorough research with multiple sources.",
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
# Should have had 2 complete calls (first RETRY, second ACCEPT)
|
||||
assert call_count[0] >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_level2_only_fires_when_level0_passes(self):
|
||||
"""Level 2 should NOT fire when output keys are missing."""
|
||||
runtime = _make_runtime()
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: just text, no set_output → Level 0 RETRY (missing keys)
|
||||
_text_finish("I did some thinking."),
|
||||
# Turn 2: set output → Level 0 ACCEPT, Level 2 check
|
||||
_text_then_set_output("Now I have output.", "result", "done"),
|
||||
_text_finish(""), # triggers judge
|
||||
],
|
||||
complete_response="ACTION: ACCEPT\nCONFIDENCE: 0.9\nFEEDBACK:",
|
||||
)
|
||||
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Research",
|
||||
description="Do research",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
success_criteria="Provide results.",
|
||||
)
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="n1",
|
||||
nodes=[spec],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
# Level 2 should only fire once (when Level 0 passes)
|
||||
assert len(llm.complete_calls) == 1
|
||||
@@ -86,6 +86,7 @@ class ScriptableMockLLMProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="Conversation summary for compaction.",
|
||||
@@ -825,7 +826,7 @@ async def test_event_loop_no_executor_retry(runtime):
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert not result.success
|
||||
assert failing_node.attempt_count == 1 # Executor forced max_retries to 0
|
||||
assert failing_node.attempt_count == 3 # Custom nodes keep their max_retries
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
@@ -929,6 +930,7 @@ async def test_context_handoff_between_nodes(runtime):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_client_facing_node_streams_output():
|
||||
"""Client-facing node emits CLIENT_OUTPUT_DELTA events."""
|
||||
recorded: list[AgentEvent] = []
|
||||
@@ -1005,11 +1007,20 @@ async def test_internal_node_no_client_output():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_node_graph(runtime):
|
||||
"""function -> event_loop -> function end-to-end."""
|
||||
"""Simple node -> event_loop -> simple node end-to-end."""
|
||||
|
||||
# Function 1: write leads to memory
|
||||
def load_leads(**kwargs):
|
||||
return ["lead_A", "lead_B", "lead_C"]
|
||||
class LoadLeadsNode(NodeProtocol):
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
leads = ["lead_A", "lead_B", "lead_C"]
|
||||
ctx.memory.write("leads", leads)
|
||||
return NodeResult(success=True, output={"leads": leads})
|
||||
|
||||
class FormatOutputNode(NodeProtocol):
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
summary = ctx.input_data.get("summary", ctx.memory.read("summary") or "no summary")
|
||||
report = f"Report: {summary}"
|
||||
ctx.memory.write("report", report)
|
||||
return NodeResult(success=True, output={"report": report})
|
||||
|
||||
# Event loop: process leads, produce summary
|
||||
el_scripts = [
|
||||
@@ -1026,18 +1037,12 @@ async def test_mixed_node_graph(runtime):
|
||||
]
|
||||
el_llm = ScriptableMockLLMProvider(el_scripts)
|
||||
|
||||
# Function 2: format final output
|
||||
def format_output(**kwargs):
|
||||
summary = kwargs.get("summary", "no summary")
|
||||
return f"Report: {summary}"
|
||||
|
||||
# Node specs
|
||||
load_spec = NodeSpec(
|
||||
id="load",
|
||||
name="Load Leads",
|
||||
description="Load lead data",
|
||||
node_type="function",
|
||||
function="load_leads",
|
||||
node_type="event_loop",
|
||||
output_keys=["leads"],
|
||||
)
|
||||
process_spec = NodeSpec(
|
||||
@@ -1045,17 +1050,13 @@ async def test_mixed_node_graph(runtime):
|
||||
name="Process Leads",
|
||||
description="Process leads with LLM",
|
||||
node_type="event_loop",
|
||||
# input_keys left empty: EventLoopNode._check_pause() reads "pause_requested"
|
||||
# from memory, and a restrictive scope would block it. Data flows via input_data.
|
||||
output_keys=["summary"],
|
||||
)
|
||||
format_spec = NodeSpec(
|
||||
id="format",
|
||||
name="Format Output",
|
||||
description="Format final report",
|
||||
node_type="function",
|
||||
function="format_output",
|
||||
# input_keys left empty for same scoping reason with FunctionNode
|
||||
node_type="event_loop",
|
||||
output_keys=["report"],
|
||||
)
|
||||
|
||||
@@ -1076,9 +1077,9 @@ async def test_mixed_node_graph(runtime):
|
||||
goal = Goal(id="test_goal", name="Pipeline Test", description="test full pipeline")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=el_llm)
|
||||
executor.register_function("load", load_leads)
|
||||
executor.register_node("load", LoadLeadsNode())
|
||||
executor.register_node("process", EventLoopNode(config=LoopConfig(max_iterations=5)))
|
||||
executor.register_function("format", format_output)
|
||||
executor.register_node("format", FormatOutputNode())
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
|
||||
@@ -425,6 +425,7 @@ class TestEventBusLifecycle:
|
||||
assert EventType.NODE_LOOP_COMPLETED in received_events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
|
||||
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
|
||||
spec = NodeSpec(
|
||||
@@ -475,6 +476,7 @@ class TestClientFacingBlocking:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_text_only_no_blocking(self, runtime, memory, client_spec):
|
||||
"""client_facing + text-only (no ask_user) should NOT block."""
|
||||
llm = MockStreamingLLM(
|
||||
@@ -630,6 +632,7 @@ class TestClientFacingBlocking:
|
||||
assert received[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_ask_user_with_real_tools(self, runtime, memory):
|
||||
"""ask_user alongside real tool calls still triggers blocking."""
|
||||
spec = NodeSpec(
|
||||
@@ -993,3 +996,697 @@ class TestOutputAccumulator:
|
||||
assert acc.get("key1") == "val1"
|
||||
assert acc.get("key2") == "val2"
|
||||
assert acc.has_all_keys(["key1", "key2"]) is True
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Transient error retry (ITEM 2)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class ErrorThenSuccessLLM(LLMProvider):
|
||||
"""LLM that raises on the first N calls, then succeeds.
|
||||
|
||||
Used to test the retry-with-backoff wrapper around _run_single_turn().
|
||||
"""
|
||||
|
||||
def __init__(self, error: Exception, fail_count: int, success_scenario: list):
|
||||
self.error = error
|
||||
self.fail_count = fail_count
|
||||
self.success_scenario = success_scenario
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
call_num = self._call_index
|
||||
self._call_index += 1
|
||||
if call_num < self.fail_count:
|
||||
raise self.error
|
||||
for event in self.success_scenario:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="ok", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
class TestTransientErrorRetry:
|
||||
"""Test retry-with-backoff for transient LLM errors in EventLoopNode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_retries_then_succeeds(self, runtime, node_spec, memory):
|
||||
"""A transient error on the first try should retry and succeed."""
|
||||
node_spec.output_keys = []
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ConnectionError("connection reset"),
|
||||
fail_count=1,
|
||||
success_scenario=text_scenario("success"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01, # fast for tests
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert llm._call_index == 2 # 1 failure + 1 success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permanent_error_no_retry(self, runtime, node_spec, memory):
|
||||
"""A permanent error (ValueError) should NOT be retried."""
|
||||
node_spec.output_keys = []
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ValueError("bad request: invalid model"),
|
||||
fail_count=1,
|
||||
success_scenario=text_scenario("success"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError, match="bad request"):
|
||||
await node.execute(ctx)
|
||||
assert llm._call_index == 1 # only tried once
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_exhausts_retries(self, runtime, node_spec, memory):
|
||||
"""Transient errors that exhaust retries should raise."""
|
||||
node_spec.output_keys = []
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=TimeoutError("request timed out"),
|
||||
fail_count=100, # always fails
|
||||
success_scenario=text_scenario("unreachable"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=2,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
with pytest.raises(TimeoutError, match="request timed out"):
|
||||
await node.execute(ctx)
|
||||
assert llm._call_index == 3 # 1 initial + 2 retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_error_event_retried_as_runtime_error(self, runtime, node_spec, memory):
|
||||
"""StreamErrorEvent(recoverable=False) raises RuntimeError caught by retry."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
# Scenario: non-recoverable StreamErrorEvent with transient keywords
|
||||
error_scenario = [
|
||||
StreamErrorEvent(
|
||||
error="Stream error: 503 service unavailable",
|
||||
recoverable=False,
|
||||
)
|
||||
]
|
||||
success_scenario = text_scenario("recovered")
|
||||
|
||||
call_index = 0
|
||||
|
||||
class StreamErrorThenSuccessLLM(LLMProvider):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
nonlocal call_index
|
||||
idx = call_index
|
||||
call_index += 1
|
||||
if idx == 0:
|
||||
for event in error_scenario:
|
||||
yield event
|
||||
else:
|
||||
for event in success_scenario:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
return LLMResponse(
|
||||
content="ok",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kwargs,
|
||||
):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
llm = StreamErrorThenSuccessLLM()
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert call_index == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_emits_event_bus_event(self, runtime, node_spec, memory):
|
||||
"""Retry should emit NODE_RETRY event on the event bus."""
|
||||
node_spec.output_keys = []
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ConnectionError("network down"),
|
||||
fail_count=1,
|
||||
success_scenario=text_scenario("ok"),
|
||||
)
|
||||
bus = EventBus()
|
||||
retry_events = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.NODE_RETRY],
|
||||
handler=lambda e: retry_events.append(e),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert len(retry_events) == 1
|
||||
assert retry_events[0].data["retry_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recoverable_stream_error_retried_not_silent(self, runtime, node_spec, memory):
|
||||
"""Recoverable StreamErrorEvent with empty response should raise ConnectionError.
|
||||
|
||||
Previously, recoverable stream errors were silently swallowed,
|
||||
producing empty responses that the judge retried — creating an
|
||||
infinite loop of 50+ empty-response iterations. Now they raise
|
||||
ConnectionError so the outer transient-error retry handles them
|
||||
with proper backoff.
|
||||
"""
|
||||
node_spec.output_keys = ["result"]
|
||||
|
||||
call_index = 0
|
||||
|
||||
class RecoverableErrorThenSuccessLLM(LLMProvider):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
nonlocal call_index
|
||||
idx = call_index
|
||||
call_index += 1
|
||||
if idx == 0:
|
||||
# Recoverable error with no content
|
||||
yield StreamErrorEvent(
|
||||
error="503 service unavailable",
|
||||
recoverable=True,
|
||||
)
|
||||
elif idx == 1:
|
||||
# Success: set output
|
||||
for event in tool_call_scenario(
|
||||
"set_output", {"key": "result", "value": "done"}
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
# Subsequent calls: text-only (no more tool calls)
|
||||
for event in text_scenario("done"):
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
return LLMResponse(content="ok", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs):
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
llm = RecoverableErrorThenSuccessLLM()
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=5,
|
||||
max_stream_retries=3,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert result.output.get("result") == "done"
|
||||
# call 0: recoverable error → ConnectionError raised → outer retry
|
||||
# call 1: set_output tool call succeeds
|
||||
# call 2: inner tool loop re-invokes LLM after tool result → text "done"
|
||||
assert call_index == 3
|
||||
|
||||
|
||||
class TestIsTransientError:
|
||||
"""Unit tests for _is_transient_error() classification."""
|
||||
|
||||
def test_timeout_error(self):
|
||||
assert EventLoopNode._is_transient_error(TimeoutError("timed out")) is True
|
||||
|
||||
def test_connection_error(self):
|
||||
assert EventLoopNode._is_transient_error(ConnectionError("reset")) is True
|
||||
|
||||
def test_os_error(self):
|
||||
assert EventLoopNode._is_transient_error(OSError("network unreachable")) is True
|
||||
|
||||
def test_value_error_not_transient(self):
|
||||
assert EventLoopNode._is_transient_error(ValueError("bad input")) is False
|
||||
|
||||
def test_type_error_not_transient(self):
|
||||
assert EventLoopNode._is_transient_error(TypeError("wrong type")) is False
|
||||
|
||||
def test_runtime_error_with_transient_keywords(self):
|
||||
check = EventLoopNode._is_transient_error
|
||||
assert check(RuntimeError("Stream error: 429 rate limit")) is True
|
||||
assert check(RuntimeError("Stream error: 503")) is True
|
||||
assert check(RuntimeError("Stream error: connection reset")) is True
|
||||
assert check(RuntimeError("Stream error: timeout exceeded")) is True
|
||||
|
||||
def test_runtime_error_without_transient_keywords(self):
|
||||
assert EventLoopNode._is_transient_error(RuntimeError("authentication failed")) is False
|
||||
assert EventLoopNode._is_transient_error(RuntimeError("invalid JSON in response")) is False
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tool doom loop detection (ITEM 1)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestFingerprintToolCalls:
|
||||
"""Unit tests for _fingerprint_tool_calls()."""
|
||||
|
||||
def test_basic_fingerprint(self):
|
||||
results = [
|
||||
{"tool_name": "search", "tool_input": {"q": "hello"}},
|
||||
]
|
||||
fps = EventLoopNode._fingerprint_tool_calls(results)
|
||||
assert len(fps) == 1
|
||||
assert fps[0][0] == "search"
|
||||
# Args should be JSON with sort_keys
|
||||
assert fps[0][1] == '{"q": "hello"}'
|
||||
|
||||
def test_order_sensitive(self):
|
||||
r1 = [
|
||||
{"tool_name": "search", "tool_input": {"q": "a"}},
|
||||
{"tool_name": "fetch", "tool_input": {"url": "b"}},
|
||||
]
|
||||
r2 = [
|
||||
{"tool_name": "fetch", "tool_input": {"url": "b"}},
|
||||
{"tool_name": "search", "tool_input": {"q": "a"}},
|
||||
]
|
||||
assert EventLoopNode._fingerprint_tool_calls(r1) != (
|
||||
EventLoopNode._fingerprint_tool_calls(r2)
|
||||
)
|
||||
|
||||
def test_sort_keys_deterministic(self):
|
||||
r1 = [{"tool_name": "t", "tool_input": {"b": 2, "a": 1}}]
|
||||
r2 = [{"tool_name": "t", "tool_input": {"a": 1, "b": 2}}]
|
||||
assert EventLoopNode._fingerprint_tool_calls(r1) == EventLoopNode._fingerprint_tool_calls(
|
||||
r2
|
||||
)
|
||||
|
||||
|
||||
class TestIsToolDoomLoop:
|
||||
"""Unit tests for _is_tool_doom_loop()."""
|
||||
|
||||
def test_below_threshold(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
fp = [("search", '{"q": "hello"}')]
|
||||
is_doom, _ = node._is_tool_doom_loop([fp, fp])
|
||||
assert is_doom is False
|
||||
|
||||
def test_at_threshold_identical(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
fp = [("search", '{"q": "hello"}')]
|
||||
is_doom, desc = node._is_tool_doom_loop([fp, fp, fp])
|
||||
assert is_doom is True
|
||||
assert "search" in desc
|
||||
|
||||
def test_different_args_no_doom(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
fp1 = [("search", '{"q": "a"}')]
|
||||
fp2 = [("search", '{"q": "b"}')]
|
||||
fp3 = [("search", '{"q": "c"}')]
|
||||
is_doom, _ = node._is_tool_doom_loop([fp1, fp2, fp3])
|
||||
assert is_doom is False
|
||||
|
||||
def test_disabled_via_config(self):
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(tool_doom_loop_enabled=False),
|
||||
)
|
||||
fp = [("search", '{"q": "hello"}')]
|
||||
is_doom, _ = node._is_tool_doom_loop([fp, fp, fp])
|
||||
assert is_doom is False
|
||||
|
||||
def test_empty_fingerprints_no_doom(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
is_doom, _ = node._is_tool_doom_loop([[], [], []])
|
||||
assert is_doom is False
|
||||
|
||||
|
||||
class ToolRepeatLLM(LLMProvider):
|
||||
"""LLM that produces identical tool calls across outer iterations.
|
||||
|
||||
Alternates: even calls → tool call, odd calls → text (exits inner loop).
|
||||
This ensures each outer iteration = 2 LLM calls with 1 tool executed.
|
||||
After tool_turns outer iterations, always returns text.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
tool_turns: int,
|
||||
final_text: str = "done",
|
||||
):
|
||||
self.tool_name = tool_name
|
||||
self.tool_input = tool_input
|
||||
self.tool_turns = tool_turns
|
||||
self.final_text = final_text
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
idx = self._call_index
|
||||
self._call_index += 1
|
||||
# Which outer iteration we're in (2 calls per iteration)
|
||||
outer_iter = idx // 2
|
||||
is_tool_call = (idx % 2 == 0) and outer_iter < self.tool_turns
|
||||
if is_tool_call:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"call_{outer_iter}",
|
||||
tool_name=self.tool_name,
|
||||
tool_input=self.tool_input,
|
||||
)
|
||||
yield FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
else:
|
||||
# Unique text per call to avoid stall detection
|
||||
text = f"{self.final_text} (call {idx})"
|
||||
yield TextDeltaEvent(content=text, snapshot=text)
|
||||
yield FinishEvent(
|
||||
stop_reason="stop",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="ok",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
|
||||
class TestToolDoomLoopIntegration:
|
||||
"""Integration tests for doom loop detection in execute().
|
||||
|
||||
Uses ToolRepeatLLM: returns tool calls for first N calls, then text.
|
||||
Each outer iteration = 2 LLM calls (tool call + text exit for inner loop).
|
||||
logged_tool_calls accumulates across inner iterations.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doom_loop_injects_warning(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""3 identical tool call turns should inject a warning."""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
# 3 tool calls (6 LLM calls: tool+text each), then 1 text
|
||||
llm = ToolRepeatLLM("search", {"q": "hello"}, tool_turns=3)
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doom_loop_emits_event(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""Doom loop should emit NODE_TOOL_DOOM_LOOP event."""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
llm = ToolRepeatLLM("search", {"q": "hello"}, tool_turns=3)
|
||||
bus = EventBus()
|
||||
doom_events: list = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.NODE_TOOL_DOOM_LOOP],
|
||||
handler=lambda e: doom_events.append(e),
|
||||
)
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
event_bus=bus,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
assert len(doom_events) == 1
|
||||
assert "search" in doom_events[0].data["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doom_loop_disabled(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""Disabled doom loop should not trigger with identical calls."""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
llm = ToolRepeatLLM("search", {"q": "hello"}, tool_turns=4)
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_enabled=False,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_args_no_doom_loop(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""Different tool args each turn should NOT trigger doom loop."""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 4:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
# LLM that returns different args each call
|
||||
call_idx = 0
|
||||
|
||||
class DiffArgsLLM(LLMProvider):
|
||||
async def stream(self, messages, **kwargs):
|
||||
nonlocal call_idx
|
||||
idx = call_idx
|
||||
call_idx += 1
|
||||
if idx < 3:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"c{idx}",
|
||||
tool_name="search",
|
||||
tool_input={"q": f"query_{idx}"},
|
||||
)
|
||||
yield FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
else:
|
||||
text = f"done (call {idx})"
|
||||
yield TextDeltaEvent(
|
||||
content=text,
|
||||
snapshot=text,
|
||||
)
|
||||
yield FinishEvent(
|
||||
stop_reason="stop",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
|
||||
def complete(self, messages, **kwargs):
|
||||
return LLMResponse(
|
||||
content="ok",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kw,
|
||||
):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
llm = DiffArgsLLM()
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="result",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
@@ -65,7 +65,7 @@ def test_client_facing_defaults_false():
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
)
|
||||
assert spec.client_facing is False
|
||||
|
||||
@@ -143,7 +143,7 @@ def test_registered_event_loop_returns_impl(runtime):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_forced_zero(runtime):
|
||||
"""An event_loop node with max_retries=3 should only execute once (no retry)."""
|
||||
"""Custom NodeProtocol impls with node_type=event_loop keep their max_retries."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_fail",
|
||||
name="Failing Event Loop",
|
||||
@@ -171,9 +171,9 @@ async def test_event_loop_max_retries_forced_zero(runtime):
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Event loop nodes get max_retries overridden to 0, meaning execute once then fail
|
||||
# Custom nodes (not EventLoopNode instances) keep their max_retries
|
||||
assert not result.success
|
||||
assert failing_node.attempt_count == 1
|
||||
assert failing_node.attempt_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -246,21 +246,21 @@ async def test_event_loop_max_retries_positive_logs_warning(runtime, caplog):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
assert "Overriding to 0" in caplog.text
|
||||
assert "el_warn" in caplog.text
|
||||
# Custom nodes (not EventLoopNode instances) don't get override warning
|
||||
assert "Overriding to 0" not in caplog.text
|
||||
|
||||
|
||||
# --- Existing node types unaffected ---
|
||||
|
||||
|
||||
def test_existing_node_types_unchanged():
|
||||
"""All pre-existing node types must still be in VALID_NODE_TYPES with defaults preserved."""
|
||||
expected = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
assert expected.issubset(GraphExecutor.VALID_NODE_TYPES)
|
||||
"""Only event_loop is a valid node type."""
|
||||
expected = {"event_loop"}
|
||||
assert expected == GraphExecutor.VALID_NODE_TYPES
|
||||
|
||||
# Default node_type is still llm_tool_use
|
||||
# Default node_type is event_loop
|
||||
spec = NodeSpec(id="x", name="X", description="x")
|
||||
assert spec.node_type == "llm_tool_use"
|
||||
assert spec.node_type == "event_loop"
|
||||
|
||||
# Default max_retries is still 3
|
||||
assert spec.max_retries == 3
|
||||
|
||||
@@ -106,7 +106,7 @@ class TestExecutionQuality:
|
||||
id="node1",
|
||||
name="Always Succeeds",
|
||||
description="Never fails",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
),
|
||||
],
|
||||
@@ -151,6 +151,7 @@ class TestExecutionQuality:
|
||||
)
|
||||
|
||||
# Create graph with flaky node (fails 2 times before succeeding)
|
||||
# (actual impl from registry is FlakyNode)
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id=goal.id,
|
||||
@@ -159,7 +160,7 @@ class TestExecutionQuality:
|
||||
id="flaky",
|
||||
name="Flaky Node",
|
||||
description="Fails then succeeds",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
max_retries=3, # Allow retries
|
||||
),
|
||||
@@ -206,6 +207,7 @@ class TestExecutionQuality:
|
||||
)
|
||||
|
||||
# Create graph with always-failing node
|
||||
# (actual impl from registry is AlwaysFailsNode)
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id=goal.id,
|
||||
@@ -214,7 +216,7 @@ class TestExecutionQuality:
|
||||
id="fails",
|
||||
name="Always Fails",
|
||||
description="Never succeeds",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
max_retries=2, # Will retry twice then fail
|
||||
),
|
||||
@@ -261,6 +263,7 @@ class TestExecutionQuality:
|
||||
)
|
||||
|
||||
# Create graph with multiple flaky nodes
|
||||
# (actual impls from registry are FlakyNode instances)
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id=goal.id,
|
||||
@@ -269,7 +272,7 @@ class TestExecutionQuality:
|
||||
id="flaky1",
|
||||
name="Flaky Node 1",
|
||||
description="Fails once",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result1"],
|
||||
max_retries=3,
|
||||
),
|
||||
@@ -277,7 +280,7 @@ class TestExecutionQuality:
|
||||
id="flaky2",
|
||||
name="Flaky Node 2",
|
||||
description="Fails twice",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["result1"],
|
||||
output_keys=["result2"],
|
||||
max_retries=3,
|
||||
@@ -286,7 +289,7 @@ class TestExecutionQuality:
|
||||
id="success",
|
||||
name="Success Node",
|
||||
description="Always succeeds",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["result2"],
|
||||
output_keys=["final"],
|
||||
),
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""Tests for ExecutionStream retention behavior."""
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph import Goal, NodeSpec, SuccessCriterion
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import FinishEvent, StreamEvent, TextDeltaEvent, ToolCallEvent
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.runtime.execution_stream import EntryPointSpec, ExecutionStream
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
@@ -16,7 +18,13 @@ from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
|
||||
class DummyLLMProvider(LLMProvider):
|
||||
"""Deterministic LLM provider for execution stream tests."""
|
||||
"""Deterministic LLM provider for execution stream tests.
|
||||
|
||||
Uses set_output tool call to properly set outputs, avoiding stall detection.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._call_count = 0
|
||||
|
||||
def complete(
|
||||
self,
|
||||
@@ -26,8 +34,9 @@ class DummyLLMProvider(LLMProvider):
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, object] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(content=json.dumps({"result": "ok"}), model="dummy")
|
||||
return LLMResponse(content="Summary for compaction.", model="dummy")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
@@ -37,7 +46,29 @@ class DummyLLMProvider(LLMProvider):
|
||||
tool_executor: Callable,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(content=json.dumps({"result": "ok"}), model="dummy")
|
||||
return LLMResponse(content="Summary for compaction.", model="dummy")
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
self._call_count += 1
|
||||
|
||||
if self._call_count == 1:
|
||||
# First call: set the output via tool call
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"tc_{self._call_count}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "result", "value": "ok"},
|
||||
)
|
||||
yield FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=10)
|
||||
else:
|
||||
# Subsequent calls: just finish with text
|
||||
yield TextDeltaEvent(content="Done.", snapshot="Done.")
|
||||
yield FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -61,7 +92,7 @@ async def test_execution_stream_retention(tmp_path):
|
||||
id="hello",
|
||||
name="Hello",
|
||||
description="Return a result",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
input_keys=["user_name"],
|
||||
output_keys=["result"],
|
||||
system_prompt='Return JSON: {"result": "ok"}',
|
||||
@@ -120,3 +151,146 @@ async def test_execution_stream_retention(tmp_path):
|
||||
|
||||
await stream.stop()
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_session_reuses_directory_and_memory(tmp_path):
|
||||
"""When an async entry point uses resume_session_id, it should:
|
||||
1. Run in the same session directory as the primary execution
|
||||
2. Have access to the primary session's memory
|
||||
3. NOT overwrite the primary session's state.json
|
||||
"""
|
||||
goal = Goal(
|
||||
id="test-goal",
|
||||
name="Test",
|
||||
description="Shared session test",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="result",
|
||||
description="Result present",
|
||||
metric="output_contains",
|
||||
target="result",
|
||||
)
|
||||
],
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
node = NodeSpec(
|
||||
id="hello",
|
||||
name="Hello",
|
||||
description="Return a result",
|
||||
node_type="event_loop",
|
||||
input_keys=["user_name"],
|
||||
output_keys=["result"],
|
||||
system_prompt='Return JSON: {"result": "ok"}',
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id=goal.id,
|
||||
version="1.0.0",
|
||||
entry_node="hello",
|
||||
entry_points={"start": "hello"},
|
||||
terminal_nodes=["hello"],
|
||||
pause_nodes=[],
|
||||
nodes=[node],
|
||||
edges=[],
|
||||
default_model="dummy",
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
from framework.storage.session_store import SessionStore
|
||||
|
||||
session_store = SessionStore(tmp_path)
|
||||
|
||||
# Primary stream
|
||||
primary_stream = ExecutionStream(
|
||||
stream_id="primary",
|
||||
entry_spec=EntryPointSpec(
|
||||
id="primary",
|
||||
name="Primary",
|
||||
entry_node="hello",
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
state_manager=SharedStateManager(),
|
||||
storage=storage,
|
||||
outcome_aggregator=OutcomeAggregator(goal, EventBus()),
|
||||
event_bus=None,
|
||||
llm=DummyLLMProvider(),
|
||||
tools=[],
|
||||
tool_executor=None,
|
||||
session_store=session_store,
|
||||
)
|
||||
|
||||
await primary_stream.start()
|
||||
|
||||
# Run primary execution — creates session directory and state.json
|
||||
primary_exec_id = await primary_stream.execute({"user_name": "alice"})
|
||||
primary_result = await primary_stream.wait_for_completion(primary_exec_id, timeout=5)
|
||||
assert primary_result is not None
|
||||
assert primary_result.success
|
||||
|
||||
# Verify primary session's state.json exists and has the primary entry_point
|
||||
primary_state_path = tmp_path / "sessions" / primary_exec_id / "state.json"
|
||||
assert primary_state_path.exists()
|
||||
primary_state = json.loads(primary_state_path.read_text())
|
||||
assert primary_state["entry_point"] == "primary"
|
||||
|
||||
# Async stream — simulates a webhook entry point sharing the session
|
||||
async_stream = ExecutionStream(
|
||||
stream_id="webhook",
|
||||
entry_spec=EntryPointSpec(
|
||||
id="webhook",
|
||||
name="Webhook",
|
||||
entry_node="hello",
|
||||
trigger_type="event",
|
||||
isolation_level="shared",
|
||||
),
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
state_manager=SharedStateManager(),
|
||||
storage=storage,
|
||||
outcome_aggregator=OutcomeAggregator(goal, EventBus()),
|
||||
event_bus=None,
|
||||
llm=DummyLLMProvider(),
|
||||
tools=[],
|
||||
tool_executor=None,
|
||||
session_store=session_store,
|
||||
)
|
||||
|
||||
await async_stream.start()
|
||||
|
||||
# Run async execution with resume_session_id pointing to primary session
|
||||
session_state = {
|
||||
"resume_session_id": primary_exec_id,
|
||||
"memory": {"rules": "star important emails"},
|
||||
}
|
||||
async_exec_id = await async_stream.execute({"event": "new_email"}, session_state=session_state)
|
||||
|
||||
# Should reuse the primary session ID
|
||||
assert async_exec_id == primary_exec_id
|
||||
|
||||
async_result = await async_stream.wait_for_completion(async_exec_id, timeout=5)
|
||||
assert async_result is not None
|
||||
assert async_result.success
|
||||
|
||||
# State.json should NOT have been overwritten by the async execution
|
||||
# (it should still show the primary entry point)
|
||||
final_state = json.loads(primary_state_path.read_text())
|
||||
assert final_state["entry_point"] == "primary"
|
||||
|
||||
# Verify only ONE session directory exists (not two)
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
session_dirs = [d for d in sessions_dir.iterdir() if d.is_dir()]
|
||||
assert len(session_dirs) == 1
|
||||
assert session_dirs[0].name == primary_exec_id
|
||||
|
||||
await primary_stream.stop()
|
||||
await async_stream.stop()
|
||||
await storage.stop()
|
||||
|
||||
@@ -81,7 +81,9 @@ def goal():
|
||||
|
||||
def test_max_node_visits_default():
|
||||
"""NodeSpec.max_node_visits should default to 1."""
|
||||
spec = NodeSpec(id="n", name="N", description="test", node_type="function", output_keys=["out"])
|
||||
spec = NodeSpec(
|
||||
id="n", name="N", description="test", node_type="event_loop", output_keys=["out"]
|
||||
)
|
||||
assert spec.max_node_visits == 1
|
||||
|
||||
|
||||
@@ -101,7 +103,7 @@ async def test_visit_limit_skips_node(runtime, goal):
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry with visit limit",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=1,
|
||||
)
|
||||
@@ -109,7 +111,7 @@ async def test_visit_limit_skips_node(runtime, goal):
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0, # unlimited — let max_steps guard
|
||||
)
|
||||
@@ -159,7 +161,7 @@ async def test_visit_limit_allows_multiple(runtime, goal):
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry allows two visits",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
@@ -167,7 +169,7 @@ async def test_visit_limit_allows_multiple(runtime, goal):
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0, # unlimited
|
||||
)
|
||||
@@ -215,7 +217,7 @@ async def test_visit_limit_zero_unlimited(runtime, goal):
|
||||
id="a",
|
||||
name="A",
|
||||
description="unlimited visits",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=0,
|
||||
)
|
||||
@@ -223,7 +225,7 @@ async def test_visit_limit_zero_unlimited(runtime, goal):
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0,
|
||||
)
|
||||
@@ -274,7 +276,7 @@ async def test_conditional_feedback_edge(runtime, goal):
|
||||
id="director",
|
||||
name="Director",
|
||||
description="plans work",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
@@ -282,7 +284,7 @@ async def test_conditional_feedback_edge(runtime, goal):
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="writes draft",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
@@ -290,7 +292,7 @@ async def test_conditional_feedback_edge(runtime, goal):
|
||||
id="output",
|
||||
name="Output",
|
||||
description="final output",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["final"],
|
||||
)
|
||||
|
||||
@@ -370,7 +372,7 @@ async def test_conditional_feedback_false(runtime, goal):
|
||||
id="director",
|
||||
name="Director",
|
||||
description="plans work",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
@@ -378,14 +380,14 @@ async def test_conditional_feedback_false(runtime, goal):
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="writes draft",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="final output",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["final"],
|
||||
)
|
||||
|
||||
@@ -458,14 +460,14 @@ async def test_visit_counts_in_result(runtime, goal):
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["a_out"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="terminal",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["a_out"],
|
||||
output_keys=["b_out"],
|
||||
)
|
||||
@@ -509,21 +511,21 @@ async def test_conditional_priority_prevents_fanout(runtime, goal):
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="produces output",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="forward target",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["final"],
|
||||
)
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="feedback target",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
|
||||
@@ -79,7 +79,7 @@ async def test_executor_respects_custom_max_retries_high(runtime):
|
||||
name="Flaky Node",
|
||||
description="A node that fails multiple times before succeeding",
|
||||
max_retries=10, # Should allow 10 retries
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
@@ -123,7 +123,7 @@ async def test_executor_respects_custom_max_retries_low(runtime):
|
||||
name="Fragile Node",
|
||||
description="A node with low retry tolerance",
|
||||
max_retries=2, # max_retries=N means N total attempts allowed
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
@@ -166,7 +166,7 @@ async def test_executor_respects_default_max_retries(runtime):
|
||||
name="Default Node",
|
||||
description="A node using default retry settings",
|
||||
# max_retries not specified, should default to 3
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
@@ -211,7 +211,7 @@ async def test_executor_max_retries_two_succeeds_on_second(runtime):
|
||||
name="Two Retry Node",
|
||||
description="A node with two attempts allowed",
|
||||
max_retries=2, # max_retries=N means N total attempts allowed
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
@@ -253,7 +253,7 @@ async def test_executor_different_nodes_different_max_retries(runtime):
|
||||
name="Node 1",
|
||||
description="First node in multi-node test",
|
||||
max_retries=2,
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["result1"],
|
||||
)
|
||||
|
||||
@@ -262,7 +262,7 @@ async def test_executor_different_nodes_different_max_retries(runtime):
|
||||
name="Node 2",
|
||||
description="Second node in multi-node test",
|
||||
max_retries=5,
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["result1"],
|
||||
output_keys=["result2"],
|
||||
)
|
||||
|
||||
+29
-25
@@ -116,7 +116,7 @@ def _make_fanout_graph(
|
||||
id="source",
|
||||
name="Source",
|
||||
description="entry",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["data"],
|
||||
)
|
||||
|
||||
@@ -164,10 +164,10 @@ def _make_fanout_graph(
|
||||
async def test_fanout_triggers_on_multiple_success_edges(runtime, goal):
|
||||
"""Fan-out should activate when a node has >1 ON_SUCCESS outgoing edges."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_out"]
|
||||
id="b1", name="B1", description="branch 1", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_out"]
|
||||
id="b2", name="B2", description="branch 2", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
@@ -195,10 +195,10 @@ async def test_branches_execute_concurrently(runtime, goal):
|
||||
"""All fan-out branches should be launched via asyncio.gather (concurrent)."""
|
||||
order = []
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_done"]
|
||||
id="b1", name="B1", description="branch 1", node_type="event_loop", output_keys=["b1_done"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_done"]
|
||||
id="b2", name="B2", description="branch 2", node_type="event_loop", output_keys=["b2_done"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
@@ -223,13 +223,17 @@ async def test_branches_execute_concurrently(runtime, goal):
|
||||
async def test_convergence_at_fan_in_node(runtime, goal):
|
||||
"""After fan-out branches complete, execution should continue at convergence node."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_out"]
|
||||
id="b1", name="B1", description="branch 1", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_out"]
|
||||
id="b2", name="B2", description="branch 2", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
merge = NodeSpec(
|
||||
id="merge", name="Merge", description="fan-in", node_type="function", output_keys=["merged"]
|
||||
id="merge",
|
||||
name="Merge",
|
||||
description="fan-in",
|
||||
node_type="event_loop",
|
||||
output_keys=["merged"],
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2], fan_in_node=merge)
|
||||
@@ -255,13 +259,13 @@ async def test_convergence_at_fan_in_node(runtime, goal):
|
||||
async def test_fail_all_strategy_raises_on_branch_failure(runtime, goal):
|
||||
"""fail_all should raise RuntimeError if any branch fails."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="ok branch", node_type="function", output_keys=["b1_out"]
|
||||
id="b1", name="B1", description="ok branch", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2",
|
||||
name="B2",
|
||||
description="bad branch",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["b2_out"],
|
||||
max_retries=1,
|
||||
)
|
||||
@@ -290,13 +294,13 @@ async def test_fail_all_strategy_raises_on_branch_failure(runtime, goal):
|
||||
async def test_continue_others_strategy_allows_partial_success(runtime, goal):
|
||||
"""continue_others should let successful branches complete even if one fails."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="ok", node_type="function", output_keys=["b1_out"]
|
||||
id="b1", name="B1", description="ok", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2",
|
||||
name="B2",
|
||||
description="fail",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["b2_out"],
|
||||
max_retries=1,
|
||||
)
|
||||
@@ -325,13 +329,13 @@ async def test_continue_others_strategy_allows_partial_success(runtime, goal):
|
||||
async def test_wait_all_strategy_collects_all_results(runtime, goal):
|
||||
"""wait_all should wait for all branches before proceeding."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="ok", node_type="function", output_keys=["b1_out"]
|
||||
id="b1", name="B1", description="ok", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2",
|
||||
name="B2",
|
||||
description="fail",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["b2_out"],
|
||||
max_retries=1,
|
||||
)
|
||||
@@ -365,12 +369,12 @@ async def test_per_branch_retry(runtime, goal):
|
||||
id="b1",
|
||||
name="B1",
|
||||
description="flaky",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
output_keys=["b1_out"],
|
||||
max_retries=5,
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="solid", node_type="function", output_keys=["b2_out"]
|
||||
id="b2", name="B2", description="solid", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
@@ -394,13 +398,13 @@ async def test_per_branch_retry(runtime, goal):
|
||||
async def test_single_edge_no_parallel_overhead(runtime, goal):
|
||||
"""A single outgoing edge should follow normal sequential path, not fan-out."""
|
||||
n1 = NodeSpec(
|
||||
id="n1", name="N1", description="entry", node_type="function", output_keys=["out1"]
|
||||
id="n1", name="N1", description="entry", node_type="event_loop", output_keys=["out1"]
|
||||
)
|
||||
n2 = NodeSpec(
|
||||
id="n2",
|
||||
name="N2",
|
||||
description="next",
|
||||
node_type="function",
|
||||
node_type="event_loop",
|
||||
input_keys=["out1"],
|
||||
output_keys=["out2"],
|
||||
)
|
||||
@@ -432,8 +436,8 @@ async def test_single_edge_no_parallel_overhead(runtime, goal):
|
||||
|
||||
def test_detect_fan_out_nodes():
|
||||
"""GraphSpec.detect_fan_out_nodes should identify fan-out topology."""
|
||||
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="function", output_keys=["x"])
|
||||
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="function", output_keys=["y"])
|
||||
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="event_loop", output_keys=["x"])
|
||||
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="event_loop", output_keys=["y"])
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
|
||||
fan_outs = graph.detect_fan_out_nodes()
|
||||
@@ -447,10 +451,10 @@ def test_detect_fan_out_nodes():
|
||||
|
||||
def test_detect_fan_in_nodes():
|
||||
"""GraphSpec.detect_fan_in_nodes should identify convergence topology."""
|
||||
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="function", output_keys=["x"])
|
||||
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="function", output_keys=["y"])
|
||||
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="event_loop", output_keys=["x"])
|
||||
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="event_loop", output_keys=["y"])
|
||||
merge = NodeSpec(
|
||||
id="merge", name="Merge", description="m", node_type="function", output_keys=["z"]
|
||||
id="merge", name="Merge", description="m", node_type="event_loop", output_keys=["z"]
|
||||
)
|
||||
graph = _make_fanout_graph([b1, b2], fan_in_node=merge)
|
||||
|
||||
@@ -467,10 +471,10 @@ def test_detect_fan_in_nodes():
|
||||
async def test_parallel_disabled_uses_sequential(runtime, goal):
|
||||
"""When enable_parallel_execution=False, multi-edge should follow first match only."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="b1", node_type="function", output_keys=["b1_out"]
|
||||
id="b1", name="B1", description="b1", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="b2", node_type="function", output_keys=["b2_out"]
|
||||
id="b2", name="B2", description="b2", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
|
||||
@@ -1,442 +0,0 @@
|
||||
"""
|
||||
Tests for the Worker-Judge flexible execution pattern.
|
||||
|
||||
Tests cover:
|
||||
- Plan and PlanStep data structures
|
||||
- Code sandbox security
|
||||
- HybridJudge rule evaluation
|
||||
- WorkerNode action dispatch
|
||||
- FlexibleGraphExecutor end-to-end
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.code_sandbox import (
|
||||
CodeSandbox,
|
||||
safe_eval,
|
||||
safe_exec,
|
||||
)
|
||||
from framework.graph.goal import Goal, SuccessCriterion
|
||||
from framework.graph.judge import HybridJudge, create_default_judge
|
||||
from framework.graph.plan import (
|
||||
ActionSpec,
|
||||
ActionType,
|
||||
EvaluationRule,
|
||||
ExecutionStatus,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
Plan,
|
||||
PlanExecutionResult,
|
||||
PlanStep,
|
||||
StepStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestPlanDataStructures:
|
||||
"""Tests for Plan and PlanStep."""
|
||||
|
||||
def test_plan_step_creation(self):
|
||||
"""Test creating a PlanStep."""
|
||||
action = ActionSpec(
|
||||
action_type=ActionType.LLM_CALL,
|
||||
prompt="Hello, world!",
|
||||
)
|
||||
step = PlanStep(
|
||||
id="step_1",
|
||||
description="Say hello",
|
||||
action=action,
|
||||
expected_outputs=["greeting"],
|
||||
)
|
||||
|
||||
assert step.id == "step_1"
|
||||
assert step.status == StepStatus.PENDING
|
||||
assert step.action.action_type == ActionType.LLM_CALL
|
||||
|
||||
def test_plan_step_is_ready(self):
|
||||
"""Test PlanStep.is_ready() with dependencies."""
|
||||
step1 = PlanStep(
|
||||
id="step_1",
|
||||
description="First step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=[],
|
||||
)
|
||||
step2 = PlanStep(
|
||||
id="step_2",
|
||||
description="Second step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=["step_1"],
|
||||
)
|
||||
|
||||
# Step 1 is ready (no deps)
|
||||
assert step1.is_ready(set()) is True
|
||||
|
||||
# Step 2 is not ready (dep not met)
|
||||
assert step2.is_ready(set()) is False
|
||||
|
||||
# Step 2 is ready after step 1 completes
|
||||
assert step2.is_ready({"step_1"}) is True
|
||||
|
||||
def test_plan_get_ready_steps(self):
|
||||
"""Test Plan.get_ready_steps()."""
|
||||
plan = Plan(
|
||||
id="test_plan",
|
||||
goal_id="goal_1",
|
||||
description="Test plan",
|
||||
steps=[
|
||||
PlanStep(
|
||||
id="step_1",
|
||||
description="First",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=[],
|
||||
),
|
||||
PlanStep(
|
||||
id="step_2",
|
||||
description="Second",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=["step_1"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
ready = plan.get_ready_steps()
|
||||
assert len(ready) == 1
|
||||
assert ready[0].id == "step_1"
|
||||
|
||||
def test_plan_is_complete(self):
|
||||
"""Test Plan.is_complete()."""
|
||||
plan = Plan(
|
||||
id="test_plan",
|
||||
goal_id="goal_1",
|
||||
description="Test plan",
|
||||
steps=[
|
||||
PlanStep(
|
||||
id="step_1",
|
||||
description="First",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
status=StepStatus.COMPLETED,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert plan.is_complete() is True
|
||||
|
||||
def test_plan_to_feedback_context(self):
|
||||
"""Test Plan.to_feedback_context()."""
|
||||
plan = Plan(
|
||||
id="test_plan",
|
||||
goal_id="goal_1",
|
||||
description="Test plan",
|
||||
steps=[
|
||||
PlanStep(
|
||||
id="step_1",
|
||||
description="Completed step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
status=StepStatus.COMPLETED,
|
||||
result={"data": "value"},
|
||||
),
|
||||
PlanStep(
|
||||
id="step_2",
|
||||
description="Failed step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
status=StepStatus.FAILED,
|
||||
error="Something went wrong",
|
||||
attempts=3,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
context = plan.to_feedback_context()
|
||||
assert context["plan_id"] == "test_plan"
|
||||
assert len(context["completed_steps"]) == 1
|
||||
assert len(context["failed_steps"]) == 1
|
||||
assert context["failed_steps"][0]["error"] == "Something went wrong"
|
||||
|
||||
|
||||
class TestCodeSandbox:
|
||||
"""Tests for code sandbox security."""
|
||||
|
||||
def test_simple_execution(self):
|
||||
"""Test simple code execution."""
|
||||
result = safe_exec("x = 1 + 2\nresult = x * 3")
|
||||
assert result.success is True
|
||||
assert result.variables.get("x") == 3
|
||||
assert result.result == 9
|
||||
|
||||
def test_input_injection(self):
|
||||
"""Test passing inputs to sandbox."""
|
||||
result = safe_exec(
|
||||
"result = x + y",
|
||||
inputs={"x": 10, "y": 20},
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.result == 30
|
||||
|
||||
def test_blocked_import(self):
|
||||
"""Test that dangerous imports are blocked."""
|
||||
result = safe_exec("import os")
|
||||
assert result.success is False
|
||||
assert "blocked" in result.error.lower() or "import" in result.error.lower()
|
||||
|
||||
def test_blocked_private_access(self):
|
||||
"""Test that private attribute access is blocked."""
|
||||
result = safe_exec("x = [].__class__.__bases__")
|
||||
assert result.success is False
|
||||
|
||||
def test_blocked_exec_eval(self):
|
||||
"""Test that exec/eval are blocked."""
|
||||
result = safe_exec("exec('print(1)')")
|
||||
assert result.success is False
|
||||
|
||||
def test_safe_eval_expression(self):
|
||||
"""Test safe_eval for expressions."""
|
||||
result = safe_eval("x + y", inputs={"x": 5, "y": 3})
|
||||
assert result.success is True
|
||||
assert result.result == 8
|
||||
|
||||
def test_allowed_modules(self):
|
||||
"""Test that allowed modules work."""
|
||||
sandbox = CodeSandbox()
|
||||
# math is in ALLOWED_MODULES
|
||||
result = sandbox.execute(
|
||||
"""
|
||||
import math
|
||||
result = math.sqrt(16)
|
||||
""",
|
||||
inputs={},
|
||||
)
|
||||
# Note: imports are blocked by default in validation
|
||||
# This test documents current behavior
|
||||
assert result.success is False # imports blocked by validator
|
||||
|
||||
|
||||
class TestHybridJudge:
|
||||
"""Tests for the HybridJudge."""
|
||||
|
||||
def test_rule_based_accept(self):
|
||||
"""Test rule-based accept judgment."""
|
||||
judge = HybridJudge()
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="success_check",
|
||||
description="Accept on success flag",
|
||||
condition="result.get('success') == True",
|
||||
action=JudgmentAction.ACCEPT,
|
||||
)
|
||||
)
|
||||
|
||||
step = PlanStep(
|
||||
id="test_step",
|
||||
description="Test",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
)
|
||||
goal = Goal(
|
||||
id="goal_1",
|
||||
name="Test Goal",
|
||||
description="A test goal",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="sc1", description="Complete task", metric="completion", target="100%"
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Use sync version for testing
|
||||
judgment = asyncio.run(judge.evaluate(step, {"success": True}, goal))
|
||||
|
||||
assert judgment.action == JudgmentAction.ACCEPT
|
||||
assert judgment.rule_matched == "success_check"
|
||||
|
||||
def test_rule_based_retry(self):
|
||||
"""Test rule-based retry judgment."""
|
||||
judge = HybridJudge()
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="timeout_retry",
|
||||
description="Retry on timeout",
|
||||
condition="result.get('error_type') == 'timeout'",
|
||||
action=JudgmentAction.RETRY,
|
||||
feedback_template="Timeout occurred, please retry",
|
||||
)
|
||||
)
|
||||
|
||||
step = PlanStep(
|
||||
id="test_step",
|
||||
description="Test",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
)
|
||||
goal = Goal(
|
||||
id="goal_1",
|
||||
name="Test Goal",
|
||||
description="A test goal",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="sc1", description="Complete task", metric="completion", target="100%"
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
judgment = asyncio.run(judge.evaluate(step, {"error_type": "timeout"}, goal))
|
||||
|
||||
assert judgment.action == JudgmentAction.RETRY
|
||||
|
||||
def test_rule_priority(self):
|
||||
"""Test that higher priority rules are checked first."""
|
||||
judge = HybridJudge()
|
||||
|
||||
# Lower priority - would match
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="low_priority",
|
||||
description="Low priority accept",
|
||||
condition="True",
|
||||
action=JudgmentAction.ACCEPT,
|
||||
priority=1,
|
||||
)
|
||||
)
|
||||
|
||||
# Higher priority - should match first
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="high_priority",
|
||||
description="High priority escalate",
|
||||
condition="True",
|
||||
action=JudgmentAction.ESCALATE,
|
||||
priority=100,
|
||||
)
|
||||
)
|
||||
|
||||
step = PlanStep(
|
||||
id="test_step",
|
||||
description="Test",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
)
|
||||
goal = Goal(
|
||||
id="goal_1",
|
||||
name="Test Goal",
|
||||
description="A test goal",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="sc1", description="Complete task", metric="completion", target="100%"
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
judgment = asyncio.run(judge.evaluate(step, {}, goal))
|
||||
|
||||
assert judgment.rule_matched == "high_priority"
|
||||
assert judgment.action == JudgmentAction.ESCALATE
|
||||
|
||||
def test_default_judge_rules(self):
|
||||
"""Test that create_default_judge includes useful rules."""
|
||||
judge = create_default_judge()
|
||||
|
||||
# Should have rules for common cases
|
||||
rule_ids = {r.id for r in judge.rules}
|
||||
assert "explicit_success" in rule_ids
|
||||
assert "transient_error_retry" in rule_ids
|
||||
assert "security_escalate" in rule_ids
|
||||
|
||||
|
||||
class TestJudgment:
|
||||
"""Tests for Judgment data structure."""
|
||||
|
||||
def test_judgment_creation(self):
|
||||
"""Test creating a Judgment."""
|
||||
judgment = Judgment(
|
||||
action=JudgmentAction.ACCEPT,
|
||||
reasoning="Step completed successfully",
|
||||
confidence=0.95,
|
||||
)
|
||||
|
||||
assert judgment.action == JudgmentAction.ACCEPT
|
||||
assert judgment.confidence == 0.95
|
||||
assert judgment.llm_used is False
|
||||
|
||||
def test_judgment_with_feedback(self):
|
||||
"""Test Judgment with feedback for retry/replan."""
|
||||
judgment = Judgment(
|
||||
action=JudgmentAction.REPLAN,
|
||||
reasoning="Missing required data",
|
||||
feedback="Need to fetch user data first",
|
||||
context={"missing": ["user_id", "email"]},
|
||||
)
|
||||
|
||||
assert judgment.action == JudgmentAction.REPLAN
|
||||
assert judgment.feedback is not None
|
||||
assert "user_id" in judgment.context["missing"]
|
||||
|
||||
|
||||
class TestPlanExecutionResult:
|
||||
"""Tests for PlanExecutionResult."""
|
||||
|
||||
def test_completed_result(self):
|
||||
"""Test completed execution result."""
|
||||
result = PlanExecutionResult(
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
results={"output": "success"},
|
||||
steps_executed=5,
|
||||
total_tokens=1000,
|
||||
)
|
||||
|
||||
assert result.status == ExecutionStatus.COMPLETED
|
||||
assert result.steps_executed == 5
|
||||
|
||||
def test_needs_replan_result(self):
|
||||
"""Test needs_replan execution result."""
|
||||
result = PlanExecutionResult(
|
||||
status=ExecutionStatus.NEEDS_REPLAN,
|
||||
feedback="Step 3 failed: missing data",
|
||||
feedback_context={
|
||||
"completed_steps": ["step_1", "step_2"],
|
||||
"failed_step": "step_3",
|
||||
},
|
||||
completed_steps=["step_1", "step_2"],
|
||||
)
|
||||
|
||||
assert result.status == ExecutionStatus.NEEDS_REPLAN
|
||||
assert result.feedback is not None
|
||||
assert len(result.completed_steps) == 2
|
||||
|
||||
|
||||
# Integration tests would require mocking Runtime and LLM
|
||||
class TestFlexibleExecutorIntegration:
|
||||
"""Integration tests for FlexibleGraphExecutor."""
|
||||
|
||||
def test_executor_creation(self, tmp_path):
|
||||
"""Test creating a FlexibleGraphExecutor."""
|
||||
from framework.graph.flexible_executor import FlexibleGraphExecutor
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
runtime = Runtime(storage_path=tmp_path / "runtime")
|
||||
executor = FlexibleGraphExecutor(runtime=runtime)
|
||||
|
||||
assert executor.runtime == runtime
|
||||
assert executor.judge is not None
|
||||
assert executor.worker is not None
|
||||
|
||||
def test_executor_with_custom_judge(self, tmp_path):
|
||||
"""Test executor with custom judge."""
|
||||
from framework.graph.flexible_executor import FlexibleGraphExecutor
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
runtime = Runtime(storage_path=tmp_path / "runtime")
|
||||
custom_judge = HybridJudge()
|
||||
custom_judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="custom_rule",
|
||||
description="Custom rule",
|
||||
condition="True",
|
||||
action=JudgmentAction.ACCEPT,
|
||||
)
|
||||
)
|
||||
|
||||
executor = FlexibleGraphExecutor(runtime=runtime, judge=custom_judge)
|
||||
|
||||
assert len(executor.judge.rules) == 1
|
||||
assert executor.judge.rules[0].id == "custom_rule"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -5,7 +5,7 @@ Focused on minimal success and failure scenarios.
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeResult, NodeSpec
|
||||
@@ -49,7 +49,7 @@ async def test_executor_single_node_success():
|
||||
id="n1",
|
||||
name="node1",
|
||||
description="test node",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
input_keys=[],
|
||||
output_keys=["result"],
|
||||
max_retries=0,
|
||||
@@ -104,7 +104,7 @@ async def test_executor_single_node_failure():
|
||||
id="n1",
|
||||
name="node1",
|
||||
description="failing node",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
input_keys=[],
|
||||
output_keys=["result"],
|
||||
max_retries=0,
|
||||
@@ -143,77 +143,20 @@ class FakeEventBus:
|
||||
async def emit_node_loop_completed(self, **kwargs):
|
||||
self.events.append(("completed", kwargs))
|
||||
|
||||
async def emit_edge_traversed(self, **kwargs):
|
||||
self.events.append(("edge_traversed", kwargs))
|
||||
|
||||
async def emit_execution_paused(self, **kwargs):
|
||||
self.events.append(("execution_paused", kwargs))
|
||||
|
||||
async def emit_execution_resumed(self, **kwargs):
|
||||
self.events.append(("execution_resumed", kwargs))
|
||||
|
||||
async def emit_node_retry(self, **kwargs):
|
||||
self.events.append(("node_retry", kwargs))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_emits_node_events():
|
||||
"""Executor should emit NODE_LOOP_STARTED/COMPLETED for each non-event_loop node."""
|
||||
runtime = DummyRuntime()
|
||||
event_bus = FakeEventBus()
|
||||
|
||||
graph = GraphSpec(
|
||||
id="graph-ev",
|
||||
goal_id="g-ev",
|
||||
nodes=[
|
||||
NodeSpec(
|
||||
id="n1",
|
||||
name="first",
|
||||
description="first node",
|
||||
node_type="llm_generate",
|
||||
input_keys=[],
|
||||
output_keys=["result"],
|
||||
max_retries=0,
|
||||
),
|
||||
NodeSpec(
|
||||
id="n2",
|
||||
name="second",
|
||||
description="second node",
|
||||
node_type="llm_generate",
|
||||
input_keys=["result"],
|
||||
output_keys=["result"],
|
||||
max_retries=0,
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(
|
||||
id="e1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
],
|
||||
entry_node="n1",
|
||||
terminal_nodes=["n2"],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime,
|
||||
node_registry={
|
||||
"n1": SuccessNode(),
|
||||
"n2": SuccessNode(),
|
||||
},
|
||||
event_bus=event_bus,
|
||||
stream_id="test-stream",
|
||||
)
|
||||
|
||||
goal = Goal(id="g-ev", name="event-test", description="test events")
|
||||
result = await executor.execute(graph=graph, goal=goal)
|
||||
|
||||
assert result.success is True
|
||||
assert result.path == ["n1", "n2"]
|
||||
|
||||
# Should have 4 events: started/completed for n1, then started/completed for n2
|
||||
assert len(event_bus.events) == 4
|
||||
assert event_bus.events[0] == ("started", {"stream_id": "test-stream", "node_id": "n1"})
|
||||
assert event_bus.events[1] == (
|
||||
"completed",
|
||||
{"stream_id": "test-stream", "node_id": "n1", "iterations": 1},
|
||||
)
|
||||
assert event_bus.events[2] == ("started", {"stream_id": "test-stream", "node_id": "n2"})
|
||||
assert event_bus.events[3] == (
|
||||
"completed",
|
||||
{"stream_id": "test-stream", "node_id": "n2", "iterations": 1},
|
||||
)
|
||||
|
||||
|
||||
# ---- Fake event_loop node (registered, so executor won't emit for it) ----
|
||||
class FakeEventLoopNode:
|
||||
@@ -276,7 +219,7 @@ async def test_executor_no_events_without_event_bus():
|
||||
id="n1",
|
||||
name="node1",
|
||||
description="test node",
|
||||
node_type="llm_generate",
|
||||
node_type="event_loop",
|
||||
input_keys=[],
|
||||
output_keys=["result"],
|
||||
max_retries=0,
|
||||
|
||||
@@ -9,12 +9,18 @@ For live tests (requires API keys):
|
||||
OPENAI_API_KEY=sk-... pytest tests/test_litellm_provider.py -v -m live
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import LLMProvider, Tool, ToolResult, ToolUse
|
||||
from framework.llm.litellm import LiteLLMProvider, _compute_retry_delay
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
|
||||
|
||||
class TestLiteLLMProviderInit:
|
||||
@@ -532,3 +538,291 @@ class TestJsonMode:
|
||||
messages = call_kwargs["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Please respond with a valid JSON object" in messages[0]["content"]
|
||||
|
||||
|
||||
class TestComputeRetryDelay:
|
||||
"""Test _compute_retry_delay() header parsing and fallback logic."""
|
||||
|
||||
def test_fallback_exponential_backoff(self):
|
||||
"""No exception -> exponential backoff."""
|
||||
assert _compute_retry_delay(0) == 2 # 2 * 2^0
|
||||
assert _compute_retry_delay(1) == 4 # 2 * 2^1
|
||||
assert _compute_retry_delay(2) == 8 # 2 * 2^2
|
||||
assert _compute_retry_delay(3) == 16 # 2 * 2^3
|
||||
|
||||
def test_max_delay_cap(self):
|
||||
"""Backoff should be capped at RATE_LIMIT_MAX_DELAY."""
|
||||
# 2 * 2^10 = 2048, should be capped at 120
|
||||
assert _compute_retry_delay(10) == 120
|
||||
|
||||
def test_custom_max_delay(self):
|
||||
"""Custom max_delay should be respected."""
|
||||
assert _compute_retry_delay(5, max_delay=10) == 10
|
||||
|
||||
def test_retry_after_ms_header(self):
|
||||
"""retry-after-ms header should be parsed as milliseconds."""
|
||||
exc = _make_exception_with_headers({"retry-after-ms": "5000"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 5.0
|
||||
|
||||
def test_retry_after_ms_fractional(self):
|
||||
"""retry-after-ms should handle fractional values."""
|
||||
exc = _make_exception_with_headers({"retry-after-ms": "1500"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 1.5
|
||||
|
||||
def test_retry_after_seconds_header(self):
|
||||
"""retry-after header as seconds should be parsed."""
|
||||
exc = _make_exception_with_headers({"retry-after": "3"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 3.0
|
||||
|
||||
def test_retry_after_seconds_fractional(self):
|
||||
"""retry-after header should handle fractional seconds."""
|
||||
exc = _make_exception_with_headers({"retry-after": "2.5"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 2.5
|
||||
|
||||
def test_retry_after_ms_takes_priority(self):
|
||||
"""retry-after-ms should take priority over retry-after."""
|
||||
exc = _make_exception_with_headers(
|
||||
{
|
||||
"retry-after-ms": "2000",
|
||||
"retry-after": "10",
|
||||
}
|
||||
)
|
||||
assert _compute_retry_delay(0, exception=exc) == 2.0
|
||||
|
||||
def test_retry_after_http_date(self):
|
||||
"""retry-after as HTTP-date should be parsed."""
|
||||
from email.utils import format_datetime
|
||||
|
||||
future = datetime.now(UTC) + timedelta(seconds=5)
|
||||
date_str = format_datetime(future, usegmt=True)
|
||||
exc = _make_exception_with_headers({"retry-after": date_str})
|
||||
delay = _compute_retry_delay(0, exception=exc)
|
||||
assert 3.0 <= delay <= 6.0 # within tolerance
|
||||
|
||||
def test_exception_without_response(self):
|
||||
"""Exception with response=None should fall back to exponential."""
|
||||
exc = Exception("test")
|
||||
exc.response = None # type: ignore[attr-defined]
|
||||
assert _compute_retry_delay(0, exception=exc) == 2 # exponential fallback
|
||||
|
||||
def test_exception_without_response_attr(self):
|
||||
"""Exception without .response attr should fall back to exponential."""
|
||||
exc = ValueError("no response attr")
|
||||
assert _compute_retry_delay(0, exception=exc) == 2
|
||||
|
||||
def test_negative_retry_after_clamped_to_zero(self):
|
||||
"""Negative retry-after should be clamped to 0."""
|
||||
exc = _make_exception_with_headers({"retry-after": "-5"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 0
|
||||
|
||||
def test_negative_retry_after_ms_clamped_to_zero(self):
|
||||
"""Negative retry-after-ms should be clamped to 0."""
|
||||
exc = _make_exception_with_headers({"retry-after-ms": "-1000"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 0
|
||||
|
||||
def test_invalid_retry_after_falls_back(self):
|
||||
"""Non-numeric, non-date retry-after should fall back to exponential."""
|
||||
exc = _make_exception_with_headers({"retry-after": "not-a-number-or-date"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 2 # exponential fallback
|
||||
|
||||
def test_invalid_retry_after_ms_falls_back_to_retry_after(self):
|
||||
"""Invalid retry-after-ms should fall through to retry-after."""
|
||||
exc = _make_exception_with_headers(
|
||||
{
|
||||
"retry-after-ms": "garbage",
|
||||
"retry-after": "7",
|
||||
}
|
||||
)
|
||||
assert _compute_retry_delay(0, exception=exc) == 7.0
|
||||
|
||||
def test_retry_after_capped_at_max_delay(self):
|
||||
"""Server-provided delay should be capped at max_delay."""
|
||||
exc = _make_exception_with_headers({"retry-after": "3600"})
|
||||
assert _compute_retry_delay(0, exception=exc) == 120 # capped
|
||||
|
||||
def test_retry_after_ms_capped_at_max_delay(self):
|
||||
"""Server-provided ms delay should be capped at max_delay."""
|
||||
exc = _make_exception_with_headers({"retry-after-ms": "300000"}) # 300s
|
||||
assert _compute_retry_delay(0, exception=exc) == 120 # capped
|
||||
|
||||
|
||||
def _make_exception_with_headers(headers: dict[str, str]) -> BaseException:
|
||||
"""Create a mock exception with response headers for testing."""
|
||||
exc = Exception("rate limited")
|
||||
response = MagicMock()
|
||||
response.headers = headers
|
||||
exc.response = response # type: ignore[attr-defined]
|
||||
return exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async LLM methods — non-blocking event loop tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAsyncComplete:
|
||||
"""Test that acomplete/acomplete_with_tools don't block the event loop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_acomplete_uses_acompletion(self, mock_acompletion):
|
||||
"""acomplete() should call litellm.acompletion (async), not litellm.completion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "async hello"
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
# acompletion is async, so mock must return a coroutine
|
||||
async def async_return(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
mock_acompletion.side_effect = async_return
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
result = await provider.acomplete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system="You are helpful.",
|
||||
)
|
||||
|
||||
assert result.content == "async hello"
|
||||
assert result.model == "gpt-4o-mini"
|
||||
assert result.input_tokens == 10
|
||||
assert result.output_tokens == 5
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_acomplete_does_not_block_event_loop(self, mock_acompletion):
|
||||
"""Verify event loop stays responsive during acomplete()."""
|
||||
heartbeat_ticks = []
|
||||
|
||||
async def heartbeat():
|
||||
start = time.monotonic()
|
||||
for _ in range(10):
|
||||
heartbeat_ticks.append(time.monotonic() - start)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def slow_acompletion(*args, **kwargs):
|
||||
# Simulate a 300ms LLM call — async, so event loop should stay free
|
||||
await asyncio.sleep(0.3)
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "done"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.model = "gpt-4o-mini"
|
||||
resp.usage.prompt_tokens = 5
|
||||
resp.usage.completion_tokens = 3
|
||||
return resp
|
||||
|
||||
mock_acompletion.side_effect = slow_acompletion
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
# Run heartbeat + acomplete concurrently
|
||||
_, result = await asyncio.gather(
|
||||
heartbeat(),
|
||||
provider.acomplete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
),
|
||||
)
|
||||
|
||||
assert result.content == "done"
|
||||
# Heartbeat should have ticked multiple times during the 300ms LLM call
|
||||
# (if the event loop were blocked, we'd see 0-1 ticks)
|
||||
assert len(heartbeat_ticks) >= 3, (
|
||||
f"Event loop was blocked — only {len(heartbeat_ticks)} heartbeat ticks"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_acomplete_with_tools_uses_acompletion(self, mock_acompletion):
|
||||
"""acomplete_with_tools() should use litellm.acompletion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "tool result"
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
async def async_return(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
mock_acompletion.side_effect = async_return
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
tools = [
|
||||
Tool(
|
||||
name="search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"q": {"type": "string"}}, "required": ["q"]},
|
||||
)
|
||||
]
|
||||
|
||||
result = await provider.acomplete_with_tools(
|
||||
messages=[{"role": "user", "content": "Search for cats"}],
|
||||
system="You are helpful.",
|
||||
tools=tools,
|
||||
tool_executor=lambda tu: ToolResult(tool_use_id=tu.id, content="cats"),
|
||||
)
|
||||
|
||||
assert result.content == "tool result"
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_provider_acomplete(self):
|
||||
"""MockLLMProvider.acomplete() should work without blocking."""
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
provider = MockLLMProvider()
|
||||
result = await provider.acomplete(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
system="Be helpful.",
|
||||
)
|
||||
|
||||
assert result.content # Should have some mock content
|
||||
assert result.model == "mock-model"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_provider_acomplete_offloads_to_executor(self):
|
||||
"""Base LLMProvider.acomplete() should offload sync complete() to thread pool."""
|
||||
call_thread_ids = []
|
||||
|
||||
class SlowSyncProvider(LLMProvider):
|
||||
def complete(
|
||||
self,
|
||||
messages,
|
||||
system="",
|
||||
tools=None,
|
||||
max_tokens=1024,
|
||||
response_format=None,
|
||||
json_mode=False,
|
||||
max_retries=None,
|
||||
):
|
||||
call_thread_ids.append(threading.current_thread().ident)
|
||||
time.sleep(0.1) # Sync blocking
|
||||
return LLMResponse(content="sync done", model="slow")
|
||||
|
||||
def complete_with_tools(
|
||||
self, messages, system, tools, tool_executor, max_iterations=10
|
||||
):
|
||||
return LLMResponse(content="sync tools done", model="slow")
|
||||
|
||||
provider = SlowSyncProvider()
|
||||
main_thread_id = threading.current_thread().ident
|
||||
|
||||
result = await provider.acomplete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert result.content == "sync done"
|
||||
# The sync complete() should have run on a different thread
|
||||
assert call_thread_ids[0] != main_thread_id, (
|
||||
"Base acomplete() should offload sync complete() to a thread pool"
|
||||
)
|
||||
|
||||
@@ -143,6 +143,7 @@ def _has_api_key(env_var: str) -> bool:
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — text streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.skip(reason="Requires valid live API keys — run manually")
|
||||
class TestRealAPITextStreaming:
|
||||
"""Stream a simple text response from each provider and dump events."""
|
||||
|
||||
@@ -204,6 +205,7 @@ class TestRealAPITextStreaming:
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — tool call streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.skip(reason="Requires valid live API keys — run manually")
|
||||
class TestRealAPIToolCallStreaming:
|
||||
"""Stream a tool call response from each provider and dump events."""
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ class MockLLMProvider(LLMProvider):
|
||||
max_tokens=1024,
|
||||
response_format=None,
|
||||
json_mode=False,
|
||||
max_retries=None,
|
||||
):
|
||||
self.complete_calls.append(
|
||||
{
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
"""Tests for LLMNode JSON extraction logic.
|
||||
|
||||
Run with:
|
||||
cd core
|
||||
pytest tests/test_node_json_extraction.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.node import LLMNode
|
||||
|
||||
|
||||
class TestJsonExtraction:
|
||||
"""Test _extract_json JSON extraction without LLM calls."""
|
||||
|
||||
@pytest.fixture
|
||||
def node(self):
|
||||
"""Create an LLMNode instance for testing."""
|
||||
return LLMNode()
|
||||
|
||||
def test_clean_json(self, node):
|
||||
"""Test parsing clean JSON directly."""
|
||||
result = node._extract_json('{"key": "value"}', ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_json_with_whitespace(self, node):
|
||||
"""Test parsing JSON with surrounding whitespace."""
|
||||
result = node._extract_json(' {"key": "value"} \n', ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_markdown_code_block_at_start(self, node):
|
||||
"""Test extracting JSON from markdown code block at start."""
|
||||
input_text = '```json\n{"key": "value"}\n```'
|
||||
result = node._extract_json(input_text, ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_markdown_code_block_without_json_label(self, node):
|
||||
"""Test extracting JSON from markdown code block without 'json' label."""
|
||||
input_text = '```\n{"key": "value"}\n```'
|
||||
result = node._extract_json(input_text, ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_prose_around_markdown_block(self, node):
|
||||
"""Test extracting JSON when prose surrounds the markdown block."""
|
||||
input_text = 'Here is the result:\n```json\n{"key": "value"}\n```\nHope this helps!'
|
||||
result = node._extract_json(input_text, ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_json_embedded_in_prose(self, node):
|
||||
"""Test extracting JSON embedded in prose text."""
|
||||
input_text = 'The answer is {"key": "value"} as requested.'
|
||||
result = node._extract_json(input_text, ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_nested_json(self, node):
|
||||
"""Test parsing nested JSON objects."""
|
||||
input_text = '{"outer": {"inner": "value"}}'
|
||||
result = node._extract_json(input_text, ["outer"])
|
||||
assert result == {"outer": {"inner": "value"}}
|
||||
|
||||
def test_deeply_nested_json(self, node):
|
||||
"""Test parsing deeply nested JSON objects."""
|
||||
input_text = '{"a": {"b": {"c": {"d": "deep"}}}}'
|
||||
result = node._extract_json(input_text, ["a"])
|
||||
assert result == {"a": {"b": {"c": {"d": "deep"}}}}
|
||||
|
||||
def test_json_with_array(self, node):
|
||||
"""Test parsing JSON with array values."""
|
||||
input_text = '{"items": [1, 2, 3]}'
|
||||
result = node._extract_json(input_text, ["items"])
|
||||
assert result == {"items": [1, 2, 3]}
|
||||
|
||||
def test_json_with_string_containing_braces(self, node):
|
||||
"""Test parsing JSON where string values contain braces."""
|
||||
input_text = '{"code": "function() { return 1; }"}'
|
||||
result = node._extract_json(input_text, ["code"])
|
||||
assert result == {"code": "function() { return 1; }"}
|
||||
|
||||
def test_json_with_escaped_quotes(self, node):
|
||||
"""Test parsing JSON with escaped quotes in strings."""
|
||||
input_text = '{"message": "He said \\"hello\\""}'
|
||||
result = node._extract_json(input_text, ["message"])
|
||||
assert result == {"message": 'He said "hello"'}
|
||||
|
||||
def test_multiple_json_objects_takes_first(self, node):
|
||||
"""Test that when multiple JSON objects exist, first is taken."""
|
||||
input_text = '{"first": 1} and then {"second": 2}'
|
||||
result = node._extract_json(input_text, ["first"])
|
||||
assert result == {"first": 1}
|
||||
|
||||
def test_json_with_boolean_and_null(self, node):
|
||||
"""Test parsing JSON with boolean and null values."""
|
||||
input_text = '{"active": true, "deleted": false, "data": null}'
|
||||
result = node._extract_json(input_text, ["active", "deleted", "data"])
|
||||
assert result == {"active": True, "deleted": False, "data": None}
|
||||
|
||||
def test_json_with_numbers(self, node):
|
||||
"""Test parsing JSON with integer and float values."""
|
||||
input_text = '{"count": 42, "price": 19.99}'
|
||||
result = node._extract_json(input_text, ["count", "price"])
|
||||
assert result == {"count": 42, "price": 19.99}
|
||||
|
||||
def test_invalid_json_raises_error(self, node, monkeypatch):
|
||||
"""Test that completely invalid JSON raises an error when no LLM fallback available."""
|
||||
# Remove API keys so LLM fallback is not attempted
|
||||
monkeypatch.delenv("CEREBRAS_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Cannot parse JSON"):
|
||||
node._extract_json("This is not JSON at all", ["key"])
|
||||
|
||||
def test_empty_string_raises_error(self, node, monkeypatch):
|
||||
"""Test that empty string raises an error when no LLM fallback available."""
|
||||
# Remove API keys so LLM fallback is not attempted
|
||||
monkeypatch.delenv("CEREBRAS_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Cannot parse JSON"):
|
||||
node._extract_json("", ["key"])
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user