feat: concurrent framework entrypoints
This commit is contained in:
@@ -288,13 +288,56 @@ Respond with ONLY a JSON object:
|
||||
return result
|
||||
|
||||
|
||||
class AsyncEntryPointSpec(BaseModel):
|
||||
"""
|
||||
Specification for an asynchronous entry point.
|
||||
|
||||
Used with AgentRuntime for multi-entry-point agents that handle
|
||||
concurrent execution streams (e.g., webhook + API handlers).
|
||||
|
||||
Example:
|
||||
AsyncEntryPointSpec(
|
||||
id="webhook",
|
||||
name="Zendesk Webhook Handler",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
isolation_level="shared",
|
||||
)
|
||||
"""
|
||||
id: str = Field(description="Unique identifier for this entry point")
|
||||
name: str = Field(description="Human-readable name")
|
||||
entry_node: str = Field(description="Node ID to start execution from")
|
||||
trigger_type: str = Field(
|
||||
default="manual",
|
||||
description="How this entry point is triggered: webhook, api, timer, event, manual"
|
||||
)
|
||||
trigger_config: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Trigger-specific configuration (e.g., webhook URL, timer interval)"
|
||||
)
|
||||
isolation_level: str = Field(
|
||||
default="shared",
|
||||
description="State isolation: isolated, shared, or synchronized"
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description="Execution priority (higher = more priority)"
|
||||
)
|
||||
max_concurrent: int = Field(
|
||||
default=10,
|
||||
description="Maximum concurrent executions for this entry point"
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class GraphSpec(BaseModel):
|
||||
"""
|
||||
Complete specification of an agent graph.
|
||||
|
||||
Contains all nodes, edges, and metadata needed to execute.
|
||||
|
||||
Example:
|
||||
For single-entry-point agents (traditional pattern):
|
||||
GraphSpec(
|
||||
id="calculator-graph",
|
||||
goal_id="calc-001",
|
||||
@@ -303,6 +346,29 @@ class GraphSpec(BaseModel):
|
||||
nodes=[...],
|
||||
edges=[...],
|
||||
)
|
||||
|
||||
For multi-entry-point agents (concurrent streams):
|
||||
GraphSpec(
|
||||
id="support-agent-graph",
|
||||
goal_id="support-001",
|
||||
entry_node="process-webhook", # Default entry
|
||||
async_entry_points=[
|
||||
AsyncEntryPointSpec(
|
||||
id="webhook",
|
||||
name="Zendesk Webhook",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
),
|
||||
AsyncEntryPointSpec(
|
||||
id="api",
|
||||
name="API Handler",
|
||||
entry_node="process-request",
|
||||
trigger_type="api",
|
||||
),
|
||||
],
|
||||
nodes=[...],
|
||||
edges=[...],
|
||||
)
|
||||
"""
|
||||
id: str
|
||||
goal_id: str
|
||||
@@ -314,6 +380,10 @@ class GraphSpec(BaseModel):
|
||||
default_factory=dict,
|
||||
description="Named entry points for resuming execution. Format: {name: node_id}"
|
||||
)
|
||||
async_entry_points: list[AsyncEntryPointSpec] = Field(
|
||||
default_factory=list,
|
||||
description="Asynchronous entry points for concurrent execution streams (used with AgentRuntime)"
|
||||
)
|
||||
terminal_nodes: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of nodes that end execution"
|
||||
@@ -363,6 +433,17 @@ class GraphSpec(BaseModel):
|
||||
return node
|
||||
return None
|
||||
|
||||
def has_async_entry_points(self) -> bool:
|
||||
"""Check if this graph uses async entry points (multi-stream execution)."""
|
||||
return len(self.async_entry_points) > 0
|
||||
|
||||
def get_async_entry_point(self, entry_point_id: str) -> AsyncEntryPointSpec | None:
|
||||
"""Get an async entry point by ID."""
|
||||
for ep in self.async_entry_points:
|
||||
if ep.id == entry_point_id:
|
||||
return ep
|
||||
return None
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> list[EdgeSpec]:
|
||||
"""Get all edges leaving a node, sorted by priority."""
|
||||
edges = [e for e in self.edges if e.source == node_id]
|
||||
@@ -412,6 +493,36 @@ class GraphSpec(BaseModel):
|
||||
if not self.get_node(self.entry_node):
|
||||
errors.append(f"Entry node '{self.entry_node}' not found")
|
||||
|
||||
# Check async entry points
|
||||
seen_entry_ids = set()
|
||||
for entry_point in self.async_entry_points:
|
||||
# Check for duplicate IDs
|
||||
if entry_point.id in seen_entry_ids:
|
||||
errors.append(f"Duplicate async entry point ID: '{entry_point.id}'")
|
||||
seen_entry_ids.add(entry_point.id)
|
||||
|
||||
# Check entry node exists
|
||||
if not self.get_node(entry_point.entry_node):
|
||||
errors.append(
|
||||
f"Async entry point '{entry_point.id}' references missing node '{entry_point.entry_node}'"
|
||||
)
|
||||
|
||||
# Validate isolation level
|
||||
valid_isolation = {"isolated", "shared", "synchronized"}
|
||||
if entry_point.isolation_level not in valid_isolation:
|
||||
errors.append(
|
||||
f"Async entry point '{entry_point.id}' has invalid isolation_level "
|
||||
f"'{entry_point.isolation_level}'. Valid: {valid_isolation}"
|
||||
)
|
||||
|
||||
# Validate trigger type
|
||||
valid_triggers = {"webhook", "api", "timer", "event", "manual"}
|
||||
if entry_point.trigger_type not in valid_triggers:
|
||||
errors.append(
|
||||
f"Async entry point '{entry_point.id}' has invalid trigger_type "
|
||||
f"'{entry_point.trigger_type}'. Valid: {valid_triggers}"
|
||||
)
|
||||
|
||||
# Check terminal nodes exist
|
||||
for term in self.terminal_nodes:
|
||||
if not self.get_node(term):
|
||||
@@ -433,6 +544,10 @@ class GraphSpec(BaseModel):
|
||||
for entry_point_node in self.entry_points.values():
|
||||
to_visit.append(entry_point_node)
|
||||
|
||||
# Add all async entry points as valid starting points
|
||||
for async_entry in self.async_entry_points:
|
||||
to_visit.append(async_entry.entry_node)
|
||||
|
||||
# Traverse from all entry points
|
||||
while to_visit:
|
||||
current = to_visit.pop()
|
||||
@@ -442,11 +557,16 @@ class GraphSpec(BaseModel):
|
||||
for edge in self.get_outgoing_edges(current):
|
||||
to_visit.append(edge.target)
|
||||
|
||||
# Build set of async entry point nodes for quick lookup
|
||||
async_entry_nodes = {ep.entry_node for ep in self.async_entry_points}
|
||||
|
||||
for node in self.nodes:
|
||||
if node.id not in reachable:
|
||||
# Skip this error if the node is a pause node or an entry point target
|
||||
# (pause/resume architecture makes these reachable via session state)
|
||||
if node.id in self.pause_nodes or node.id in self.entry_points.values():
|
||||
# Skip this error if the node is a pause node, entry point target, or async entry point
|
||||
# (pause/resume architecture and async entry points make these reachable)
|
||||
if (node.id in self.pause_nodes or
|
||||
node.id in self.entry_points.values() or
|
||||
node.id in async_entry_nodes):
|
||||
continue
|
||||
errors.append(f"Node '{node.id}' is unreachable from entry")
|
||||
|
||||
|
||||
@@ -506,11 +506,19 @@ class LLMNode(NodeProtocol):
|
||||
# Try direct JSON parse first (fast path)
|
||||
try:
|
||||
content = raw_response.strip()
|
||||
# Remove markdown code blocks if present
|
||||
|
||||
# Remove markdown code blocks if present - more robust extraction
|
||||
if content.startswith("```"):
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', content, re.DOTALL)
|
||||
# Try multiple patterns for markdown code blocks
|
||||
# Pattern 1: ```json\n...\n``` or ```\n...\n```
|
||||
match = re.search(r'^```(?:json)?\s*\n([\s\S]*?)\n```\s*$', content)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
else:
|
||||
# Pattern 2: Just strip the first and last lines if they're ```
|
||||
lines = content.split('\n')
|
||||
if lines[0].startswith('```') and lines[-1].strip() == '```':
|
||||
content = '\n'.join(lines[1:-1]).strip()
|
||||
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict):
|
||||
@@ -560,9 +568,14 @@ IMPORTANT:
|
||||
cleaned = result.content.strip()
|
||||
# Remove markdown if Haiku added it
|
||||
if cleaned.startswith("```"):
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, re.DOTALL)
|
||||
match = re.search(r'^```(?:json)?\s*\n([\s\S]*?)\n```\s*$', cleaned)
|
||||
if match:
|
||||
cleaned = match.group(1).strip()
|
||||
else:
|
||||
# Fallback: strip first/last lines
|
||||
lines = cleaned.split('\n')
|
||||
if lines[0].startswith('```') and lines[-1].strip() == '```':
|
||||
cleaned = '\n'.join(lines[1:-1]).strip()
|
||||
|
||||
parsed = json.loads(cleaned)
|
||||
logger.info(" ✓ Haiku cleaned JSON output")
|
||||
|
||||
+325
-17
@@ -4,16 +4,20 @@ import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Callable, Any
|
||||
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition
|
||||
from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition, AsyncEntryPointSpec
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.graph.executor import GraphExecutor, ExecutionResult
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# Multi-entry-point runtime imports
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner.protocol import CapabilityResponse, AgentMessage
|
||||
|
||||
@@ -36,6 +40,9 @@ class AgentInfo:
|
||||
constraints: list[dict]
|
||||
required_tools: list[str]
|
||||
has_tools_module: bool
|
||||
# Multi-entry-point support
|
||||
async_entry_points: list[dict] = field(default_factory=list)
|
||||
is_multi_entry_point: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -92,6 +99,20 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
|
||||
)
|
||||
edges.append(edge)
|
||||
|
||||
# Build AsyncEntryPointSpec objects for multi-entry-point support
|
||||
async_entry_points = []
|
||||
for aep_data in graph_data.get("async_entry_points", []):
|
||||
async_entry_points.append(AsyncEntryPointSpec(
|
||||
id=aep_data["id"],
|
||||
name=aep_data.get("name", aep_data["id"]),
|
||||
entry_node=aep_data["entry_node"],
|
||||
trigger_type=aep_data.get("trigger_type", "manual"),
|
||||
trigger_config=aep_data.get("trigger_config", {}),
|
||||
isolation_level=aep_data.get("isolation_level", "shared"),
|
||||
priority=aep_data.get("priority", 0),
|
||||
max_concurrent=aep_data.get("max_concurrent", 10),
|
||||
))
|
||||
|
||||
# Build GraphSpec
|
||||
graph = GraphSpec(
|
||||
id=graph_data.get("id", "agent-graph"),
|
||||
@@ -99,6 +120,7 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
|
||||
version=graph_data.get("version", "1.0.0"),
|
||||
entry_node=graph_data.get("entry_node", ""),
|
||||
entry_points=graph_data.get("entry_points", {}), # Support pause/resume architecture
|
||||
async_entry_points=async_entry_points, # Support multi-entry-point agents
|
||||
terminal_nodes=graph_data.get("terminal_nodes", []),
|
||||
pause_nodes=graph_data.get("pause_nodes", []), # Support pause/resume architecture
|
||||
nodes=nodes,
|
||||
@@ -174,7 +196,7 @@ class AgentRunner:
|
||||
goal: Goal,
|
||||
mock_mode: bool = False,
|
||||
storage_path: Path | None = None,
|
||||
model: str = "claude-haiku-4-5-20251001",
|
||||
model: str = "cerebras/zai-glm-4.7",
|
||||
):
|
||||
"""
|
||||
Initialize the runner (use AgentRunner.load() instead).
|
||||
@@ -213,6 +235,10 @@ class AgentRunner:
|
||||
self._executor: GraphExecutor | None = None
|
||||
self._approval_callback: Callable | None = None
|
||||
|
||||
# Multi-entry-point support (AgentRuntime)
|
||||
self._agent_runtime: AgentRuntime | None = None
|
||||
self._uses_async_entry_points = self.graph.has_async_entry_points()
|
||||
|
||||
# Auto-discover tools from tools.py
|
||||
tools_path = agent_path / "tools.py"
|
||||
if tools_path.exists():
|
||||
@@ -229,7 +255,7 @@ class AgentRunner:
|
||||
agent_path: str | Path,
|
||||
mock_mode: bool = False,
|
||||
storage_path: Path | None = None,
|
||||
model: str = "claude-haiku-4-5-20251001",
|
||||
model: str = "cerebras/zai-glm-4.7",
|
||||
) -> "AgentRunner":
|
||||
"""
|
||||
Load an agent from an export folder.
|
||||
@@ -238,7 +264,7 @@ class AgentRunner:
|
||||
agent_path: Path to agent folder (containing agent.json)
|
||||
mock_mode: If True, use mock LLM responses
|
||||
storage_path: Path for runtime storage (defaults to temp)
|
||||
model: Anthropic model to use
|
||||
model: LLM model to use (any LiteLLM-compatible model name)
|
||||
|
||||
Returns:
|
||||
AgentRunner instance ready to run
|
||||
@@ -371,9 +397,6 @@ class AgentRunner:
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Set up runtime, LLM, and executor."""
|
||||
# Create runtime
|
||||
self._runtime = Runtime(storage_path=self._storage_path)
|
||||
|
||||
# Set up session context for tools (workspace_id, agent_id, session_id)
|
||||
workspace_id = "default" # Could be derived from storage path
|
||||
agent_id = self.graph.id or "unknown"
|
||||
@@ -387,41 +410,299 @@ class AgentRunner:
|
||||
)
|
||||
|
||||
# Create LLM provider (if not mock mode and API key available)
|
||||
if not self.mock_mode and os.environ.get("ANTHROPIC_API_KEY"):
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
# Uses LiteLLM which auto-detects the provider from model name
|
||||
if not self.mock_mode:
|
||||
# Detect required API key from model name
|
||||
api_key_env = self._get_api_key_env_var(self.model)
|
||||
if api_key_env and os.environ.get(api_key_env):
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
self._llm = LiteLLMProvider(model=self.model)
|
||||
elif api_key_env:
|
||||
print(f"Warning: {api_key_env} not set. LLM calls will fail.")
|
||||
print(f"Set it with: export {api_key_env}=your-api-key")
|
||||
|
||||
self._llm = AnthropicProvider(model=self.model)
|
||||
# Get tools for executor/runtime
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
tool_executor = self._tool_registry.get_executor()
|
||||
|
||||
if self._uses_async_entry_points:
|
||||
# Multi-entry-point mode: use AgentRuntime
|
||||
self._setup_agent_runtime(tools, tool_executor)
|
||||
else:
|
||||
# Single-entry-point mode: use legacy GraphExecutor
|
||||
self._setup_legacy_executor(tools, tool_executor)
|
||||
|
||||
def _get_api_key_env_var(self, model: str) -> str | None:
|
||||
"""Get the environment variable name for the API key based on model name."""
|
||||
model_lower = model.lower()
|
||||
|
||||
# Map model prefixes to API key environment variables
|
||||
# LiteLLM uses these conventions
|
||||
if model_lower.startswith("cerebras/"):
|
||||
return "CEREBRAS_API_KEY"
|
||||
elif model_lower.startswith("openai/") or model_lower.startswith("gpt-"):
|
||||
return "OPENAI_API_KEY"
|
||||
elif model_lower.startswith("anthropic/") or model_lower.startswith("claude"):
|
||||
return "ANTHROPIC_API_KEY"
|
||||
elif model_lower.startswith("gemini/") or model_lower.startswith("google/"):
|
||||
return "GOOGLE_API_KEY"
|
||||
elif model_lower.startswith("mistral/"):
|
||||
return "MISTRAL_API_KEY"
|
||||
elif model_lower.startswith("groq/"):
|
||||
return "GROQ_API_KEY"
|
||||
elif model_lower.startswith("ollama/"):
|
||||
return None # Ollama doesn't need an API key (local)
|
||||
elif model_lower.startswith("azure/"):
|
||||
return "AZURE_API_KEY"
|
||||
elif model_lower.startswith("cohere/"):
|
||||
return "COHERE_API_KEY"
|
||||
elif model_lower.startswith("replicate/"):
|
||||
return "REPLICATE_API_KEY"
|
||||
elif model_lower.startswith("together/"):
|
||||
return "TOGETHER_API_KEY"
|
||||
else:
|
||||
# Default: assume OpenAI-compatible
|
||||
return "OPENAI_API_KEY"
|
||||
|
||||
def _setup_legacy_executor(self, tools: list, tool_executor: Callable | None) -> None:
|
||||
"""Set up legacy single-entry-point execution using GraphExecutor."""
|
||||
# Create runtime
|
||||
self._runtime = Runtime(storage_path=self._storage_path)
|
||||
|
||||
# Create executor
|
||||
self._executor = GraphExecutor(
|
||||
runtime=self._runtime,
|
||||
llm=self._llm,
|
||||
tools=list(self._tool_registry.get_tools().values()),
|
||||
tool_executor=self._tool_registry.get_executor(),
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
approval_callback=self._approval_callback,
|
||||
)
|
||||
|
||||
async def run(self, input_data: dict | None = None, session_state: dict | None = None) -> ExecutionResult:
|
||||
def _setup_agent_runtime(self, tools: list, tool_executor: Callable | None) -> None:
|
||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
||||
# Convert AsyncEntryPointSpec to EntryPointSpec for AgentRuntime
|
||||
entry_points = []
|
||||
for async_ep in self.graph.async_entry_points:
|
||||
ep = EntryPointSpec(
|
||||
id=async_ep.id,
|
||||
name=async_ep.name,
|
||||
entry_node=async_ep.entry_node,
|
||||
trigger_type=async_ep.trigger_type,
|
||||
trigger_config=async_ep.trigger_config,
|
||||
isolation_level=async_ep.isolation_level,
|
||||
priority=async_ep.priority,
|
||||
max_concurrent=async_ep.max_concurrent,
|
||||
)
|
||||
entry_points.append(ep)
|
||||
|
||||
# Create AgentRuntime with all entry points
|
||||
self._agent_runtime = create_agent_runtime(
|
||||
graph=self.graph,
|
||||
goal=self.goal,
|
||||
storage_path=self._storage_path,
|
||||
entry_points=entry_points,
|
||||
llm=self._llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: dict | None = None,
|
||||
session_state: dict | None = None,
|
||||
entry_point_id: str | None = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute the agent with given input data.
|
||||
|
||||
For single-entry-point agents, this is the standard execution path.
|
||||
For multi-entry-point agents, you can optionally specify which entry point to use.
|
||||
|
||||
Args:
|
||||
input_data: Input data for the agent (e.g., {"lead_id": "123"})
|
||||
session_state: Optional session state to resume from
|
||||
entry_point_id: For multi-entry-point agents, which entry point to trigger
|
||||
(defaults to first entry point or "default")
|
||||
|
||||
Returns:
|
||||
ExecutionResult with output, path, and metrics
|
||||
"""
|
||||
if self._uses_async_entry_points:
|
||||
# Multi-entry-point mode: use AgentRuntime
|
||||
return await self._run_with_agent_runtime(
|
||||
input_data=input_data or {},
|
||||
entry_point_id=entry_point_id,
|
||||
)
|
||||
else:
|
||||
# Legacy single-entry-point mode
|
||||
return await self._run_with_executor(
|
||||
input_data=input_data or {},
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
async def _run_with_executor(
|
||||
self,
|
||||
input_data: dict,
|
||||
session_state: dict | None = None,
|
||||
) -> ExecutionResult:
|
||||
"""Run using legacy GraphExecutor (single entry point)."""
|
||||
if self._executor is None:
|
||||
self._setup()
|
||||
|
||||
return await self._executor.execute(
|
||||
graph=self.graph,
|
||||
goal=self.goal,
|
||||
input_data=input_data or {},
|
||||
input_data=input_data,
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
async def _run_with_agent_runtime(
|
||||
self,
|
||||
input_data: dict,
|
||||
entry_point_id: str | None = None,
|
||||
) -> ExecutionResult:
|
||||
"""Run using AgentRuntime (multi-entry-point)."""
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
# Start runtime if not running
|
||||
if not self._agent_runtime.is_running:
|
||||
await self._agent_runtime.start()
|
||||
|
||||
# Determine entry point
|
||||
if entry_point_id is None:
|
||||
# Use first entry point or "default" if no entry points defined
|
||||
entry_points = self._agent_runtime.get_entry_points()
|
||||
if entry_points:
|
||||
entry_point_id = entry_points[0].id
|
||||
else:
|
||||
entry_point_id = "default"
|
||||
|
||||
# Trigger and wait for result
|
||||
result = await self._agent_runtime.trigger_and_wait(
|
||||
entry_point_id=entry_point_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
# Return result or create error result
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error="Execution timed out or failed to complete",
|
||||
)
|
||||
|
||||
# === Multi-Entry-Point API (for agents with async_entry_points) ===
|
||||
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the agent runtime (for multi-entry-point agents).
|
||||
|
||||
This starts all registered entry points and allows concurrent execution.
|
||||
For single-entry-point agents, this is a no-op.
|
||||
"""
|
||||
if not self._uses_async_entry_points:
|
||||
return
|
||||
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
await self._agent_runtime.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the agent runtime (for multi-entry-point agents).
|
||||
|
||||
For single-entry-point agents, this is a no-op.
|
||||
"""
|
||||
if self._agent_runtime is not None:
|
||||
await self._agent_runtime.stop()
|
||||
|
||||
async def trigger(
|
||||
self,
|
||||
entry_point_id: str,
|
||||
input_data: dict[str, Any],
|
||||
correlation_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Trigger execution at a specific entry point (non-blocking).
|
||||
|
||||
For multi-entry-point agents only. Returns execution ID for tracking.
|
||||
|
||||
Args:
|
||||
entry_point_id: Which entry point to trigger
|
||||
input_data: Input data for the execution
|
||||
correlation_id: Optional ID to correlate related executions
|
||||
|
||||
Returns:
|
||||
Execution ID for tracking
|
||||
|
||||
Raises:
|
||||
RuntimeError: If agent doesn't use async entry points
|
||||
"""
|
||||
if not self._uses_async_entry_points:
|
||||
raise RuntimeError(
|
||||
"trigger() is only available for multi-entry-point agents. "
|
||||
"Use run() for single-entry-point agents."
|
||||
)
|
||||
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
if not self._agent_runtime.is_running:
|
||||
await self._agent_runtime.start()
|
||||
|
||||
return await self._agent_runtime.trigger(
|
||||
entry_point_id=entry_point_id,
|
||||
input_data=input_data,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async def get_goal_progress(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get goal progress across all execution streams.
|
||||
|
||||
For multi-entry-point agents only.
|
||||
|
||||
Returns:
|
||||
Dict with overall_progress, criteria_status, constraint_violations, etc.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If agent doesn't use async entry points
|
||||
"""
|
||||
if not self._uses_async_entry_points:
|
||||
raise RuntimeError(
|
||||
"get_goal_progress() is only available for multi-entry-point agents."
|
||||
)
|
||||
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
return await self._agent_runtime.get_goal_progress()
|
||||
|
||||
def get_entry_points(self) -> list[EntryPointSpec]:
|
||||
"""
|
||||
Get all registered entry points (for multi-entry-point agents).
|
||||
|
||||
Returns:
|
||||
List of EntryPointSpec objects
|
||||
"""
|
||||
if not self._uses_async_entry_points:
|
||||
return []
|
||||
|
||||
if self._agent_runtime is None:
|
||||
self._setup()
|
||||
|
||||
return self._agent_runtime.get_entry_points()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the agent runtime is running (for multi-entry-point agents)."""
|
||||
if self._agent_runtime is None:
|
||||
return False
|
||||
return self._agent_runtime.is_running
|
||||
|
||||
def info(self) -> AgentInfo:
|
||||
"""Return agent metadata (nodes, edges, goal, required tools)."""
|
||||
# Extract required tools from nodes
|
||||
@@ -454,6 +735,19 @@ class AgentRunner:
|
||||
for edge in self.graph.edges
|
||||
]
|
||||
|
||||
# Build async entry points info
|
||||
async_entry_points_info = [
|
||||
{
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"entry_node": ep.entry_node,
|
||||
"trigger_type": ep.trigger_type,
|
||||
"isolation_level": ep.isolation_level,
|
||||
"max_concurrent": ep.max_concurrent,
|
||||
}
|
||||
for ep in self.graph.async_entry_points
|
||||
]
|
||||
|
||||
return AgentInfo(
|
||||
name=self.graph.id,
|
||||
description=self.graph.description,
|
||||
@@ -475,6 +769,8 @@ class AgentRunner:
|
||||
],
|
||||
required_tools=sorted(required_tools),
|
||||
has_tools_module=(self.agent_path / "tools.py").exists(),
|
||||
async_entry_points=async_entry_points_info,
|
||||
is_multi_entry_point=self._uses_async_entry_points,
|
||||
)
|
||||
|
||||
def validate(self) -> ValidationResult:
|
||||
@@ -748,7 +1044,7 @@ Respond with JSON only:
|
||||
)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up resources."""
|
||||
"""Clean up resources (synchronous)."""
|
||||
# Clean up MCP client connections
|
||||
self._tool_registry.cleanup()
|
||||
|
||||
@@ -756,14 +1052,26 @@ Respond with JSON only:
|
||||
self._temp_dir.cleanup()
|
||||
self._temp_dir = None
|
||||
|
||||
async def cleanup_async(self) -> None:
|
||||
"""Clean up resources (asynchronous - for multi-entry-point agents)."""
|
||||
# Stop agent runtime if running
|
||||
if self._agent_runtime is not None and self._agent_runtime.is_running:
|
||||
await self._agent_runtime.stop()
|
||||
|
||||
# Run synchronous cleanup
|
||||
self.cleanup()
|
||||
|
||||
async def __aenter__(self) -> "AgentRunner":
|
||||
"""Context manager entry."""
|
||||
self._setup()
|
||||
# Start runtime for multi-entry-point agents
|
||||
if self._uses_async_entry_points and self._agent_runtime is not None:
|
||||
await self._agent_runtime.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
"""Context manager exit."""
|
||||
self.cleanup()
|
||||
await self.cleanup_async()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Destructor - cleanup temp dir."""
|
||||
|
||||
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
Agent Runtime - Top-level orchestrator for multi-entry-point agents.
|
||||
|
||||
Manages agent lifecycle and coordinates multiple execution streams
|
||||
while preserving the goal-driven approach.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.runtime.shared_state import SharedStateManager
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.runtime.execution_stream import ExecutionStream, EntryPointSpec
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRuntimeConfig:
|
||||
"""Configuration for AgentRuntime."""
|
||||
max_concurrent_executions: int = 100
|
||||
cache_ttl: float = 60.0
|
||||
batch_interval: float = 0.1
|
||||
max_history: int = 1000
|
||||
|
||||
|
||||
class AgentRuntime:
|
||||
"""
|
||||
Top-level runtime that manages agent lifecycle and concurrent executions.
|
||||
|
||||
Responsibilities:
|
||||
- Register and manage multiple entry points
|
||||
- Coordinate execution streams
|
||||
- Manage shared state across streams
|
||||
- Aggregate decisions/outcomes for goal evaluation
|
||||
- Handle lifecycle events (start, pause, shutdown)
|
||||
|
||||
Example:
|
||||
# Create runtime
|
||||
runtime = AgentRuntime(
|
||||
graph=support_agent_graph,
|
||||
goal=support_agent_goal,
|
||||
storage_path=Path("./storage"),
|
||||
llm=llm_provider,
|
||||
)
|
||||
|
||||
# Register entry points
|
||||
runtime.register_entry_point(EntryPointSpec(
|
||||
id="webhook",
|
||||
name="Zendesk Webhook",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
isolation_level="shared",
|
||||
))
|
||||
|
||||
runtime.register_entry_point(EntryPointSpec(
|
||||
id="api",
|
||||
name="API Handler",
|
||||
entry_node="process-request",
|
||||
trigger_type="api",
|
||||
isolation_level="shared",
|
||||
))
|
||||
|
||||
# Start runtime
|
||||
await runtime.start()
|
||||
|
||||
# Trigger executions (non-blocking)
|
||||
exec_1 = await runtime.trigger("webhook", {"ticket_id": "123"})
|
||||
exec_2 = await runtime.trigger("api", {"query": "help"})
|
||||
|
||||
# Check goal progress
|
||||
progress = await runtime.get_goal_progress()
|
||||
print(f"Progress: {progress['overall_progress']:.1%}")
|
||||
|
||||
# Stop runtime
|
||||
await runtime.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: "GraphSpec",
|
||||
goal: "Goal",
|
||||
storage_path: str | Path,
|
||||
llm: "LLMProvider | None" = None,
|
||||
tools: list["Tool"] | None = None,
|
||||
tool_executor: Callable | None = None,
|
||||
config: AgentRuntimeConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent runtime.
|
||||
|
||||
Args:
|
||||
graph: Graph specification for this agent
|
||||
goal: Goal driving execution
|
||||
storage_path: Path for persistent storage
|
||||
llm: LLM provider for nodes
|
||||
tools: Available tools
|
||||
tool_executor: Function to execute tools
|
||||
config: Optional runtime configuration
|
||||
"""
|
||||
self.graph = graph
|
||||
self.goal = goal
|
||||
self._config = config or AgentRuntimeConfig()
|
||||
|
||||
# Initialize storage
|
||||
self._storage = ConcurrentStorage(
|
||||
base_path=storage_path,
|
||||
cache_ttl=self._config.cache_ttl,
|
||||
batch_interval=self._config.batch_interval,
|
||||
)
|
||||
|
||||
# Initialize shared components
|
||||
self._state_manager = SharedStateManager()
|
||||
self._event_bus = EventBus(max_history=self._config.max_history)
|
||||
self._outcome_aggregator = OutcomeAggregator(goal, self._event_bus)
|
||||
|
||||
# LLM and tools
|
||||
self._llm = llm
|
||||
self._tools = tools or []
|
||||
self._tool_executor = tool_executor
|
||||
|
||||
# Entry points and streams
|
||||
self._entry_points: dict[str, EntryPointSpec] = {}
|
||||
self._streams: dict[str, ExecutionStream] = {}
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def register_entry_point(self, spec: EntryPointSpec) -> None:
|
||||
"""
|
||||
Register a named entry point for the agent.
|
||||
|
||||
Args:
|
||||
spec: Entry point specification
|
||||
|
||||
Raises:
|
||||
ValueError: If entry point ID already registered
|
||||
RuntimeError: If runtime is already running
|
||||
"""
|
||||
if self._running:
|
||||
raise RuntimeError("Cannot register entry points while runtime is running")
|
||||
|
||||
if spec.id in self._entry_points:
|
||||
raise ValueError(f"Entry point '{spec.id}' already registered")
|
||||
|
||||
# Validate entry node exists in graph
|
||||
if self.graph.get_node(spec.entry_node) is None:
|
||||
raise ValueError(f"Entry node '{spec.entry_node}' not found in graph")
|
||||
|
||||
self._entry_points[spec.id] = spec
|
||||
logger.info(f"Registered entry point: {spec.id} -> {spec.entry_node}")
|
||||
|
||||
def unregister_entry_point(self, entry_point_id: str) -> bool:
|
||||
"""
|
||||
Unregister an entry point.
|
||||
|
||||
Args:
|
||||
entry_point_id: Entry point to remove
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
|
||||
Raises:
|
||||
RuntimeError: If runtime is running
|
||||
"""
|
||||
if self._running:
|
||||
raise RuntimeError("Cannot unregister entry points while runtime is running")
|
||||
|
||||
if entry_point_id in self._entry_points:
|
||||
del self._entry_points[entry_point_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the agent runtime and all registered entry points."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
# Start storage
|
||||
await self._storage.start()
|
||||
|
||||
# Create streams for each entry point
|
||||
for ep_id, spec in self._entry_points.items():
|
||||
stream = ExecutionStream(
|
||||
stream_id=ep_id,
|
||||
entry_spec=spec,
|
||||
graph=self.graph,
|
||||
goal=self.goal,
|
||||
state_manager=self._state_manager,
|
||||
storage=self._storage,
|
||||
outcome_aggregator=self._outcome_aggregator,
|
||||
event_bus=self._event_bus,
|
||||
llm=self._llm,
|
||||
tools=self._tools,
|
||||
tool_executor=self._tool_executor,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
|
||||
self._running = True
|
||||
logger.info(f"AgentRuntime started with {len(self._streams)} streams")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the agent runtime and all streams."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
# Stop all streams
|
||||
for stream in self._streams.values():
|
||||
await stream.stop()
|
||||
|
||||
self._streams.clear()
|
||||
|
||||
# Stop storage
|
||||
await self._storage.stop()
|
||||
|
||||
self._running = False
|
||||
logger.info("AgentRuntime stopped")
|
||||
|
||||
async def trigger(
|
||||
self,
|
||||
entry_point_id: str,
|
||||
input_data: dict[str, Any],
|
||||
correlation_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Trigger execution at a specific entry point.
|
||||
|
||||
Non-blocking - returns immediately with execution ID.
|
||||
|
||||
Args:
|
||||
entry_point_id: Which entry point to trigger
|
||||
input_data: Input data for the execution
|
||||
correlation_id: Optional ID to correlate related executions
|
||||
|
||||
Returns:
|
||||
Execution ID for tracking
|
||||
|
||||
Raises:
|
||||
ValueError: If entry point not found
|
||||
RuntimeError: If runtime not running
|
||||
"""
|
||||
if not self._running:
|
||||
raise RuntimeError("AgentRuntime is not running")
|
||||
|
||||
stream = self._streams.get(entry_point_id)
|
||||
if stream is None:
|
||||
raise ValueError(f"Entry point '{entry_point_id}' not found")
|
||||
|
||||
return await stream.execute(input_data, correlation_id)
|
||||
|
||||
async def trigger_and_wait(
|
||||
self,
|
||||
entry_point_id: str,
|
||||
input_data: dict[str, Any],
|
||||
timeout: float | None = None,
|
||||
) -> ExecutionResult | None:
|
||||
"""
|
||||
Trigger execution and wait for completion.
|
||||
|
||||
Args:
|
||||
entry_point_id: Which entry point to trigger
|
||||
input_data: Input data for the execution
|
||||
timeout: Maximum time to wait (seconds)
|
||||
|
||||
Returns:
|
||||
ExecutionResult or None if timeout
|
||||
"""
|
||||
exec_id = await self.trigger(entry_point_id, input_data)
|
||||
stream = self._streams[entry_point_id]
|
||||
return await stream.wait_for_completion(exec_id, timeout)
|
||||
|
||||
async def get_goal_progress(self) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate goal progress across all streams.
|
||||
|
||||
Returns:
|
||||
Progress report including overall progress, criteria status,
|
||||
constraint violations, and metrics.
|
||||
"""
|
||||
return await self._outcome_aggregator.evaluate_goal_progress()
|
||||
|
||||
async def cancel_execution(
|
||||
self,
|
||||
entry_point_id: str,
|
||||
execution_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Cancel a running execution.
|
||||
|
||||
Args:
|
||||
entry_point_id: Stream containing the execution
|
||||
execution_id: Execution to cancel
|
||||
|
||||
Returns:
|
||||
True if cancelled, False if not found
|
||||
"""
|
||||
stream = self._streams.get(entry_point_id)
|
||||
if stream is None:
|
||||
return False
|
||||
return await stream.cancel_execution(execution_id)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_entry_points(self) -> list[EntryPointSpec]:
|
||||
"""Get all registered entry points."""
|
||||
return list(self._entry_points.values())
|
||||
|
||||
def get_stream(self, entry_point_id: str) -> ExecutionStream | None:
|
||||
"""Get a specific execution stream."""
|
||||
return self._streams.get(entry_point_id)
|
||||
|
||||
def get_execution_result(
|
||||
self,
|
||||
entry_point_id: str,
|
||||
execution_id: str,
|
||||
) -> ExecutionResult | None:
|
||||
"""Get result of a completed execution."""
|
||||
stream = self._streams.get(entry_point_id)
|
||||
if stream:
|
||||
return stream.get_result(execution_id)
|
||||
return None
|
||||
|
||||
# === EVENT SUBSCRIPTIONS ===
|
||||
|
||||
def subscribe_to_events(
|
||||
self,
|
||||
event_types: list,
|
||||
handler: Callable,
|
||||
filter_stream: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to agent events.
|
||||
|
||||
Args:
|
||||
event_types: Types of events to receive
|
||||
handler: Async function to call when event occurs
|
||||
filter_stream: Only receive events from this stream
|
||||
|
||||
Returns:
|
||||
Subscription ID (use to unsubscribe)
|
||||
"""
|
||||
return self._event_bus.subscribe(
|
||||
event_types=event_types,
|
||||
handler=handler,
|
||||
filter_stream=filter_stream,
|
||||
)
|
||||
|
||||
def unsubscribe_from_events(self, subscription_id: str) -> bool:
|
||||
"""Unsubscribe from events."""
|
||||
return self._event_bus.unsubscribe(subscription_id)
|
||||
|
||||
# === STATS AND MONITORING ===
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get comprehensive runtime statistics."""
|
||||
stream_stats = {}
|
||||
for ep_id, stream in self._streams.items():
|
||||
stream_stats[ep_id] = stream.get_stats()
|
||||
|
||||
return {
|
||||
"running": self._running,
|
||||
"entry_points": len(self._entry_points),
|
||||
"streams": stream_stats,
|
||||
"goal_id": self.goal.id,
|
||||
"outcome_aggregator": self._outcome_aggregator.get_stats(),
|
||||
"event_bus": self._event_bus.get_stats(),
|
||||
"state_manager": self._state_manager.get_stats(),
|
||||
}
|
||||
|
||||
# === PROPERTIES ===
|
||||
|
||||
@property
|
||||
def state_manager(self) -> SharedStateManager:
|
||||
"""Access the shared state manager."""
|
||||
return self._state_manager
|
||||
|
||||
@property
|
||||
def event_bus(self) -> EventBus:
|
||||
"""Access the event bus."""
|
||||
return self._event_bus
|
||||
|
||||
@property
|
||||
def outcome_aggregator(self) -> OutcomeAggregator:
|
||||
"""Access the outcome aggregator."""
|
||||
return self._outcome_aggregator
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if runtime is running."""
|
||||
return self._running
|
||||
|
||||
|
||||
# === CONVENIENCE FACTORY ===
|
||||
|
||||
def create_agent_runtime(
|
||||
graph: "GraphSpec",
|
||||
goal: "Goal",
|
||||
storage_path: str | Path,
|
||||
entry_points: list[EntryPointSpec],
|
||||
llm: "LLMProvider | None" = None,
|
||||
tools: list["Tool"] | None = None,
|
||||
tool_executor: Callable | None = None,
|
||||
config: AgentRuntimeConfig | None = None,
|
||||
) -> AgentRuntime:
|
||||
"""
|
||||
Create and configure an AgentRuntime with entry points.
|
||||
|
||||
Convenience factory that creates runtime and registers entry points.
|
||||
|
||||
Args:
|
||||
graph: Graph specification
|
||||
goal: Goal driving execution
|
||||
storage_path: Path for persistent storage
|
||||
entry_points: Entry point specifications
|
||||
llm: LLM provider
|
||||
tools: Available tools
|
||||
tool_executor: Tool executor function
|
||||
config: Runtime configuration
|
||||
|
||||
Returns:
|
||||
Configured AgentRuntime (not yet started)
|
||||
"""
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=storage_path,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
for spec in entry_points:
|
||||
runtime.register_entry_point(spec)
|
||||
|
||||
return runtime
|
||||
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
Event Bus - Pub/sub event system for inter-stream communication.
|
||||
|
||||
Allows streams to:
|
||||
- Publish events about their execution
|
||||
- Subscribe to events from other streams
|
||||
- Coordinate based on shared state changes
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
"""Types of events that can be published."""
|
||||
|
||||
# Execution lifecycle
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
EXECUTION_COMPLETED = "execution_completed"
|
||||
EXECUTION_FAILED = "execution_failed"
|
||||
EXECUTION_PAUSED = "execution_paused"
|
||||
EXECUTION_RESUMED = "execution_resumed"
|
||||
|
||||
# State changes
|
||||
STATE_CHANGED = "state_changed"
|
||||
STATE_CONFLICT = "state_conflict"
|
||||
|
||||
# Goal tracking
|
||||
GOAL_PROGRESS = "goal_progress"
|
||||
GOAL_ACHIEVED = "goal_achieved"
|
||||
CONSTRAINT_VIOLATION = "constraint_violation"
|
||||
|
||||
# Stream lifecycle
|
||||
STREAM_STARTED = "stream_started"
|
||||
STREAM_STOPPED = "stream_stopped"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentEvent:
|
||||
"""An event in the agent system."""
|
||||
type: EventType
|
||||
stream_id: str
|
||||
execution_id: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
correlation_id: str | None = None # For tracking related events
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"stream_id": self.stream_id,
|
||||
"execution_id": self.execution_id,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"correlation_id": self.correlation_id,
|
||||
}
|
||||
|
||||
|
||||
# Type for event handlers
|
||||
EventHandler = Callable[[AgentEvent], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Subscription:
|
||||
"""A subscription to events."""
|
||||
id: str
|
||||
event_types: set[EventType]
|
||||
handler: EventHandler
|
||||
filter_stream: str | None = None # Only receive events from this stream
|
||||
filter_execution: str | None = None # Only receive events from this execution
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Pub/sub event bus for inter-stream communication.
|
||||
|
||||
Features:
|
||||
- Async event handling
|
||||
- Type-based subscriptions
|
||||
- Stream/execution filtering
|
||||
- Event history for debugging
|
||||
|
||||
Example:
|
||||
bus = EventBus()
|
||||
|
||||
# Subscribe to execution events
|
||||
async def on_execution_complete(event: AgentEvent):
|
||||
print(f"Execution {event.execution_id} completed")
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_COMPLETED],
|
||||
handler=on_execution_complete,
|
||||
)
|
||||
|
||||
# Publish an event
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="webhook",
|
||||
execution_id="exec_123",
|
||||
data={"result": "success"},
|
||||
))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_history: int = 1000,
|
||||
max_concurrent_handlers: int = 10,
|
||||
):
|
||||
"""
|
||||
Initialize event bus.
|
||||
|
||||
Args:
|
||||
max_history: Maximum events to keep in history
|
||||
max_concurrent_handlers: Maximum concurrent handler executions
|
||||
"""
|
||||
self._subscriptions: dict[str, Subscription] = {}
|
||||
self._event_history: list[AgentEvent] = []
|
||||
self._max_history = max_history
|
||||
self._semaphore = asyncio.Semaphore(max_concurrent_handlers)
|
||||
self._subscription_counter = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
event_types: list[EventType],
|
||||
handler: EventHandler,
|
||||
filter_stream: str | None = None,
|
||||
filter_execution: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to events.
|
||||
|
||||
Args:
|
||||
event_types: Types of events to receive
|
||||
handler: Async function to call when event occurs
|
||||
filter_stream: Only receive events from this stream
|
||||
filter_execution: Only receive events from this execution
|
||||
|
||||
Returns:
|
||||
Subscription ID (use to unsubscribe)
|
||||
"""
|
||||
self._subscription_counter += 1
|
||||
sub_id = f"sub_{self._subscription_counter}"
|
||||
|
||||
subscription = Subscription(
|
||||
id=sub_id,
|
||||
event_types=set(event_types),
|
||||
handler=handler,
|
||||
filter_stream=filter_stream,
|
||||
filter_execution=filter_execution,
|
||||
)
|
||||
|
||||
self._subscriptions[sub_id] = subscription
|
||||
logger.debug(f"Subscription {sub_id} registered for {event_types}")
|
||||
|
||||
return sub_id
|
||||
|
||||
def unsubscribe(self, subscription_id: str) -> bool:
|
||||
"""
|
||||
Unsubscribe from events.
|
||||
|
||||
Args:
|
||||
subscription_id: ID returned from subscribe()
|
||||
|
||||
Returns:
|
||||
True if subscription was found and removed
|
||||
"""
|
||||
if subscription_id in self._subscriptions:
|
||||
del self._subscriptions[subscription_id]
|
||||
logger.debug(f"Subscription {subscription_id} removed")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def publish(self, event: AgentEvent) -> None:
|
||||
"""
|
||||
Publish an event to all matching subscribers.
|
||||
|
||||
Args:
|
||||
event: Event to publish
|
||||
"""
|
||||
# Add to history
|
||||
async with self._lock:
|
||||
self._event_history.append(event)
|
||||
if len(self._event_history) > self._max_history:
|
||||
self._event_history = self._event_history[-self._max_history:]
|
||||
|
||||
# Find matching subscriptions
|
||||
matching_handlers: list[EventHandler] = []
|
||||
|
||||
for subscription in self._subscriptions.values():
|
||||
if self._matches(subscription, event):
|
||||
matching_handlers.append(subscription.handler)
|
||||
|
||||
# Execute handlers concurrently
|
||||
if matching_handlers:
|
||||
await self._execute_handlers(event, matching_handlers)
|
||||
|
||||
def _matches(self, subscription: Subscription, event: AgentEvent) -> bool:
|
||||
"""Check if a subscription matches an event."""
|
||||
# Check event type
|
||||
if event.type not in subscription.event_types:
|
||||
return False
|
||||
|
||||
# Check stream filter
|
||||
if subscription.filter_stream and subscription.filter_stream != event.stream_id:
|
||||
return False
|
||||
|
||||
# Check execution filter
|
||||
if subscription.filter_execution and subscription.filter_execution != event.execution_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _execute_handlers(
|
||||
self,
|
||||
event: AgentEvent,
|
||||
handlers: list[EventHandler],
|
||||
) -> None:
|
||||
"""Execute handlers concurrently with rate limiting."""
|
||||
|
||||
async def run_handler(handler: EventHandler) -> None:
|
||||
async with self._semaphore:
|
||||
try:
|
||||
await handler(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Handler error for {event.type}: {e}")
|
||||
|
||||
# Run all handlers concurrently
|
||||
await asyncio.gather(*[run_handler(h) for h in handlers], return_exceptions=True)
|
||||
|
||||
# === CONVENIENCE PUBLISHERS ===
|
||||
|
||||
async def emit_execution_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution started event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"input": input_data or {}},
|
||||
correlation_id=correlation_id,
|
||||
))
|
||||
|
||||
async def emit_execution_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution completed event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"output": output or {}},
|
||||
correlation_id=correlation_id,
|
||||
))
|
||||
|
||||
async def emit_execution_failed(
|
||||
self,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
error: str,
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution failed event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_FAILED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"error": error},
|
||||
correlation_id=correlation_id,
|
||||
))
|
||||
|
||||
async def emit_goal_progress(
|
||||
self,
|
||||
stream_id: str,
|
||||
progress: float,
|
||||
criteria_status: dict[str, Any],
|
||||
) -> None:
|
||||
"""Emit goal progress event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.GOAL_PROGRESS,
|
||||
stream_id=stream_id,
|
||||
data={
|
||||
"progress": progress,
|
||||
"criteria_status": criteria_status,
|
||||
},
|
||||
))
|
||||
|
||||
async def emit_constraint_violation(
|
||||
self,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
constraint_id: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
"""Emit constraint violation event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.CONSTRAINT_VIOLATION,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"constraint_id": constraint_id,
|
||||
"description": description,
|
||||
},
|
||||
))
|
||||
|
||||
async def emit_state_changed(
|
||||
self,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
key: str,
|
||||
old_value: Any,
|
||||
new_value: Any,
|
||||
scope: str,
|
||||
) -> None:
|
||||
"""Emit state changed event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.STATE_CHANGED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"key": key,
|
||||
"old_value": old_value,
|
||||
"new_value": new_value,
|
||||
"scope": scope,
|
||||
},
|
||||
))
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
event_type: EventType | None = None,
|
||||
stream_id: str | None = None,
|
||||
execution_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[AgentEvent]:
|
||||
"""
|
||||
Get event history with optional filtering.
|
||||
|
||||
Args:
|
||||
event_type: Filter by event type
|
||||
stream_id: Filter by stream
|
||||
execution_id: Filter by execution
|
||||
limit: Maximum events to return
|
||||
|
||||
Returns:
|
||||
List of matching events (most recent first)
|
||||
"""
|
||||
events = self._event_history[::-1] # Reverse for most recent first
|
||||
|
||||
# Apply filters
|
||||
if event_type:
|
||||
events = [e for e in events if e.type == event_type]
|
||||
if stream_id:
|
||||
events = [e for e in events if e.stream_id == stream_id]
|
||||
if execution_id:
|
||||
events = [e for e in events if e.execution_id == execution_id]
|
||||
|
||||
return events[:limit]
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get event bus statistics."""
|
||||
type_counts = {}
|
||||
for event in self._event_history:
|
||||
type_counts[event.type.value] = type_counts.get(event.type.value, 0) + 1
|
||||
|
||||
return {
|
||||
"total_events": len(self._event_history),
|
||||
"subscriptions": len(self._subscriptions),
|
||||
"events_by_type": type_counts,
|
||||
}
|
||||
|
||||
# === WAITING OPERATIONS ===
|
||||
|
||||
async def wait_for(
|
||||
self,
|
||||
event_type: EventType,
|
||||
stream_id: str | None = None,
|
||||
execution_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AgentEvent | None:
|
||||
"""
|
||||
Wait for a specific event to occur.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to wait for
|
||||
stream_id: Filter by stream
|
||||
execution_id: Filter by execution
|
||||
timeout: Maximum time to wait (seconds)
|
||||
|
||||
Returns:
|
||||
The event if received, None if timeout
|
||||
"""
|
||||
result: AgentEvent | None = None
|
||||
event_received = asyncio.Event()
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
nonlocal result
|
||||
result = event
|
||||
event_received.set()
|
||||
|
||||
# Subscribe
|
||||
sub_id = self.subscribe(
|
||||
event_types=[event_type],
|
||||
handler=handler,
|
||||
filter_stream=stream_id,
|
||||
filter_execution=execution_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Wait with timeout
|
||||
if timeout:
|
||||
try:
|
||||
await asyncio.wait_for(event_received.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
else:
|
||||
await event_received.wait()
|
||||
|
||||
return result
|
||||
finally:
|
||||
self.unsubscribe(sub_id)
|
||||
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
Execution Stream - Manages concurrent executions for a single entry point.
|
||||
|
||||
Each stream has:
|
||||
- Its own StreamRuntime for decision tracking
|
||||
- Access to shared state (read/write based on isolation)
|
||||
- Connection to the outcome aggregator
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
from framework.graph.executor import GraphExecutor, ExecutionResult
|
||||
from framework.runtime.stream_runtime import StreamRuntime, StreamRuntimeAdapter
|
||||
from framework.runtime.shared_state import SharedStateManager, IsolationLevel, StreamMemory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntryPointSpec:
|
||||
"""Specification for an entry point."""
|
||||
id: str
|
||||
name: str
|
||||
entry_node: str # Node ID to start from
|
||||
trigger_type: str # "webhook", "api", "timer", "event", "manual"
|
||||
trigger_config: dict[str, Any] = field(default_factory=dict)
|
||||
isolation_level: str = "shared" # "isolated" | "shared" | "synchronized"
|
||||
priority: int = 0
|
||||
max_concurrent: int = 10 # Max concurrent executions for this entry point
|
||||
|
||||
def get_isolation_level(self) -> IsolationLevel:
|
||||
"""Convert string isolation level to enum."""
|
||||
return IsolationLevel(self.isolation_level)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""Context for a single execution."""
|
||||
id: str
|
||||
correlation_id: str
|
||||
stream_id: str
|
||||
entry_point: str
|
||||
input_data: dict[str, Any]
|
||||
isolation_level: IsolationLevel
|
||||
started_at: datetime = field(default_factory=datetime.now)
|
||||
completed_at: datetime | None = None
|
||||
status: str = "pending" # pending, running, completed, failed, paused
|
||||
|
||||
|
||||
class ExecutionStream:
|
||||
"""
|
||||
Manages concurrent executions for a single entry point.
|
||||
|
||||
Each stream:
|
||||
- Has its own StreamRuntime for thread-safe decision tracking
|
||||
- Creates GraphExecutor instances per execution
|
||||
- Manages execution lifecycle with proper isolation
|
||||
|
||||
Example:
|
||||
stream = ExecutionStream(
|
||||
stream_id="webhook",
|
||||
entry_spec=webhook_entry,
|
||||
graph=graph_spec,
|
||||
goal=goal,
|
||||
state_manager=shared_state,
|
||||
storage=concurrent_storage,
|
||||
outcome_aggregator=aggregator,
|
||||
event_bus=event_bus,
|
||||
llm=llm_provider,
|
||||
)
|
||||
|
||||
await stream.start()
|
||||
|
||||
# Trigger execution
|
||||
exec_id = await stream.execute({"ticket_id": "123"})
|
||||
|
||||
# Wait for result
|
||||
result = await stream.wait_for_completion(exec_id)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
entry_spec: EntryPointSpec,
|
||||
graph: "GraphSpec",
|
||||
goal: "Goal",
|
||||
state_manager: SharedStateManager,
|
||||
storage: "ConcurrentStorage",
|
||||
outcome_aggregator: "OutcomeAggregator",
|
||||
event_bus: "EventBus | None" = None,
|
||||
llm: "LLMProvider | None" = None,
|
||||
tools: list["Tool"] | None = None,
|
||||
tool_executor: Callable | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
|
||||
Args:
|
||||
stream_id: Unique identifier for this stream
|
||||
entry_spec: Entry point specification
|
||||
graph: Graph specification for this agent
|
||||
goal: Goal driving execution
|
||||
state_manager: Shared state manager
|
||||
storage: Concurrent storage backend
|
||||
outcome_aggregator: For cross-stream evaluation
|
||||
event_bus: Optional event bus for publishing events
|
||||
llm: LLM provider for nodes
|
||||
tools: Available tools
|
||||
tool_executor: Function to execute tools
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.entry_spec = entry_spec
|
||||
self.graph = graph
|
||||
self.goal = goal
|
||||
self._state_manager = state_manager
|
||||
self._storage = storage
|
||||
self._outcome_aggregator = outcome_aggregator
|
||||
self._event_bus = event_bus
|
||||
self._llm = llm
|
||||
self._tools = tools or []
|
||||
self._tool_executor = tool_executor
|
||||
|
||||
# Create stream-scoped runtime
|
||||
self._runtime = StreamRuntime(
|
||||
stream_id=stream_id,
|
||||
storage=storage,
|
||||
outcome_aggregator=outcome_aggregator,
|
||||
)
|
||||
|
||||
# Execution tracking
|
||||
self._active_executions: dict[str, ExecutionContext] = {}
|
||||
self._execution_tasks: dict[str, asyncio.Task] = {}
|
||||
self._execution_results: dict[str, ExecutionResult] = {}
|
||||
self._completion_events: dict[str, asyncio.Event] = {}
|
||||
|
||||
# Concurrency control
|
||||
self._semaphore = asyncio.Semaphore(entry_spec.max_concurrent)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the execution stream."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info(f"ExecutionStream '{self.stream_id}' started")
|
||||
|
||||
# Emit stream started event
|
||||
if self._event_bus:
|
||||
from framework.runtime.event_bus import EventType, AgentEvent
|
||||
await self._event_bus.publish(AgentEvent(
|
||||
type=EventType.STREAM_STARTED,
|
||||
stream_id=self.stream_id,
|
||||
data={"entry_point": self.entry_spec.id},
|
||||
))
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the execution stream and cancel active executions."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Cancel all active executions
|
||||
for exec_id, task in self._execution_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._execution_tasks.clear()
|
||||
self._active_executions.clear()
|
||||
|
||||
logger.info(f"ExecutionStream '{self.stream_id}' stopped")
|
||||
|
||||
# Emit stream stopped event
|
||||
if self._event_bus:
|
||||
from framework.runtime.event_bus import EventType, AgentEvent
|
||||
await self._event_bus.publish(AgentEvent(
|
||||
type=EventType.STREAM_STOPPED,
|
||||
stream_id=self.stream_id,
|
||||
))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_data: dict[str, Any],
|
||||
correlation_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Queue an execution and return its ID.
|
||||
|
||||
Non-blocking - the execution runs in the background.
|
||||
|
||||
Args:
|
||||
input_data: Input data for this execution
|
||||
correlation_id: Optional ID to correlate related executions
|
||||
|
||||
Returns:
|
||||
Execution ID for tracking
|
||||
"""
|
||||
if not self._running:
|
||||
raise RuntimeError(f"ExecutionStream '{self.stream_id}' is not running")
|
||||
|
||||
# Generate execution ID
|
||||
execution_id = f"exec_{self.stream_id}_{uuid.uuid4().hex[:8]}"
|
||||
if correlation_id is None:
|
||||
correlation_id = execution_id
|
||||
|
||||
# Create execution context
|
||||
ctx = ExecutionContext(
|
||||
id=execution_id,
|
||||
correlation_id=correlation_id,
|
||||
stream_id=self.stream_id,
|
||||
entry_point=self.entry_spec.id,
|
||||
input_data=input_data,
|
||||
isolation_level=self.entry_spec.get_isolation_level(),
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._active_executions[execution_id] = ctx
|
||||
self._completion_events[execution_id] = asyncio.Event()
|
||||
|
||||
# Start execution task
|
||||
task = asyncio.create_task(self._run_execution(ctx))
|
||||
self._execution_tasks[execution_id] = task
|
||||
|
||||
logger.debug(f"Queued execution {execution_id} for stream {self.stream_id}")
|
||||
return execution_id
|
||||
|
||||
async def _run_execution(self, ctx: ExecutionContext) -> None:
|
||||
"""Run a single execution within the stream."""
|
||||
execution_id = ctx.id
|
||||
|
||||
# Acquire semaphore to limit concurrency
|
||||
async with self._semaphore:
|
||||
ctx.status = "running"
|
||||
|
||||
try:
|
||||
# Emit started event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_started(
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
input_data=ctx.input_data,
|
||||
correlation_id=ctx.correlation_id,
|
||||
)
|
||||
|
||||
# Create execution-scoped memory
|
||||
memory = self._state_manager.create_memory(
|
||||
execution_id=execution_id,
|
||||
stream_id=self.stream_id,
|
||||
isolation=ctx.isolation_level,
|
||||
)
|
||||
|
||||
# Create runtime adapter for this execution
|
||||
runtime_adapter = StreamRuntimeAdapter(self._runtime, execution_id)
|
||||
|
||||
# Create executor for this execution
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime_adapter,
|
||||
llm=self._llm,
|
||||
tools=self._tools,
|
||||
tool_executor=self._tool_executor,
|
||||
)
|
||||
|
||||
# Create modified graph with entry point
|
||||
# We need to override the entry_node to use our entry point
|
||||
modified_graph = self._create_modified_graph()
|
||||
|
||||
# Execute
|
||||
result = await executor.execute(
|
||||
graph=modified_graph,
|
||||
goal=self.goal,
|
||||
input_data=ctx.input_data,
|
||||
)
|
||||
|
||||
# Store result
|
||||
self._execution_results[execution_id] = result
|
||||
|
||||
# Update context
|
||||
ctx.completed_at = datetime.now()
|
||||
ctx.status = "completed" if result.success else "failed"
|
||||
if result.paused_at:
|
||||
ctx.status = "paused"
|
||||
|
||||
# Emit completion/failure event
|
||||
if self._event_bus:
|
||||
if result.success:
|
||||
await self._event_bus.emit_execution_completed(
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
output=result.output,
|
||||
correlation_id=ctx.correlation_id,
|
||||
)
|
||||
else:
|
||||
await self._event_bus.emit_execution_failed(
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
error=result.error or "Unknown error",
|
||||
correlation_id=ctx.correlation_id,
|
||||
)
|
||||
|
||||
logger.debug(f"Execution {execution_id} completed: success={result.success}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
ctx.status = "cancelled"
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
ctx.status = "failed"
|
||||
logger.error(f"Execution {execution_id} failed: {e}")
|
||||
|
||||
# Store error result
|
||||
self._execution_results[execution_id] = ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Emit failure event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_failed(
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
error=str(e),
|
||||
correlation_id=ctx.correlation_id,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up state
|
||||
self._state_manager.cleanup_execution(execution_id)
|
||||
|
||||
# Signal completion
|
||||
if execution_id in self._completion_events:
|
||||
self._completion_events[execution_id].set()
|
||||
|
||||
def _create_modified_graph(self) -> "GraphSpec":
|
||||
"""Create a graph with the entry point overridden."""
|
||||
# Use the existing graph but override entry_node
|
||||
from framework.graph.edge import GraphSpec
|
||||
|
||||
# Create a copy with modified entry node
|
||||
return GraphSpec(
|
||||
id=self.graph.id,
|
||||
goal_id=self.graph.goal_id,
|
||||
version=self.graph.version,
|
||||
entry_node=self.entry_spec.entry_node, # Use our entry point
|
||||
entry_points={
|
||||
"start": self.entry_spec.entry_node,
|
||||
**self.graph.entry_points,
|
||||
},
|
||||
terminal_nodes=self.graph.terminal_nodes,
|
||||
pause_nodes=self.graph.pause_nodes,
|
||||
nodes=self.graph.nodes,
|
||||
edges=self.graph.edges,
|
||||
default_model=self.graph.default_model,
|
||||
max_tokens=self.graph.max_tokens,
|
||||
max_steps=self.graph.max_steps,
|
||||
)
|
||||
|
||||
async def wait_for_completion(
|
||||
self,
|
||||
execution_id: str,
|
||||
timeout: float | None = None,
|
||||
) -> ExecutionResult | None:
|
||||
"""
|
||||
Wait for an execution to complete.
|
||||
|
||||
Args:
|
||||
execution_id: Execution to wait for
|
||||
timeout: Maximum time to wait (seconds)
|
||||
|
||||
Returns:
|
||||
ExecutionResult or None if timeout
|
||||
"""
|
||||
event = self._completion_events.get(execution_id)
|
||||
if event is None:
|
||||
# Execution not found or already cleaned up
|
||||
return self._execution_results.get(execution_id)
|
||||
|
||||
try:
|
||||
if timeout:
|
||||
await asyncio.wait_for(event.wait(), timeout=timeout)
|
||||
else:
|
||||
await event.wait()
|
||||
|
||||
return self._execution_results.get(execution_id)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
def get_result(self, execution_id: str) -> ExecutionResult | None:
|
||||
"""Get result of a completed execution."""
|
||||
return self._execution_results.get(execution_id)
|
||||
|
||||
def get_context(self, execution_id: str) -> ExecutionContext | None:
|
||||
"""Get execution context."""
|
||||
return self._active_executions.get(execution_id)
|
||||
|
||||
async def cancel_execution(self, execution_id: str) -> bool:
|
||||
"""
|
||||
Cancel a running execution.
|
||||
|
||||
Args:
|
||||
execution_id: Execution to cancel
|
||||
|
||||
Returns:
|
||||
True if cancelled, False if not found
|
||||
"""
|
||||
task = self._execution_tasks.get(execution_id)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return True
|
||||
return False
|
||||
|
||||
# === STATS AND MONITORING ===
|
||||
|
||||
def get_active_count(self) -> int:
|
||||
"""Get count of active executions."""
|
||||
return len([
|
||||
ctx for ctx in self._active_executions.values()
|
||||
if ctx.status == "running"
|
||||
])
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get stream statistics."""
|
||||
statuses = {}
|
||||
for ctx in self._active_executions.values():
|
||||
statuses[ctx.status] = statuses.get(ctx.status, 0) + 1
|
||||
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"entry_point": self.entry_spec.id,
|
||||
"running": self._running,
|
||||
"total_executions": len(self._active_executions),
|
||||
"completed_executions": len(self._execution_results),
|
||||
"status_counts": statuses,
|
||||
"max_concurrent": self.entry_spec.max_concurrent,
|
||||
"available_slots": self._semaphore._value,
|
||||
}
|
||||
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Outcome Aggregator - Aggregates outcomes across streams for goal evaluation.
|
||||
|
||||
The goal-driven nature of Hive means we need to track whether
|
||||
concurrent executions collectively achieve the goal.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from framework.schemas.decision import Decision, Outcome
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.goal import Goal
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriterionStatus:
|
||||
"""Status of a success criterion."""
|
||||
criterion_id: str
|
||||
description: str
|
||||
met: bool
|
||||
evidence: list[str] = field(default_factory=list)
|
||||
progress: float = 0.0 # 0.0 to 1.0
|
||||
last_updated: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConstraintCheck:
|
||||
"""Result of a constraint check."""
|
||||
constraint_id: str
|
||||
description: str
|
||||
violated: bool
|
||||
violation_details: str | None = None
|
||||
stream_id: str | None = None
|
||||
execution_id: str | None = None
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecisionRecord:
|
||||
"""Record of a decision for aggregation."""
|
||||
stream_id: str
|
||||
execution_id: str
|
||||
decision: Decision
|
||||
outcome: Outcome | None = None
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class OutcomeAggregator:
|
||||
"""
|
||||
Aggregates outcomes across all execution streams for goal evaluation.
|
||||
|
||||
Responsibilities:
|
||||
- Track all decisions across streams
|
||||
- Evaluate success criteria progress
|
||||
- Detect constraint violations
|
||||
- Provide unified goal progress metrics
|
||||
|
||||
Example:
|
||||
aggregator = OutcomeAggregator(goal, event_bus)
|
||||
|
||||
# Decisions are automatically recorded by StreamRuntime
|
||||
aggregator.record_decision(stream_id, execution_id, decision)
|
||||
aggregator.record_outcome(stream_id, execution_id, decision_id, outcome)
|
||||
|
||||
# Evaluate goal progress
|
||||
progress = await aggregator.evaluate_goal_progress()
|
||||
print(f"Goal progress: {progress['overall_progress']:.1%}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
goal: "Goal",
|
||||
event_bus: "EventBus | None" = None,
|
||||
):
|
||||
"""
|
||||
Initialize outcome aggregator.
|
||||
|
||||
Args:
|
||||
goal: The goal to evaluate progress against
|
||||
event_bus: Optional event bus for publishing progress events
|
||||
"""
|
||||
self.goal = goal
|
||||
self._event_bus = event_bus
|
||||
|
||||
# Decision tracking
|
||||
self._decisions: list[DecisionRecord] = []
|
||||
self._decisions_by_id: dict[str, DecisionRecord] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Criterion tracking
|
||||
self._criterion_status: dict[str, CriterionStatus] = {}
|
||||
self._initialize_criteria()
|
||||
|
||||
# Constraint tracking
|
||||
self._constraint_violations: list[ConstraintCheck] = []
|
||||
|
||||
# Metrics
|
||||
self._total_decisions = 0
|
||||
self._successful_outcomes = 0
|
||||
self._failed_outcomes = 0
|
||||
|
||||
def _initialize_criteria(self) -> None:
|
||||
"""Initialize criterion status from goal."""
|
||||
for criterion in self.goal.success_criteria:
|
||||
self._criterion_status[criterion.id] = CriterionStatus(
|
||||
criterion_id=criterion.id,
|
||||
description=criterion.description,
|
||||
met=False,
|
||||
progress=0.0,
|
||||
)
|
||||
|
||||
# === DECISION RECORDING ===
|
||||
|
||||
def record_decision(
|
||||
self,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
decision: Decision,
|
||||
) -> None:
|
||||
"""
|
||||
Record a decision from any stream.
|
||||
|
||||
Args:
|
||||
stream_id: Which stream made the decision
|
||||
execution_id: Which execution
|
||||
decision: The decision made
|
||||
"""
|
||||
record = DecisionRecord(
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
decision=decision,
|
||||
)
|
||||
|
||||
# Create unique key for lookup
|
||||
key = f"{stream_id}:{execution_id}:{decision.id}"
|
||||
self._decisions.append(record)
|
||||
self._decisions_by_id[key] = record
|
||||
self._total_decisions += 1
|
||||
|
||||
logger.debug(f"Recorded decision {decision.id} from {stream_id}/{execution_id}")
|
||||
|
||||
def record_outcome(
|
||||
self,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
decision_id: str,
|
||||
outcome: Outcome,
|
||||
) -> None:
|
||||
"""
|
||||
Record the outcome of a decision.
|
||||
|
||||
Args:
|
||||
stream_id: Which stream
|
||||
execution_id: Which execution
|
||||
decision_id: Which decision
|
||||
outcome: The outcome
|
||||
"""
|
||||
key = f"{stream_id}:{execution_id}:{decision_id}"
|
||||
record = self._decisions_by_id.get(key)
|
||||
|
||||
if record:
|
||||
record.outcome = outcome
|
||||
|
||||
if outcome.success:
|
||||
self._successful_outcomes += 1
|
||||
else:
|
||||
self._failed_outcomes += 1
|
||||
|
||||
logger.debug(f"Recorded outcome for {decision_id}: success={outcome.success}")
|
||||
|
||||
def record_constraint_violation(
|
||||
self,
|
||||
constraint_id: str,
|
||||
description: str,
|
||||
violation_details: str,
|
||||
stream_id: str | None = None,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Record a constraint violation.
|
||||
|
||||
Args:
|
||||
constraint_id: Which constraint was violated
|
||||
description: Constraint description
|
||||
violation_details: What happened
|
||||
stream_id: Which stream
|
||||
execution_id: Which execution
|
||||
"""
|
||||
check = ConstraintCheck(
|
||||
constraint_id=constraint_id,
|
||||
description=description,
|
||||
violated=True,
|
||||
violation_details=violation_details,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
self._constraint_violations.append(check)
|
||||
logger.warning(f"Constraint violation: {constraint_id} - {violation_details}")
|
||||
|
||||
# Publish event if event bus available
|
||||
if self._event_bus and stream_id:
|
||||
asyncio.create_task(
|
||||
self._event_bus.emit_constraint_violation(
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id or "",
|
||||
constraint_id=constraint_id,
|
||||
description=violation_details,
|
||||
)
|
||||
)
|
||||
|
||||
# === GOAL EVALUATION ===
|
||||
|
||||
async def evaluate_goal_progress(self) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate progress toward goal across all streams.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"overall_progress": 0.0-1.0,
|
||||
"criteria_status": {criterion_id: {...}},
|
||||
"constraint_violations": [...],
|
||||
"metrics": {...},
|
||||
"recommendation": "continue" | "adjust" | "complete"
|
||||
}
|
||||
"""
|
||||
async with self._lock:
|
||||
result = {
|
||||
"overall_progress": 0.0,
|
||||
"criteria_status": {},
|
||||
"constraint_violations": [],
|
||||
"metrics": {},
|
||||
"recommendation": "continue",
|
||||
}
|
||||
|
||||
# Evaluate each success criterion
|
||||
total_weight = 0.0
|
||||
met_weight = 0.0
|
||||
|
||||
for criterion in self.goal.success_criteria:
|
||||
status = await self._evaluate_criterion(criterion)
|
||||
self._criterion_status[criterion.id] = status
|
||||
result["criteria_status"][criterion.id] = {
|
||||
"description": status.description,
|
||||
"met": status.met,
|
||||
"progress": status.progress,
|
||||
"evidence": status.evidence,
|
||||
}
|
||||
|
||||
total_weight += criterion.weight
|
||||
if status.met:
|
||||
met_weight += criterion.weight
|
||||
else:
|
||||
# Partial credit based on progress
|
||||
met_weight += criterion.weight * status.progress
|
||||
|
||||
# Calculate overall progress
|
||||
if total_weight > 0:
|
||||
result["overall_progress"] = met_weight / total_weight
|
||||
|
||||
# Include constraint violations
|
||||
result["constraint_violations"] = [
|
||||
{
|
||||
"constraint_id": v.constraint_id,
|
||||
"description": v.description,
|
||||
"details": v.violation_details,
|
||||
"stream_id": v.stream_id,
|
||||
"timestamp": v.timestamp.isoformat(),
|
||||
}
|
||||
for v in self._constraint_violations
|
||||
]
|
||||
|
||||
# Add metrics
|
||||
result["metrics"] = {
|
||||
"total_decisions": self._total_decisions,
|
||||
"successful_outcomes": self._successful_outcomes,
|
||||
"failed_outcomes": self._failed_outcomes,
|
||||
"success_rate": (
|
||||
self._successful_outcomes / max(1, self._successful_outcomes + self._failed_outcomes)
|
||||
),
|
||||
"streams_active": len(set(d.stream_id for d in self._decisions)),
|
||||
"executions_total": len(set((d.stream_id, d.execution_id) for d in self._decisions)),
|
||||
}
|
||||
|
||||
# Determine recommendation
|
||||
result["recommendation"] = self._get_recommendation(result)
|
||||
|
||||
# Publish progress event
|
||||
if self._event_bus:
|
||||
# Get any stream ID for the event
|
||||
stream_ids = set(d.stream_id for d in self._decisions)
|
||||
if stream_ids:
|
||||
await self._event_bus.emit_goal_progress(
|
||||
stream_id=list(stream_ids)[0],
|
||||
progress=result["overall_progress"],
|
||||
criteria_status=result["criteria_status"],
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _evaluate_criterion(self, criterion: Any) -> CriterionStatus:
|
||||
"""
|
||||
Evaluate a single success criterion.
|
||||
|
||||
This is a heuristic evaluation based on decision outcomes.
|
||||
More sophisticated evaluation can be added per criterion type.
|
||||
"""
|
||||
status = CriterionStatus(
|
||||
criterion_id=criterion.id,
|
||||
description=criterion.description,
|
||||
met=False,
|
||||
progress=0.0,
|
||||
evidence=[],
|
||||
)
|
||||
|
||||
# Get relevant decisions (those mentioning this criterion or related intents)
|
||||
relevant_decisions = [
|
||||
d for d in self._decisions
|
||||
if criterion.id in str(d.decision.active_constraints)
|
||||
or self._is_related_to_criterion(d.decision, criterion)
|
||||
]
|
||||
|
||||
if not relevant_decisions:
|
||||
# No evidence yet
|
||||
return status
|
||||
|
||||
# Calculate success rate for relevant decisions
|
||||
outcomes = [d.outcome for d in relevant_decisions if d.outcome is not None]
|
||||
if outcomes:
|
||||
success_count = sum(1 for o in outcomes if o.success)
|
||||
status.progress = success_count / len(outcomes)
|
||||
|
||||
# Add evidence
|
||||
for d in relevant_decisions[:5]: # Limit evidence
|
||||
if d.outcome:
|
||||
evidence = f"{d.decision.intent}: {'success' if d.outcome.success else 'failed'}"
|
||||
status.evidence.append(evidence)
|
||||
|
||||
# Check if criterion is met based on target
|
||||
try:
|
||||
target = criterion.target
|
||||
if isinstance(target, str) and target.endswith("%"):
|
||||
target_value = float(target.rstrip("%")) / 100
|
||||
status.met = status.progress >= target_value
|
||||
else:
|
||||
# For non-percentage targets, consider met if progress > 0.8
|
||||
status.met = status.progress >= 0.8
|
||||
except (ValueError, AttributeError):
|
||||
status.met = status.progress >= 0.8
|
||||
|
||||
return status
|
||||
|
||||
def _is_related_to_criterion(self, decision: Decision, criterion: Any) -> bool:
|
||||
"""Check if a decision is related to a criterion."""
|
||||
# Simple keyword matching
|
||||
criterion_keywords = criterion.description.lower().split()
|
||||
decision_text = f"{decision.intent} {decision.reasoning}".lower()
|
||||
|
||||
matches = sum(1 for kw in criterion_keywords if kw in decision_text)
|
||||
return matches >= 2 # At least 2 keyword matches
|
||||
|
||||
def _get_recommendation(self, result: dict) -> str:
|
||||
"""Get recommendation based on current progress."""
|
||||
progress = result["overall_progress"]
|
||||
violations = result["constraint_violations"]
|
||||
|
||||
# Check for hard constraint violations
|
||||
hard_violations = [
|
||||
v for v in violations
|
||||
if self._is_hard_constraint(v["constraint_id"])
|
||||
]
|
||||
|
||||
if hard_violations:
|
||||
return "adjust" # Must address violations
|
||||
|
||||
if progress >= 0.95:
|
||||
return "complete" # Goal essentially achieved
|
||||
|
||||
if progress < 0.3 and result["metrics"]["total_decisions"] > 10:
|
||||
return "adjust" # Low progress despite many decisions
|
||||
|
||||
return "continue"
|
||||
|
||||
def _is_hard_constraint(self, constraint_id: str) -> bool:
|
||||
"""Check if a constraint is a hard constraint."""
|
||||
for constraint in self.goal.constraints:
|
||||
if constraint.id == constraint_id:
|
||||
return constraint.constraint_type == "hard"
|
||||
return False
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_decisions_by_stream(self, stream_id: str) -> list[DecisionRecord]:
|
||||
"""Get all decisions from a specific stream."""
|
||||
return [d for d in self._decisions if d.stream_id == stream_id]
|
||||
|
||||
def get_decisions_by_execution(
|
||||
self,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
) -> list[DecisionRecord]:
|
||||
"""Get all decisions from a specific execution."""
|
||||
return [
|
||||
d for d in self._decisions
|
||||
if d.stream_id == stream_id and d.execution_id == execution_id
|
||||
]
|
||||
|
||||
def get_recent_decisions(self, limit: int = 10) -> list[DecisionRecord]:
|
||||
"""Get most recent decisions."""
|
||||
return self._decisions[-limit:]
|
||||
|
||||
def get_criterion_status(self, criterion_id: str) -> CriterionStatus | None:
|
||||
"""Get status of a specific criterion."""
|
||||
return self._criterion_status.get(criterion_id)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get aggregator statistics."""
|
||||
return {
|
||||
"total_decisions": self._total_decisions,
|
||||
"successful_outcomes": self._successful_outcomes,
|
||||
"failed_outcomes": self._failed_outcomes,
|
||||
"constraint_violations": len(self._constraint_violations),
|
||||
"criteria_tracked": len(self._criterion_status),
|
||||
"streams_seen": len(set(d.stream_id for d in self._decisions)),
|
||||
}
|
||||
|
||||
# === RESET OPERATIONS ===
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all aggregated data."""
|
||||
self._decisions.clear()
|
||||
self._decisions_by_id.clear()
|
||||
self._constraint_violations.clear()
|
||||
self._total_decisions = 0
|
||||
self._successful_outcomes = 0
|
||||
self._failed_outcomes = 0
|
||||
self._initialize_criteria()
|
||||
logger.info("OutcomeAggregator reset")
|
||||
@@ -0,0 +1,494 @@
|
||||
"""
|
||||
Shared State Manager - Manages state across concurrent executions.
|
||||
|
||||
Provides different isolation levels:
|
||||
- ISOLATED: Each execution has its own memory copy
|
||||
- SHARED: All executions read/write same memory (eventual consistency)
|
||||
- SYNCHRONIZED: Shared memory with write locks (strong consistency)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IsolationLevel(str, Enum):
|
||||
"""State isolation level for concurrent executions."""
|
||||
ISOLATED = "isolated" # Private state per execution
|
||||
SHARED = "shared" # Shared state (eventual consistency)
|
||||
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
|
||||
|
||||
|
||||
class StateScope(str, Enum):
|
||||
"""Scope for state operations."""
|
||||
EXECUTION = "execution" # Local to a single execution
|
||||
STREAM = "stream" # Shared within a stream
|
||||
GLOBAL = "global" # Shared across all streams
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateChange:
|
||||
"""Record of a state change."""
|
||||
key: str
|
||||
old_value: Any
|
||||
new_value: Any
|
||||
scope: StateScope
|
||||
execution_id: str
|
||||
stream_id: str
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class SharedStateManager:
|
||||
"""
|
||||
Manages shared state across concurrent executions.
|
||||
|
||||
State hierarchy:
|
||||
- Global state: Shared across all streams and executions
|
||||
- Stream state: Shared within a stream (across executions)
|
||||
- Execution state: Private to a single execution
|
||||
|
||||
Isolation levels control visibility:
|
||||
- ISOLATED: Only sees execution state
|
||||
- SHARED: Sees all levels, writes propagate up based on scope
|
||||
- SYNCHRONIZED: Like SHARED but with write locks
|
||||
|
||||
Example:
|
||||
manager = SharedStateManager()
|
||||
|
||||
# Create memory for an execution
|
||||
memory = manager.create_memory(
|
||||
execution_id="exec_123",
|
||||
stream_id="webhook",
|
||||
isolation=IsolationLevel.SHARED,
|
||||
)
|
||||
|
||||
# Read/write through the memory
|
||||
await memory.write("customer_id", "cust_456", scope=StateScope.STREAM)
|
||||
value = await memory.read("customer_id")
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# State storage at each level
|
||||
self._global_state: dict[str, Any] = {}
|
||||
self._stream_state: dict[str, dict[str, Any]] = {} # stream_id -> {key: value}
|
||||
self._execution_state: dict[str, dict[str, Any]] = {} # execution_id -> {key: value}
|
||||
|
||||
# Locks for synchronized access
|
||||
self._global_lock = asyncio.Lock()
|
||||
self._stream_locks: dict[str, asyncio.Lock] = {}
|
||||
self._key_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# Change history for debugging/auditing
|
||||
self._change_history: list[StateChange] = []
|
||||
self._max_history = 1000
|
||||
|
||||
# Version tracking
|
||||
self._version = 0
|
||||
|
||||
def create_memory(
|
||||
self,
|
||||
execution_id: str,
|
||||
stream_id: str,
|
||||
isolation: IsolationLevel,
|
||||
) -> "StreamMemory":
|
||||
"""
|
||||
Create a memory instance for an execution.
|
||||
|
||||
Args:
|
||||
execution_id: Unique execution identifier
|
||||
stream_id: Stream this execution belongs to
|
||||
isolation: Isolation level for this execution
|
||||
|
||||
Returns:
|
||||
StreamMemory instance for reading/writing state
|
||||
"""
|
||||
# Initialize execution state
|
||||
if execution_id not in self._execution_state:
|
||||
self._execution_state[execution_id] = {}
|
||||
|
||||
# Initialize stream state
|
||||
if stream_id not in self._stream_state:
|
||||
self._stream_state[stream_id] = {}
|
||||
self._stream_locks[stream_id] = asyncio.Lock()
|
||||
|
||||
return StreamMemory(
|
||||
manager=self,
|
||||
execution_id=execution_id,
|
||||
stream_id=stream_id,
|
||||
isolation=isolation,
|
||||
)
|
||||
|
||||
def cleanup_execution(self, execution_id: str) -> None:
|
||||
"""
|
||||
Clean up state for a completed execution.
|
||||
|
||||
Args:
|
||||
execution_id: Execution to clean up
|
||||
"""
|
||||
self._execution_state.pop(execution_id, None)
|
||||
logger.debug(f"Cleaned up state for execution: {execution_id}")
|
||||
|
||||
def cleanup_stream(self, stream_id: str) -> None:
|
||||
"""
|
||||
Clean up state for a closed stream.
|
||||
|
||||
Args:
|
||||
stream_id: Stream to clean up
|
||||
"""
|
||||
self._stream_state.pop(stream_id, None)
|
||||
self._stream_locks.pop(stream_id, None)
|
||||
logger.debug(f"Cleaned up state for stream: {stream_id}")
|
||||
|
||||
# === LOW-LEVEL STATE OPERATIONS ===
|
||||
|
||||
async def read(
|
||||
self,
|
||||
key: str,
|
||||
execution_id: str,
|
||||
stream_id: str,
|
||||
isolation: IsolationLevel,
|
||||
) -> Any:
|
||||
"""
|
||||
Read a value respecting isolation level.
|
||||
|
||||
Resolution order (stops at first match):
|
||||
1. Execution state (always checked)
|
||||
2. Stream state (if isolation != ISOLATED)
|
||||
3. Global state (if isolation != ISOLATED)
|
||||
"""
|
||||
# Always check execution-local first
|
||||
if execution_id in self._execution_state:
|
||||
if key in self._execution_state[execution_id]:
|
||||
return self._execution_state[execution_id][key]
|
||||
|
||||
# Check stream-level (unless isolated)
|
||||
if isolation != IsolationLevel.ISOLATED:
|
||||
if stream_id in self._stream_state:
|
||||
if key in self._stream_state[stream_id]:
|
||||
return self._stream_state[stream_id][key]
|
||||
|
||||
# Check global
|
||||
if key in self._global_state:
|
||||
return self._global_state[key]
|
||||
|
||||
return None
|
||||
|
||||
async def write(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
execution_id: str,
|
||||
stream_id: str,
|
||||
isolation: IsolationLevel,
|
||||
scope: StateScope = StateScope.EXECUTION,
|
||||
) -> None:
|
||||
"""
|
||||
Write a value respecting isolation level.
|
||||
|
||||
Args:
|
||||
key: State key
|
||||
value: Value to write
|
||||
execution_id: Current execution
|
||||
stream_id: Current stream
|
||||
isolation: Isolation level
|
||||
scope: Where to write (execution, stream, or global)
|
||||
"""
|
||||
# Get old value for change tracking
|
||||
old_value = await self.read(key, execution_id, stream_id, isolation)
|
||||
|
||||
# ISOLATED can only write to execution scope
|
||||
if isolation == IsolationLevel.ISOLATED:
|
||||
scope = StateScope.EXECUTION
|
||||
|
||||
# SYNCHRONIZED requires locks for stream/global writes
|
||||
if isolation == IsolationLevel.SYNCHRONIZED and scope != StateScope.EXECUTION:
|
||||
await self._write_with_lock(key, value, execution_id, stream_id, scope)
|
||||
else:
|
||||
await self._write_direct(key, value, execution_id, stream_id, scope)
|
||||
|
||||
# Record change
|
||||
self._record_change(StateChange(
|
||||
key=key,
|
||||
old_value=old_value,
|
||||
new_value=value,
|
||||
scope=scope,
|
||||
execution_id=execution_id,
|
||||
stream_id=stream_id,
|
||||
))
|
||||
|
||||
async def _write_direct(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
execution_id: str,
|
||||
stream_id: str,
|
||||
scope: StateScope,
|
||||
) -> None:
|
||||
"""Write without locking (for ISOLATED and SHARED)."""
|
||||
if scope == StateScope.EXECUTION:
|
||||
if execution_id not in self._execution_state:
|
||||
self._execution_state[execution_id] = {}
|
||||
self._execution_state[execution_id][key] = value
|
||||
|
||||
elif scope == StateScope.STREAM:
|
||||
if stream_id not in self._stream_state:
|
||||
self._stream_state[stream_id] = {}
|
||||
self._stream_state[stream_id][key] = value
|
||||
|
||||
elif scope == StateScope.GLOBAL:
|
||||
self._global_state[key] = value
|
||||
|
||||
self._version += 1
|
||||
|
||||
async def _write_with_lock(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
execution_id: str,
|
||||
stream_id: str,
|
||||
scope: StateScope,
|
||||
) -> None:
|
||||
"""Write with locking (for SYNCHRONIZED)."""
|
||||
lock = self._get_lock(scope, key, stream_id)
|
||||
async with lock:
|
||||
await self._write_direct(key, value, execution_id, stream_id, scope)
|
||||
|
||||
def _get_lock(self, scope: StateScope, key: str, stream_id: str) -> asyncio.Lock:
|
||||
"""Get appropriate lock for scope and key."""
|
||||
if scope == StateScope.GLOBAL:
|
||||
lock_key = f"global:{key}"
|
||||
elif scope == StateScope.STREAM:
|
||||
lock_key = f"stream:{stream_id}:{key}"
|
||||
else:
|
||||
lock_key = f"exec:{key}"
|
||||
|
||||
if lock_key not in self._key_locks:
|
||||
self._key_locks[lock_key] = asyncio.Lock()
|
||||
|
||||
return self._key_locks[lock_key]
|
||||
|
||||
def _record_change(self, change: StateChange) -> None:
|
||||
"""Record a state change for auditing."""
|
||||
self._change_history.append(change)
|
||||
|
||||
# Trim history if too long
|
||||
if len(self._change_history) > self._max_history:
|
||||
self._change_history = self._change_history[-self._max_history:]
|
||||
|
||||
# === BULK OPERATIONS ===
|
||||
|
||||
async def read_all(
|
||||
self,
|
||||
execution_id: str,
|
||||
stream_id: str,
|
||||
isolation: IsolationLevel,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Read all visible state for an execution.
|
||||
|
||||
Returns merged state from all visible levels.
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# Start with global (if visible)
|
||||
if isolation != IsolationLevel.ISOLATED:
|
||||
result.update(self._global_state)
|
||||
|
||||
# Add stream state (overwrites global)
|
||||
if stream_id in self._stream_state:
|
||||
result.update(self._stream_state[stream_id])
|
||||
|
||||
# Add execution state (overwrites all)
|
||||
if execution_id in self._execution_state:
|
||||
result.update(self._execution_state[execution_id])
|
||||
|
||||
return result
|
||||
|
||||
async def write_batch(
|
||||
self,
|
||||
updates: dict[str, Any],
|
||||
execution_id: str,
|
||||
stream_id: str,
|
||||
isolation: IsolationLevel,
|
||||
scope: StateScope = StateScope.EXECUTION,
|
||||
) -> None:
|
||||
"""Write multiple values atomically."""
|
||||
for key, value in updates.items():
|
||||
await self.write(key, value, execution_id, stream_id, isolation, scope)
|
||||
|
||||
# === UTILITY ===
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get state manager statistics."""
|
||||
return {
|
||||
"global_keys": len(self._global_state),
|
||||
"stream_count": len(self._stream_state),
|
||||
"execution_count": len(self._execution_state),
|
||||
"total_changes": len(self._change_history),
|
||||
"version": self._version,
|
||||
}
|
||||
|
||||
def get_recent_changes(self, limit: int = 10) -> list[StateChange]:
|
||||
"""Get recent state changes."""
|
||||
return self._change_history[-limit:]
|
||||
|
||||
|
||||
class StreamMemory:
|
||||
"""
|
||||
Memory interface for a single execution.
|
||||
|
||||
Provides scoped access to shared state with proper isolation.
|
||||
Compatible with the existing SharedMemory interface where possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: SharedStateManager,
|
||||
execution_id: str,
|
||||
stream_id: str,
|
||||
isolation: IsolationLevel,
|
||||
):
|
||||
self._manager = manager
|
||||
self._execution_id = execution_id
|
||||
self._stream_id = stream_id
|
||||
self._isolation = isolation
|
||||
|
||||
# Permission model (optional, for node-level scoping)
|
||||
self._allowed_read: set[str] | None = None
|
||||
self._allowed_write: set[str] | None = None
|
||||
|
||||
def with_permissions(
|
||||
self,
|
||||
read_keys: list[str],
|
||||
write_keys: list[str],
|
||||
) -> "StreamMemory":
|
||||
"""
|
||||
Create a scoped view with read/write permissions.
|
||||
|
||||
Compatible with existing SharedMemory.with_permissions().
|
||||
"""
|
||||
scoped = StreamMemory(
|
||||
manager=self._manager,
|
||||
execution_id=self._execution_id,
|
||||
stream_id=self._stream_id,
|
||||
isolation=self._isolation,
|
||||
)
|
||||
scoped._allowed_read = set(read_keys)
|
||||
scoped._allowed_write = set(write_keys)
|
||||
return scoped
|
||||
|
||||
async def read(self, key: str) -> Any:
|
||||
"""Read a value from state."""
|
||||
# Check permissions
|
||||
if self._allowed_read is not None and key not in self._allowed_read:
|
||||
raise PermissionError(f"Not allowed to read key: {key}")
|
||||
|
||||
return await self._manager.read(
|
||||
key=key,
|
||||
execution_id=self._execution_id,
|
||||
stream_id=self._stream_id,
|
||||
isolation=self._isolation,
|
||||
)
|
||||
|
||||
async def write(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
scope: StateScope = StateScope.EXECUTION,
|
||||
) -> None:
|
||||
"""Write a value to state."""
|
||||
# Check permissions
|
||||
if self._allowed_write is not None and key not in self._allowed_write:
|
||||
raise PermissionError(f"Not allowed to write key: {key}")
|
||||
|
||||
await self._manager.write(
|
||||
key=key,
|
||||
value=value,
|
||||
execution_id=self._execution_id,
|
||||
stream_id=self._stream_id,
|
||||
isolation=self._isolation,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
async def read_all(self) -> dict[str, Any]:
|
||||
"""Read all visible state."""
|
||||
all_state = await self._manager.read_all(
|
||||
execution_id=self._execution_id,
|
||||
stream_id=self._stream_id,
|
||||
isolation=self._isolation,
|
||||
)
|
||||
|
||||
# Filter by permissions if set
|
||||
if self._allowed_read is not None:
|
||||
return {k: v for k, v in all_state.items() if k in self._allowed_read}
|
||||
|
||||
return all_state
|
||||
|
||||
# === SYNC API (for backward compatibility with SharedMemory) ===
|
||||
|
||||
def read_sync(self, key: str) -> Any:
|
||||
"""
|
||||
Synchronous read (for compatibility with existing code).
|
||||
|
||||
Note: This runs the async operation in a new event loop
|
||||
or uses direct access if no loop is running.
|
||||
"""
|
||||
# Direct access for sync usage
|
||||
if self._allowed_read is not None and key not in self._allowed_read:
|
||||
raise PermissionError(f"Not allowed to read key: {key}")
|
||||
|
||||
# Check execution state
|
||||
exec_state = self._manager._execution_state.get(self._execution_id, {})
|
||||
if key in exec_state:
|
||||
return exec_state[key]
|
||||
|
||||
# Check stream/global if not isolated
|
||||
if self._isolation != IsolationLevel.ISOLATED:
|
||||
stream_state = self._manager._stream_state.get(self._stream_id, {})
|
||||
if key in stream_state:
|
||||
return stream_state[key]
|
||||
|
||||
if key in self._manager._global_state:
|
||||
return self._manager._global_state[key]
|
||||
|
||||
return None
|
||||
|
||||
def write_sync(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Synchronous write (for compatibility with existing code).
|
||||
|
||||
Always writes to execution scope for simplicity.
|
||||
"""
|
||||
if self._allowed_write is not None and key not in self._allowed_write:
|
||||
raise PermissionError(f"Not allowed to write key: {key}")
|
||||
|
||||
if self._execution_id not in self._manager._execution_state:
|
||||
self._manager._execution_state[self._execution_id] = {}
|
||||
|
||||
self._manager._execution_state[self._execution_id][key] = value
|
||||
self._manager._version += 1
|
||||
|
||||
def read_all_sync(self) -> dict[str, Any]:
|
||||
"""Synchronous read all."""
|
||||
result = {}
|
||||
|
||||
# Global (if visible)
|
||||
if self._isolation != IsolationLevel.ISOLATED:
|
||||
result.update(self._manager._global_state)
|
||||
if self._stream_id in self._manager._stream_state:
|
||||
result.update(self._manager._stream_state[self._stream_id])
|
||||
|
||||
# Execution
|
||||
if self._execution_id in self._manager._execution_state:
|
||||
result.update(self._manager._execution_state[self._execution_id])
|
||||
|
||||
# Filter by permissions
|
||||
if self._allowed_read is not None:
|
||||
result = {k: v for k, v in result.items() if k in self._allowed_read}
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Stream Runtime - Thread-safe runtime for concurrent executions.
|
||||
|
||||
Unlike the original Runtime which has a single _current_run,
|
||||
StreamRuntime tracks runs by execution_id, allowing concurrent
|
||||
executions within the same stream without collision.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from framework.schemas.decision import Decision, Option, Outcome, DecisionType
|
||||
from framework.schemas.run import Run, RunStatus
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamRuntime:
|
||||
"""
|
||||
Thread-safe runtime for a single execution stream.
|
||||
|
||||
Key differences from Runtime:
|
||||
- Tracks multiple runs concurrently via execution_id
|
||||
- Uses ConcurrentStorage for thread-safe persistence
|
||||
- Reports decisions to OutcomeAggregator for cross-stream evaluation
|
||||
|
||||
Example:
|
||||
runtime = StreamRuntime(
|
||||
stream_id="webhook",
|
||||
storage=concurrent_storage,
|
||||
outcome_aggregator=aggregator,
|
||||
)
|
||||
|
||||
# Start a run for a specific execution
|
||||
run_id = runtime.start_run(
|
||||
execution_id="exec_123",
|
||||
goal_id="support-goal",
|
||||
goal_description="Handle support tickets",
|
||||
)
|
||||
|
||||
# Record decisions (thread-safe)
|
||||
decision_id = runtime.decide(
|
||||
execution_id="exec_123",
|
||||
intent="Classify ticket",
|
||||
options=[...],
|
||||
chosen="howto",
|
||||
reasoning="Question matches how-to pattern",
|
||||
)
|
||||
|
||||
# Record outcome
|
||||
runtime.record_outcome(
|
||||
execution_id="exec_123",
|
||||
decision_id=decision_id,
|
||||
success=True,
|
||||
result={"category": "howto"},
|
||||
)
|
||||
|
||||
# End run
|
||||
runtime.end_run(
|
||||
execution_id="exec_123",
|
||||
success=True,
|
||||
narrative="Ticket resolved",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
storage: ConcurrentStorage,
|
||||
outcome_aggregator: "OutcomeAggregator | None" = None,
|
||||
):
|
||||
"""
|
||||
Initialize stream runtime.
|
||||
|
||||
Args:
|
||||
stream_id: Unique identifier for this stream
|
||||
storage: Concurrent storage backend
|
||||
outcome_aggregator: Optional aggregator for cross-stream evaluation
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self._storage = storage
|
||||
self._outcome_aggregator = outcome_aggregator
|
||||
|
||||
# Track runs by execution_id (thread-safe via lock)
|
||||
self._runs: dict[str, Run] = {}
|
||||
self._run_locks: dict[str, asyncio.Lock] = {}
|
||||
self._global_lock = asyncio.Lock()
|
||||
|
||||
# Track current node per execution (for decision context)
|
||||
self._current_nodes: dict[str, str] = {}
|
||||
|
||||
# === RUN LIFECYCLE ===
|
||||
|
||||
def start_run(
|
||||
self,
|
||||
execution_id: str,
|
||||
goal_id: str,
|
||||
goal_description: str = "",
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Start a new run for an execution.
|
||||
|
||||
Args:
|
||||
execution_id: Unique execution identifier
|
||||
goal_id: The ID of the goal being pursued
|
||||
goal_description: Human-readable description of the goal
|
||||
input_data: Initial input to the run
|
||||
|
||||
Returns:
|
||||
The run ID
|
||||
"""
|
||||
run_id = f"run_{self.stream_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
run = Run(
|
||||
id=run_id,
|
||||
goal_id=goal_id,
|
||||
goal_description=goal_description,
|
||||
input_data=input_data or {},
|
||||
)
|
||||
|
||||
self._runs[execution_id] = run
|
||||
self._run_locks[execution_id] = asyncio.Lock()
|
||||
self._current_nodes[execution_id] = "unknown"
|
||||
|
||||
logger.debug(f"Started run {run_id} for execution {execution_id} in stream {self.stream_id}")
|
||||
return run_id
|
||||
|
||||
def end_run(
|
||||
self,
|
||||
execution_id: str,
|
||||
success: bool,
|
||||
narrative: str = "",
|
||||
output_data: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
End a run for an execution.
|
||||
|
||||
Args:
|
||||
execution_id: Execution identifier
|
||||
success: Whether the run achieved its goal
|
||||
narrative: Human-readable summary of what happened
|
||||
output_data: Final output of the run
|
||||
"""
|
||||
run = self._runs.get(execution_id)
|
||||
if run is None:
|
||||
logger.warning(f"end_run called but no run for execution {execution_id}")
|
||||
return
|
||||
|
||||
status = RunStatus.COMPLETED if success else RunStatus.FAILED
|
||||
run.output_data = output_data or {}
|
||||
run.complete(status, narrative)
|
||||
|
||||
# Save to storage asynchronously
|
||||
asyncio.create_task(self._save_run(execution_id, run))
|
||||
|
||||
logger.debug(f"Ended run {run.id} for execution {execution_id}: {status.value}")
|
||||
|
||||
async def _save_run(self, execution_id: str, run: Run) -> None:
|
||||
"""Save run to storage and clean up."""
|
||||
try:
|
||||
await self._storage.save_run(run)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save run {run.id}: {e}")
|
||||
finally:
|
||||
# Clean up
|
||||
self._runs.pop(execution_id, None)
|
||||
self._run_locks.pop(execution_id, None)
|
||||
self._current_nodes.pop(execution_id, None)
|
||||
|
||||
def set_node(self, execution_id: str, node_id: str) -> None:
|
||||
"""Set the current node context for an execution."""
|
||||
self._current_nodes[execution_id] = node_id
|
||||
|
||||
def get_run(self, execution_id: str) -> Run | None:
|
||||
"""Get the current run for an execution."""
|
||||
return self._runs.get(execution_id)
|
||||
|
||||
# === DECISION RECORDING ===
|
||||
|
||||
def decide(
|
||||
self,
|
||||
execution_id: str,
|
||||
intent: str,
|
||||
options: list[dict[str, Any]],
|
||||
chosen: str,
|
||||
reasoning: str,
|
||||
node_id: str | None = None,
|
||||
decision_type: DecisionType = DecisionType.CUSTOM,
|
||||
constraints: list[str] | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Record a decision for a specific execution.
|
||||
|
||||
Thread-safe: Multiple executions can record decisions concurrently.
|
||||
|
||||
Args:
|
||||
execution_id: Which execution is making this decision
|
||||
intent: What the agent was trying to accomplish
|
||||
options: List of options considered
|
||||
chosen: ID of the chosen option
|
||||
reasoning: Why the agent chose this option
|
||||
node_id: Which node made this decision
|
||||
decision_type: Type of decision
|
||||
constraints: Active constraints that influenced the decision
|
||||
context: Additional context available when deciding
|
||||
|
||||
Returns:
|
||||
The decision ID, or empty string if no run in progress
|
||||
"""
|
||||
run = self._runs.get(execution_id)
|
||||
if run is None:
|
||||
logger.warning(f"decide called but no run for execution {execution_id}: {intent}")
|
||||
return ""
|
||||
|
||||
# Build Option objects
|
||||
option_objects = []
|
||||
for opt in options:
|
||||
option_objects.append(Option(
|
||||
id=opt["id"],
|
||||
description=opt.get("description", ""),
|
||||
action_type=opt.get("action_type", "unknown"),
|
||||
action_params=opt.get("action_params", {}),
|
||||
pros=opt.get("pros", []),
|
||||
cons=opt.get("cons", []),
|
||||
confidence=opt.get("confidence", 0.5),
|
||||
))
|
||||
|
||||
# Create decision
|
||||
decision_id = f"dec_{len(run.decisions)}"
|
||||
current_node = node_id or self._current_nodes.get(execution_id, "unknown")
|
||||
|
||||
decision = Decision(
|
||||
id=decision_id,
|
||||
node_id=current_node,
|
||||
intent=intent,
|
||||
decision_type=decision_type,
|
||||
options=option_objects,
|
||||
chosen_option_id=chosen,
|
||||
reasoning=reasoning,
|
||||
active_constraints=constraints or [],
|
||||
input_context=context or {},
|
||||
)
|
||||
|
||||
run.add_decision(decision)
|
||||
|
||||
# Report to outcome aggregator if available
|
||||
if self._outcome_aggregator:
|
||||
self._outcome_aggregator.record_decision(
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
decision=decision,
|
||||
)
|
||||
|
||||
return decision_id
|
||||
|
||||
def record_outcome(
|
||||
self,
|
||||
execution_id: str,
|
||||
decision_id: str,
|
||||
success: bool,
|
||||
result: Any = None,
|
||||
error: str | None = None,
|
||||
summary: str = "",
|
||||
state_changes: dict[str, Any] | None = None,
|
||||
tokens_used: int = 0,
|
||||
latency_ms: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Record the outcome of a decision.
|
||||
|
||||
Args:
|
||||
execution_id: Which execution
|
||||
decision_id: ID returned from decide()
|
||||
success: Whether the action succeeded
|
||||
result: The actual result/output
|
||||
error: Error message if failed
|
||||
summary: Human-readable summary of what happened
|
||||
state_changes: What state changed as a result
|
||||
tokens_used: LLM tokens consumed
|
||||
latency_ms: Time taken in milliseconds
|
||||
"""
|
||||
run = self._runs.get(execution_id)
|
||||
if run is None:
|
||||
logger.warning(f"record_outcome called but no run for execution {execution_id}")
|
||||
return
|
||||
|
||||
outcome = Outcome(
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
summary=summary,
|
||||
state_changes=state_changes or {},
|
||||
tokens_used=tokens_used,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
run.record_outcome(decision_id, outcome)
|
||||
|
||||
# Report to outcome aggregator if available
|
||||
if self._outcome_aggregator:
|
||||
self._outcome_aggregator.record_outcome(
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
decision_id=decision_id,
|
||||
outcome=outcome,
|
||||
)
|
||||
|
||||
# === PROBLEM RECORDING ===
|
||||
|
||||
def report_problem(
|
||||
self,
|
||||
execution_id: str,
|
||||
severity: str,
|
||||
description: str,
|
||||
decision_id: str | None = None,
|
||||
root_cause: str | None = None,
|
||||
suggested_fix: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Report a problem that occurred during an execution.
|
||||
|
||||
Args:
|
||||
execution_id: Which execution
|
||||
severity: "critical", "warning", or "minor"
|
||||
description: What went wrong
|
||||
decision_id: Which decision caused this (if known)
|
||||
root_cause: Why it went wrong (if known)
|
||||
suggested_fix: What might fix it (if known)
|
||||
|
||||
Returns:
|
||||
The problem ID, or empty string if no run in progress
|
||||
"""
|
||||
run = self._runs.get(execution_id)
|
||||
if run is None:
|
||||
logger.warning(f"report_problem called but no run for execution {execution_id}: [{severity}] {description}")
|
||||
return ""
|
||||
|
||||
return run.add_problem(
|
||||
severity=severity,
|
||||
description=description,
|
||||
decision_id=decision_id,
|
||||
root_cause=root_cause,
|
||||
suggested_fix=suggested_fix,
|
||||
)
|
||||
|
||||
# === CONVENIENCE METHODS ===
|
||||
|
||||
def quick_decision(
|
||||
self,
|
||||
execution_id: str,
|
||||
intent: str,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
node_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Record a simple decision with a single action.
|
||||
|
||||
Args:
|
||||
execution_id: Which execution
|
||||
intent: What the agent is trying to do
|
||||
action: What it's doing
|
||||
reasoning: Why
|
||||
|
||||
Returns:
|
||||
The decision ID
|
||||
"""
|
||||
return self.decide(
|
||||
execution_id=execution_id,
|
||||
intent=intent,
|
||||
options=[{
|
||||
"id": "action",
|
||||
"description": action,
|
||||
"action_type": "execute",
|
||||
}],
|
||||
chosen="action",
|
||||
reasoning=reasoning,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# === STATS AND MONITORING ===
|
||||
|
||||
def get_active_executions(self) -> list[str]:
|
||||
"""Get list of active execution IDs."""
|
||||
return list(self._runs.keys())
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get runtime statistics."""
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"active_executions": len(self._runs),
|
||||
"execution_ids": list(self._runs.keys()),
|
||||
}
|
||||
|
||||
|
||||
class StreamRuntimeAdapter:
|
||||
"""
|
||||
Adapter to make StreamRuntime compatible with existing Runtime interface.
|
||||
|
||||
This allows StreamRuntime to be used with existing GraphExecutor code
|
||||
by providing the same API as Runtime but routing to a specific execution.
|
||||
"""
|
||||
|
||||
def __init__(self, stream_runtime: StreamRuntime, execution_id: str):
|
||||
"""
|
||||
Create adapter for a specific execution.
|
||||
|
||||
Args:
|
||||
stream_runtime: The underlying stream runtime
|
||||
execution_id: Which execution this adapter is for
|
||||
"""
|
||||
self._runtime = stream_runtime
|
||||
self._execution_id = execution_id
|
||||
self._current_node = "unknown"
|
||||
|
||||
# Expose storage for compatibility
|
||||
@property
|
||||
def storage(self):
|
||||
return self._runtime._storage
|
||||
|
||||
@property
|
||||
def current_run(self) -> Run | None:
|
||||
return self._runtime.get_run(self._execution_id)
|
||||
|
||||
def start_run(
|
||||
self,
|
||||
goal_id: str,
|
||||
goal_description: str = "",
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
return self._runtime.start_run(
|
||||
execution_id=self._execution_id,
|
||||
goal_id=goal_id,
|
||||
goal_description=goal_description,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
def end_run(
|
||||
self,
|
||||
success: bool,
|
||||
narrative: str = "",
|
||||
output_data: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._runtime.end_run(
|
||||
execution_id=self._execution_id,
|
||||
success=success,
|
||||
narrative=narrative,
|
||||
output_data=output_data,
|
||||
)
|
||||
|
||||
def set_node(self, node_id: str) -> None:
|
||||
self._current_node = node_id
|
||||
self._runtime.set_node(self._execution_id, node_id)
|
||||
|
||||
def decide(
|
||||
self,
|
||||
intent: str,
|
||||
options: list[dict[str, Any]],
|
||||
chosen: str,
|
||||
reasoning: str,
|
||||
node_id: str | None = None,
|
||||
decision_type: DecisionType = DecisionType.CUSTOM,
|
||||
constraints: list[str] | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
return self._runtime.decide(
|
||||
execution_id=self._execution_id,
|
||||
intent=intent,
|
||||
options=options,
|
||||
chosen=chosen,
|
||||
reasoning=reasoning,
|
||||
node_id=node_id or self._current_node,
|
||||
decision_type=decision_type,
|
||||
constraints=constraints,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def record_outcome(
|
||||
self,
|
||||
decision_id: str,
|
||||
success: bool,
|
||||
result: Any = None,
|
||||
error: str | None = None,
|
||||
summary: str = "",
|
||||
state_changes: dict[str, Any] | None = None,
|
||||
tokens_used: int = 0,
|
||||
latency_ms: int = 0,
|
||||
) -> None:
|
||||
self._runtime.record_outcome(
|
||||
execution_id=self._execution_id,
|
||||
decision_id=decision_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
summary=summary,
|
||||
state_changes=state_changes,
|
||||
tokens_used=tokens_used,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
def report_problem(
|
||||
self,
|
||||
severity: str,
|
||||
description: str,
|
||||
decision_id: str | None = None,
|
||||
root_cause: str | None = None,
|
||||
suggested_fix: str | None = None,
|
||||
) -> str:
|
||||
return self._runtime.report_problem(
|
||||
execution_id=self._execution_id,
|
||||
severity=severity,
|
||||
description=description,
|
||||
decision_id=decision_id,
|
||||
root_cause=root_cause,
|
||||
suggested_fix=suggested_fix,
|
||||
)
|
||||
|
||||
def quick_decision(
|
||||
self,
|
||||
intent: str,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
node_id: str | None = None,
|
||||
) -> str:
|
||||
return self._runtime.quick_decision(
|
||||
execution_id=self._execution_id,
|
||||
intent=intent,
|
||||
action=action,
|
||||
reasoning=reasoning,
|
||||
node_id=node_id or self._current_node,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for runtime components."""
|
||||
@@ -0,0 +1,631 @@
|
||||
"""
|
||||
Tests for AgentRuntime and multi-entry-point execution.
|
||||
|
||||
Tests:
|
||||
1. AgentRuntime creation and lifecycle
|
||||
2. Entry point registration
|
||||
3. Concurrent executions across streams
|
||||
4. SharedStateManager isolation levels
|
||||
5. OutcomeAggregator goal evaluation
|
||||
6. EventBus pub/sub
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import Goal
|
||||
from framework.graph.goal import SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition, AsyncEntryPointSpec
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.shared_state import SharedStateManager, IsolationLevel
|
||||
from framework.runtime.event_bus import EventBus, EventType, AgentEvent
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.stream_runtime import StreamRuntime
|
||||
|
||||
|
||||
# === Test Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def sample_goal():
|
||||
"""Create a sample goal for testing."""
|
||||
return Goal(
|
||||
id="test-goal",
|
||||
name="Test Goal",
|
||||
description="A goal for testing multi-entry-point execution",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="sc-1",
|
||||
description="Process all requests",
|
||||
metric="requests_processed",
|
||||
target="100%",
|
||||
weight=1.0,
|
||||
),
|
||||
],
|
||||
constraints=[
|
||||
Constraint(
|
||||
id="c-1",
|
||||
description="Must not exceed rate limits",
|
||||
constraint_type="hard",
|
||||
category="operational",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_graph():
|
||||
"""Create a sample graph with multiple entry points."""
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="process-webhook",
|
||||
name="Process Webhook",
|
||||
description="Process incoming webhook",
|
||||
node_type="llm_generate",
|
||||
input_keys=["webhook_data"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="process-api",
|
||||
name="Process API Request",
|
||||
description="Process API request",
|
||||
node_type="llm_generate",
|
||||
input_keys=["request_data"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="complete",
|
||||
name="Complete",
|
||||
description="Execution complete",
|
||||
node_type="terminal",
|
||||
input_keys=["result"],
|
||||
output_keys=["final_result"],
|
||||
),
|
||||
]
|
||||
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="webhook-to-complete",
|
||||
source="process-webhook",
|
||||
target="complete",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="api-to-complete",
|
||||
source="process-api",
|
||||
target="complete",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
]
|
||||
|
||||
async_entry_points = [
|
||||
AsyncEntryPointSpec(
|
||||
id="webhook",
|
||||
name="Webhook Handler",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
isolation_level="shared",
|
||||
),
|
||||
AsyncEntryPointSpec(
|
||||
id="api",
|
||||
name="API Handler",
|
||||
entry_node="process-api",
|
||||
trigger_type="api",
|
||||
isolation_level="shared",
|
||||
),
|
||||
]
|
||||
|
||||
return GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id="test-goal",
|
||||
version="1.0.0",
|
||||
entry_node="process-webhook",
|
||||
entry_points={"start": "process-webhook"},
|
||||
async_entry_points=async_entry_points,
|
||||
terminal_nodes=["complete"],
|
||||
pause_nodes=[],
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage():
|
||||
"""Create a temporary storage directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
# === SharedStateManager Tests ===
|
||||
|
||||
class TestSharedStateManager:
|
||||
"""Tests for SharedStateManager."""
|
||||
|
||||
def test_create_memory(self):
|
||||
"""Test creating execution-scoped memory."""
|
||||
manager = SharedStateManager()
|
||||
memory = manager.create_memory(
|
||||
execution_id="exec-1",
|
||||
stream_id="webhook",
|
||||
isolation=IsolationLevel.SHARED,
|
||||
)
|
||||
assert memory is not None
|
||||
assert memory._execution_id == "exec-1"
|
||||
assert memory._stream_id == "webhook"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_isolated_state(self):
|
||||
"""Test isolated state doesn't leak between executions."""
|
||||
manager = SharedStateManager()
|
||||
|
||||
mem1 = manager.create_memory("exec-1", "stream-1", IsolationLevel.ISOLATED)
|
||||
mem2 = manager.create_memory("exec-2", "stream-1", IsolationLevel.ISOLATED)
|
||||
|
||||
await mem1.write("key", "value1")
|
||||
await mem2.write("key", "value2")
|
||||
|
||||
assert await mem1.read("key") == "value1"
|
||||
assert await mem2.read("key") == "value2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_state(self):
|
||||
"""Test shared state is visible across executions."""
|
||||
manager = SharedStateManager()
|
||||
|
||||
mem1 = manager.create_memory("exec-1", "stream-1", IsolationLevel.SHARED)
|
||||
mem2 = manager.create_memory("exec-2", "stream-1", IsolationLevel.SHARED)
|
||||
|
||||
# Write to global scope
|
||||
await manager.write(
|
||||
key="global_key",
|
||||
value="global_value",
|
||||
execution_id="exec-1",
|
||||
stream_id="stream-1",
|
||||
isolation=IsolationLevel.SHARED,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
# Both should see it
|
||||
value1 = await manager.read("global_key", "exec-1", "stream-1", IsolationLevel.SHARED)
|
||||
value2 = await manager.read("global_key", "exec-2", "stream-1", IsolationLevel.SHARED)
|
||||
|
||||
assert value1 == "global_value"
|
||||
assert value2 == "global_value"
|
||||
|
||||
def test_cleanup_execution(self):
|
||||
"""Test execution cleanup removes state."""
|
||||
manager = SharedStateManager()
|
||||
manager.create_memory("exec-1", "stream-1", IsolationLevel.ISOLATED)
|
||||
|
||||
assert "exec-1" in manager._execution_state
|
||||
|
||||
manager.cleanup_execution("exec-1")
|
||||
|
||||
assert "exec-1" not in manager._execution_state
|
||||
|
||||
|
||||
# === EventBus Tests ===
|
||||
|
||||
class TestEventBus:
|
||||
"""Tests for EventBus pub/sub."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_subscribe(self):
|
||||
"""Test basic publish/subscribe."""
|
||||
bus = EventBus()
|
||||
received_events = []
|
||||
|
||||
async def handler(event: AgentEvent):
|
||||
received_events.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_STARTED],
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="webhook",
|
||||
execution_id="exec-1",
|
||||
data={"test": "data"},
|
||||
))
|
||||
|
||||
# Allow handler to run
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].type == EventType.EXECUTION_STARTED
|
||||
assert received_events[0].stream_id == "webhook"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_filter(self):
|
||||
"""Test filtering by stream ID."""
|
||||
bus = EventBus()
|
||||
received_events = []
|
||||
|
||||
async def handler(event: AgentEvent):
|
||||
received_events.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_STARTED],
|
||||
handler=handler,
|
||||
filter_stream="webhook",
|
||||
)
|
||||
|
||||
# Publish to webhook stream (should be received)
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="webhook",
|
||||
))
|
||||
|
||||
# Publish to api stream (should NOT be received)
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="api",
|
||||
))
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].stream_id == "webhook"
|
||||
|
||||
def test_unsubscribe(self):
|
||||
"""Test unsubscribing from events."""
|
||||
bus = EventBus()
|
||||
|
||||
async def handler(event: AgentEvent):
|
||||
pass
|
||||
|
||||
sub_id = bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_STARTED],
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
assert sub_id in bus._subscriptions
|
||||
|
||||
result = bus.unsubscribe(sub_id)
|
||||
|
||||
assert result is True
|
||||
assert sub_id not in bus._subscriptions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for(self):
|
||||
"""Test waiting for a specific event."""
|
||||
bus = EventBus()
|
||||
|
||||
# Start waiting in background
|
||||
async def wait_and_check():
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.EXECUTION_COMPLETED,
|
||||
timeout=1.0,
|
||||
)
|
||||
return event
|
||||
|
||||
wait_task = asyncio.create_task(wait_and_check())
|
||||
|
||||
# Publish the event
|
||||
await asyncio.sleep(0.1)
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="webhook",
|
||||
execution_id="exec-1",
|
||||
))
|
||||
|
||||
event = await wait_task
|
||||
|
||||
assert event is not None
|
||||
assert event.type == EventType.EXECUTION_COMPLETED
|
||||
|
||||
|
||||
# === OutcomeAggregator Tests ===
|
||||
|
||||
class TestOutcomeAggregator:
|
||||
"""Tests for OutcomeAggregator."""
|
||||
|
||||
def test_record_decision(self, sample_goal):
|
||||
"""Test recording decisions."""
|
||||
aggregator = OutcomeAggregator(sample_goal)
|
||||
|
||||
from framework.schemas.decision import Decision, DecisionType
|
||||
|
||||
decision = Decision(
|
||||
id="dec-1",
|
||||
node_id="process-webhook",
|
||||
intent="Process incoming webhook",
|
||||
decision_type=DecisionType.PATH_CHOICE,
|
||||
options=[],
|
||||
chosen_option_id="opt-1",
|
||||
reasoning="Standard processing path",
|
||||
)
|
||||
|
||||
aggregator.record_decision("webhook", "exec-1", decision)
|
||||
|
||||
assert aggregator._total_decisions == 1
|
||||
assert len(aggregator._decisions) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_goal_progress(self, sample_goal):
|
||||
"""Test goal progress evaluation."""
|
||||
aggregator = OutcomeAggregator(sample_goal)
|
||||
|
||||
progress = await aggregator.evaluate_goal_progress()
|
||||
|
||||
assert "overall_progress" in progress
|
||||
assert "criteria_status" in progress
|
||||
assert "constraint_violations" in progress
|
||||
assert "recommendation" in progress
|
||||
|
||||
def test_record_constraint_violation(self, sample_goal):
|
||||
"""Test recording constraint violations."""
|
||||
aggregator = OutcomeAggregator(sample_goal)
|
||||
|
||||
aggregator.record_constraint_violation(
|
||||
constraint_id="c-1",
|
||||
description="Rate limit exceeded",
|
||||
violation_details="More than 100 requests/minute",
|
||||
stream_id="webhook",
|
||||
execution_id="exec-1",
|
||||
)
|
||||
|
||||
assert len(aggregator._constraint_violations) == 1
|
||||
assert aggregator._constraint_violations[0].constraint_id == "c-1"
|
||||
|
||||
|
||||
# === AgentRuntime Tests ===
|
||||
|
||||
class TestAgentRuntime:
|
||||
"""Tests for AgentRuntime orchestration."""
|
||||
|
||||
def test_register_entry_point(self, sample_graph, sample_goal, temp_storage):
|
||||
"""Test registering entry points."""
|
||||
runtime = AgentRuntime(
|
||||
graph=sample_graph,
|
||||
goal=sample_goal,
|
||||
storage_path=temp_storage,
|
||||
)
|
||||
|
||||
entry_spec = EntryPointSpec(
|
||||
id="manual",
|
||||
name="Manual Trigger",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="manual",
|
||||
)
|
||||
|
||||
runtime.register_entry_point(entry_spec)
|
||||
|
||||
assert "manual" in runtime._entry_points
|
||||
assert len(runtime.get_entry_points()) == 1
|
||||
|
||||
def test_register_duplicate_entry_point_fails(self, sample_graph, sample_goal, temp_storage):
|
||||
"""Test that duplicate entry point IDs fail."""
|
||||
runtime = AgentRuntime(
|
||||
graph=sample_graph,
|
||||
goal=sample_goal,
|
||||
storage_path=temp_storage,
|
||||
)
|
||||
|
||||
entry_spec = EntryPointSpec(
|
||||
id="webhook",
|
||||
name="Webhook Handler",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
)
|
||||
|
||||
runtime.register_entry_point(entry_spec)
|
||||
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
runtime.register_entry_point(entry_spec)
|
||||
|
||||
def test_register_invalid_entry_node_fails(self, sample_graph, sample_goal, temp_storage):
|
||||
"""Test that invalid entry nodes fail."""
|
||||
runtime = AgentRuntime(
|
||||
graph=sample_graph,
|
||||
goal=sample_goal,
|
||||
storage_path=temp_storage,
|
||||
)
|
||||
|
||||
entry_spec = EntryPointSpec(
|
||||
id="invalid",
|
||||
name="Invalid Entry",
|
||||
entry_node="nonexistent-node",
|
||||
trigger_type="manual",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found in graph"):
|
||||
runtime.register_entry_point(entry_spec)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop_lifecycle(self, sample_graph, sample_goal, temp_storage):
|
||||
"""Test runtime start/stop lifecycle."""
|
||||
runtime = AgentRuntime(
|
||||
graph=sample_graph,
|
||||
goal=sample_goal,
|
||||
storage_path=temp_storage,
|
||||
)
|
||||
|
||||
entry_spec = EntryPointSpec(
|
||||
id="webhook",
|
||||
name="Webhook Handler",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
)
|
||||
|
||||
runtime.register_entry_point(entry_spec)
|
||||
|
||||
assert not runtime.is_running
|
||||
|
||||
await runtime.start()
|
||||
|
||||
assert runtime.is_running
|
||||
assert "webhook" in runtime._streams
|
||||
|
||||
await runtime.stop()
|
||||
|
||||
assert not runtime.is_running
|
||||
assert len(runtime._streams) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_requires_running(self, sample_graph, sample_goal, temp_storage):
|
||||
"""Test that trigger fails if runtime not running."""
|
||||
runtime = AgentRuntime(
|
||||
graph=sample_graph,
|
||||
goal=sample_goal,
|
||||
storage_path=temp_storage,
|
||||
)
|
||||
|
||||
entry_spec = EntryPointSpec(
|
||||
id="webhook",
|
||||
name="Webhook Handler",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
)
|
||||
|
||||
runtime.register_entry_point(entry_spec)
|
||||
|
||||
with pytest.raises(RuntimeError, match="not running"):
|
||||
await runtime.trigger("webhook", {"test": "data"})
|
||||
|
||||
|
||||
# === GraphSpec Validation Tests ===
|
||||
|
||||
class TestGraphSpecValidation:
|
||||
"""Tests for GraphSpec with async_entry_points."""
|
||||
|
||||
def test_has_async_entry_points(self, sample_graph):
|
||||
"""Test checking for async entry points."""
|
||||
assert sample_graph.has_async_entry_points() is True
|
||||
|
||||
# Graph without async entry points
|
||||
simple_graph = GraphSpec(
|
||||
id="simple",
|
||||
goal_id="goal",
|
||||
entry_node="start",
|
||||
nodes=[],
|
||||
edges=[],
|
||||
)
|
||||
assert simple_graph.has_async_entry_points() is False
|
||||
|
||||
def test_get_async_entry_point(self, sample_graph):
|
||||
"""Test getting async entry point by ID."""
|
||||
ep = sample_graph.get_async_entry_point("webhook")
|
||||
assert ep is not None
|
||||
assert ep.id == "webhook"
|
||||
assert ep.entry_node == "process-webhook"
|
||||
|
||||
ep_not_found = sample_graph.get_async_entry_point("nonexistent")
|
||||
assert ep_not_found is None
|
||||
|
||||
def test_validate_async_entry_points(self):
|
||||
"""Test validation catches async entry point errors."""
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="valid-node",
|
||||
name="Valid Node",
|
||||
description="A valid node",
|
||||
node_type="llm_generate",
|
||||
input_keys=[],
|
||||
output_keys=[],
|
||||
),
|
||||
]
|
||||
|
||||
# Invalid entry node
|
||||
graph = GraphSpec(
|
||||
id="test",
|
||||
goal_id="goal",
|
||||
entry_node="valid-node",
|
||||
async_entry_points=[
|
||||
AsyncEntryPointSpec(
|
||||
id="invalid",
|
||||
name="Invalid",
|
||||
entry_node="nonexistent-node",
|
||||
trigger_type="webhook",
|
||||
),
|
||||
],
|
||||
nodes=nodes,
|
||||
edges=[],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
assert any("nonexistent-node" in e for e in errors)
|
||||
|
||||
# Invalid isolation level
|
||||
graph2 = GraphSpec(
|
||||
id="test",
|
||||
goal_id="goal",
|
||||
entry_node="valid-node",
|
||||
async_entry_points=[
|
||||
AsyncEntryPointSpec(
|
||||
id="bad-isolation",
|
||||
name="Bad Isolation",
|
||||
entry_node="valid-node",
|
||||
trigger_type="webhook",
|
||||
isolation_level="invalid",
|
||||
),
|
||||
],
|
||||
nodes=nodes,
|
||||
edges=[],
|
||||
)
|
||||
|
||||
errors2 = graph2.validate()
|
||||
assert any("isolation_level" in e for e in errors2)
|
||||
|
||||
# Invalid trigger type
|
||||
graph3 = GraphSpec(
|
||||
id="test",
|
||||
goal_id="goal",
|
||||
entry_node="valid-node",
|
||||
async_entry_points=[
|
||||
AsyncEntryPointSpec(
|
||||
id="bad-trigger",
|
||||
name="Bad Trigger",
|
||||
entry_node="valid-node",
|
||||
trigger_type="invalid_trigger",
|
||||
),
|
||||
],
|
||||
nodes=nodes,
|
||||
edges=[],
|
||||
)
|
||||
|
||||
errors3 = graph3.validate()
|
||||
assert any("trigger_type" in e for e in errors3)
|
||||
|
||||
|
||||
# === Integration Tests ===
|
||||
|
||||
class TestCreateAgentRuntime:
|
||||
"""Tests for the create_agent_runtime factory."""
|
||||
|
||||
def test_create_with_entry_points(self, sample_graph, sample_goal, temp_storage):
|
||||
"""Test factory creates runtime with entry points."""
|
||||
entry_points = [
|
||||
EntryPointSpec(
|
||||
id="webhook",
|
||||
name="Webhook",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
),
|
||||
EntryPointSpec(
|
||||
id="api",
|
||||
name="API",
|
||||
entry_node="process-api",
|
||||
trigger_type="api",
|
||||
),
|
||||
]
|
||||
|
||||
runtime = create_agent_runtime(
|
||||
graph=sample_graph,
|
||||
goal=sample_goal,
|
||||
storage_path=temp_storage,
|
||||
entry_points=entry_points,
|
||||
)
|
||||
|
||||
assert len(runtime.get_entry_points()) == 2
|
||||
assert "webhook" in runtime._entry_points
|
||||
assert "api" in runtime._entry_points
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Concurrent Storage - Thread-safe storage backend with file locking.
|
||||
|
||||
Wraps FileStorage with:
|
||||
- Async file locking for atomic writes
|
||||
- Write batching for performance
|
||||
- Read caching for concurrent access
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.schemas.run import Run, RunSummary, RunStatus
|
||||
from framework.storage.backend import FileStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cached value with timestamp."""
|
||||
value: Any
|
||||
timestamp: float
|
||||
|
||||
def is_expired(self, ttl: float) -> bool:
|
||||
return time.time() - self.timestamp > ttl
|
||||
|
||||
|
||||
class ConcurrentStorage:
|
||||
"""
|
||||
Thread-safe storage backend with file locking and batch writes.
|
||||
|
||||
Provides:
|
||||
- Async file locking to prevent concurrent write corruption
|
||||
- Write batching to reduce I/O overhead
|
||||
- Read caching for frequently accessed data
|
||||
- Compatible API with FileStorage
|
||||
|
||||
Example:
|
||||
storage = ConcurrentStorage("/path/to/storage")
|
||||
await storage.start() # Start batch writer
|
||||
|
||||
# Async save with locking
|
||||
await storage.save_run(run)
|
||||
|
||||
# Cached read
|
||||
run = await storage.load_run(run_id)
|
||||
|
||||
await storage.stop() # Stop batch writer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_path: str | Path,
|
||||
cache_ttl: float = 60.0,
|
||||
batch_interval: float = 0.1,
|
||||
max_batch_size: int = 100,
|
||||
):
|
||||
"""
|
||||
Initialize concurrent storage.
|
||||
|
||||
Args:
|
||||
base_path: Base path for storage
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
batch_interval: Interval between batch flushes
|
||||
max_batch_size: Maximum items before forcing flush
|
||||
"""
|
||||
self.base_path = Path(base_path)
|
||||
self._base_storage = FileStorage(base_path)
|
||||
|
||||
# Caching
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._cache_ttl = cache_ttl
|
||||
|
||||
# Batching
|
||||
self._write_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._batch_interval = batch_interval
|
||||
self._max_batch_size = max_batch_size
|
||||
self._batch_task: asyncio.Task | None = None
|
||||
|
||||
# Locking
|
||||
self._file_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._global_lock = asyncio.Lock()
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the batch writer background task."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._batch_task = asyncio.create_task(self._batch_writer())
|
||||
logger.info(f"ConcurrentStorage started: {self.base_path}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the batch writer and flush pending writes."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Flush remaining items
|
||||
await self._flush_pending()
|
||||
|
||||
# Cancel batch task
|
||||
if self._batch_task:
|
||||
self._batch_task.cancel()
|
||||
try:
|
||||
await self._batch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._batch_task = None
|
||||
|
||||
logger.info("ConcurrentStorage stopped")
|
||||
|
||||
# === RUN OPERATIONS (Async, Thread-Safe) ===
|
||||
|
||||
async def save_run(self, run: Run, immediate: bool = False) -> None:
|
||||
"""
|
||||
Save a run to storage.
|
||||
|
||||
Args:
|
||||
run: Run to save
|
||||
immediate: If True, save immediately (bypasses batching)
|
||||
"""
|
||||
if immediate or not self._running:
|
||||
await self._save_run_locked(run)
|
||||
else:
|
||||
await self._write_queue.put(("run", run))
|
||||
|
||||
# Update cache
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
|
||||
async def _save_run_locked(self, run: Run) -> None:
|
||||
"""Save a run with file locking."""
|
||||
lock_key = f"run:{run.id}"
|
||||
async with self._file_locks[lock_key]:
|
||||
# Run in executor to avoid blocking event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._base_storage.save_run, run)
|
||||
|
||||
async def load_run(self, run_id: str, use_cache: bool = True) -> Run | None:
|
||||
"""
|
||||
Load a run from storage.
|
||||
|
||||
Args:
|
||||
run_id: Run ID to load
|
||||
use_cache: Whether to use cached value if available
|
||||
|
||||
Returns:
|
||||
Run object or None if not found
|
||||
"""
|
||||
cache_key = f"run:{run_id}"
|
||||
|
||||
# Check cache
|
||||
if use_cache and cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if not entry.is_expired(self._cache_ttl):
|
||||
return entry.value
|
||||
|
||||
# Load from storage
|
||||
lock_key = f"run:{run_id}"
|
||||
async with self._file_locks[lock_key]:
|
||||
loop = asyncio.get_event_loop()
|
||||
run = await loop.run_in_executor(
|
||||
None, self._base_storage.load_run, run_id
|
||||
)
|
||||
|
||||
# Update cache
|
||||
if run:
|
||||
self._cache[cache_key] = CacheEntry(run, time.time())
|
||||
|
||||
return run
|
||||
|
||||
async def load_summary(self, run_id: str, use_cache: bool = True) -> RunSummary | None:
|
||||
"""Load just the summary (faster than full run)."""
|
||||
cache_key = f"summary:{run_id}"
|
||||
|
||||
# Check cache
|
||||
if use_cache and cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if not entry.is_expired(self._cache_ttl):
|
||||
return entry.value
|
||||
|
||||
# Load from storage
|
||||
loop = asyncio.get_event_loop()
|
||||
summary = await loop.run_in_executor(
|
||||
None, self._base_storage.load_summary, run_id
|
||||
)
|
||||
|
||||
# Update cache
|
||||
if summary:
|
||||
self._cache[cache_key] = CacheEntry(summary, time.time())
|
||||
|
||||
return summary
|
||||
|
||||
async def delete_run(self, run_id: str) -> bool:
|
||||
"""Delete a run from storage."""
|
||||
lock_key = f"run:{run_id}"
|
||||
async with self._file_locks[lock_key]:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, self._base_storage.delete_run, run_id
|
||||
)
|
||||
|
||||
# Clear cache
|
||||
self._cache.pop(f"run:{run_id}", None)
|
||||
self._cache.pop(f"summary:{run_id}", None)
|
||||
|
||||
return result
|
||||
|
||||
# === QUERY OPERATIONS (Async, with Locking) ===
|
||||
|
||||
async def get_runs_by_goal(self, goal_id: str) -> list[str]:
|
||||
"""Get all run IDs for a goal."""
|
||||
async with self._file_locks[f"index:by_goal:{goal_id}"]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._base_storage.get_runs_by_goal, goal_id
|
||||
)
|
||||
|
||||
async def get_runs_by_status(self, status: str | RunStatus) -> list[str]:
|
||||
"""Get all run IDs with a status."""
|
||||
if isinstance(status, RunStatus):
|
||||
status = status.value
|
||||
async with self._file_locks[f"index:by_status:{status}"]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._base_storage.get_runs_by_status, status
|
||||
)
|
||||
|
||||
async def get_runs_by_node(self, node_id: str) -> list[str]:
|
||||
"""Get all run IDs that executed a node."""
|
||||
async with self._file_locks[f"index:by_node:{node_id}"]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._base_storage.get_runs_by_node, node_id
|
||||
)
|
||||
|
||||
async def list_all_runs(self) -> list[str]:
|
||||
"""List all run IDs."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._base_storage.list_all_runs
|
||||
)
|
||||
|
||||
async def list_all_goals(self) -> list[str]:
|
||||
"""List all goal IDs that have runs."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._base_storage.list_all_goals
|
||||
)
|
||||
|
||||
# === BATCH OPERATIONS ===
|
||||
|
||||
async def _batch_writer(self) -> None:
|
||||
"""Background task that batches writes for performance."""
|
||||
batch: list[tuple[str, Any]] = []
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Collect items with timeout
|
||||
try:
|
||||
item = await asyncio.wait_for(
|
||||
self._write_queue.get(),
|
||||
timeout=self._batch_interval,
|
||||
)
|
||||
batch.append(item)
|
||||
|
||||
# Keep collecting if more items available (up to max batch)
|
||||
while len(batch) < self._max_batch_size:
|
||||
try:
|
||||
item = self._write_queue.get_nowait()
|
||||
batch.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Flush batch if we have items
|
||||
if batch:
|
||||
await self._flush_batch(batch)
|
||||
batch = []
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Flush remaining before exit
|
||||
if batch:
|
||||
await self._flush_batch(batch)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Batch writer error: {e}")
|
||||
# Continue running despite errors
|
||||
|
||||
async def _flush_batch(self, batch: list[tuple[str, Any]]) -> None:
|
||||
"""Flush a batch of writes."""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
logger.debug(f"Flushing batch of {len(batch)} items")
|
||||
|
||||
for item_type, item in batch:
|
||||
try:
|
||||
if item_type == "run":
|
||||
await self._save_run_locked(item)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {item_type}: {e}")
|
||||
|
||||
async def _flush_pending(self) -> None:
|
||||
"""Flush all pending writes."""
|
||||
batch = []
|
||||
while True:
|
||||
try:
|
||||
item = self._write_queue.get_nowait()
|
||||
batch.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if batch:
|
||||
await self._flush_batch(batch)
|
||||
|
||||
# === CACHE MANAGEMENT ===
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached values."""
|
||||
self._cache.clear()
|
||||
|
||||
def invalidate_cache(self, key: str) -> None:
|
||||
"""Invalidate a specific cache entry."""
|
||||
self._cache.pop(key, None)
|
||||
|
||||
def get_cache_stats(self) -> dict:
|
||||
"""Get cache statistics."""
|
||||
now = time.time()
|
||||
expired = sum(
|
||||
1 for entry in self._cache.values()
|
||||
if entry.is_expired(self._cache_ttl)
|
||||
)
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"expired_entries": expired,
|
||||
"valid_entries": len(self._cache) - expired,
|
||||
}
|
||||
|
||||
# === UTILITY ===
|
||||
|
||||
async def get_stats(self) -> dict:
|
||||
"""Get storage statistics."""
|
||||
loop = asyncio.get_event_loop()
|
||||
base_stats = await loop.run_in_executor(
|
||||
None, self._base_storage.get_stats
|
||||
)
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"cache": self.get_cache_stats(),
|
||||
"pending_writes": self._write_queue.qsize(),
|
||||
"running": self._running,
|
||||
}
|
||||
|
||||
# === SYNC API (for backward compatibility) ===
|
||||
|
||||
def save_run_sync(self, run: Run) -> None:
|
||||
"""Synchronous save (uses base storage directly with lock)."""
|
||||
# Use threading lock for sync operations
|
||||
self._base_storage.save_run(run)
|
||||
|
||||
def load_run_sync(self, run_id: str) -> Run | None:
|
||||
"""Synchronous load (uses base storage directly)."""
|
||||
return self._base_storage.load_run(run_id)
|
||||
@@ -0,0 +1,337 @@
|
||||
# Multi-Entry-Point Agent Architecture
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document explains the architectural improvements made to support agents with multiple asynchronous entry points, and why the initial patterns (single-entry execution, tools-as-shared-memory) were insufficient for production use cases.
|
||||
|
||||
---
|
||||
|
||||
## The Problem: Real-World Agents Need Multiple Entry Points
|
||||
|
||||
Consider a Tier-1 support agent that must:
|
||||
|
||||
1. **Listen for Zendesk webhooks** - New tickets arrive asynchronously
|
||||
2. **Handle API requests** - Users can query ticket status or submit follow-ups
|
||||
3. **Process timer events** - Escalation checks run every 5 minutes
|
||||
4. **Respond to internal events** - Other agents may delegate work
|
||||
|
||||
These are not sequential operations—they happen **concurrently and independently**. A webhook might fire while an API request is being processed. Two tickets might arrive simultaneously.
|
||||
|
||||
### Previous Architecture Limitations
|
||||
|
||||
The original framework had a fundamental constraint:
|
||||
|
||||
```python
|
||||
# In Runtime (core.py:58)
|
||||
class Runtime:
|
||||
def __init__(self, ...):
|
||||
self._current_run: Run | None = None # Only ONE run at a time
|
||||
```
|
||||
|
||||
This single `_current_run` meant:
|
||||
|
||||
- **No concurrent executions** - Processing one ticket blocked all others
|
||||
- **No multiple entry points** - Only `entry_node` could start execution
|
||||
- **State collision** - Concurrent attempts would overwrite each other's context
|
||||
|
||||
---
|
||||
|
||||
## Why Tools-as-Shared-Memory is an Anti-Pattern
|
||||
|
||||
A tempting workaround is using tools to manage shared state:
|
||||
|
||||
```python
|
||||
# Anti-pattern: Using tools for state management
|
||||
@tool
|
||||
def get_customer_context(customer_id: str) -> dict:
|
||||
"""Retrieve customer context from database."""
|
||||
return db.get_customer(customer_id)
|
||||
|
||||
@tool
|
||||
def update_ticket_status(ticket_id: str, status: str) -> bool:
|
||||
"""Update ticket status in database."""
|
||||
db.update_ticket(ticket_id, status)
|
||||
return True
|
||||
```
|
||||
|
||||
This seems to work—tools can read/write external storage, enabling "shared state" between executions. **But this approach has serious problems:**
|
||||
|
||||
### 1. Race Conditions Without Isolation Control
|
||||
|
||||
```
|
||||
Execution A: get_customer_context("cust_123") → {tickets: 5}
|
||||
Execution B: get_customer_context("cust_123") → {tickets: 5}
|
||||
Execution A: update_ticket_count("cust_123", 6)
|
||||
Execution B: update_ticket_count("cust_123", 6) # Should be 7!
|
||||
```
|
||||
|
||||
Tools have no concept of isolation levels. Every call goes directly to storage with no coordination. In high-concurrency scenarios, you get:
|
||||
|
||||
- **Lost updates** - Changes overwrite each other
|
||||
- **Dirty reads** - Reading partially-written state
|
||||
- **Phantom data** - State changes between reads in the same logical operation
|
||||
|
||||
### 2. No Transactional Boundaries
|
||||
|
||||
Tools execute independently with no transaction semantics:
|
||||
|
||||
```python
|
||||
# What if this fails halfway?
|
||||
@tool
|
||||
def process_refund(order_id: str) -> dict:
|
||||
mark_order_refunded(order_id) # ✓ Succeeds
|
||||
credit_customer_account(order_id) # ✗ Fails - network error
|
||||
send_confirmation_email(order_id) # Never runs
|
||||
# Now order is marked refunded but customer wasn't credited!
|
||||
```
|
||||
|
||||
With tools-as-state, there's no way to:
|
||||
|
||||
- Roll back partial changes
|
||||
- Ensure atomic operations
|
||||
- Coordinate multi-step state transitions
|
||||
|
||||
### 3. Invisible Dependencies Break Goal Evaluation
|
||||
|
||||
The goal-driven approach relies on tracking decisions and their outcomes:
|
||||
|
||||
```python
|
||||
# Decision: "Update customer tier based on purchase history"
|
||||
# Outcome: Success/Failure with observable state changes
|
||||
```
|
||||
|
||||
When state flows through tools, the framework loses visibility:
|
||||
|
||||
```python
|
||||
@tool
|
||||
def update_customer_tier(customer_id: str) -> str:
|
||||
# What state did this read? What did it change?
|
||||
# The framework has no idea—it just sees "tool returned 'gold'"
|
||||
history = get_purchase_history(customer_id) # Hidden read
|
||||
new_tier = calculate_tier(history) # Hidden logic
|
||||
save_tier(customer_id, new_tier) # Hidden write
|
||||
return new_tier
|
||||
```
|
||||
|
||||
This breaks:
|
||||
|
||||
- **Outcome aggregation** - Can't track what state changed across executions
|
||||
- **Constraint checking** - Can't verify invariants were maintained
|
||||
- **Goal progress evaluation** - Can't correlate actions to success criteria
|
||||
|
||||
### 4. No Execution Correlation
|
||||
|
||||
When multiple entry points trigger concurrently, you need to:
|
||||
|
||||
- Track which execution modified which state
|
||||
- Correlate related operations (e.g., webhook + follow-up API call for same ticket)
|
||||
- Debug issues by tracing execution flow
|
||||
|
||||
Tools provide none of this. Every tool call is independent with no execution context.
|
||||
|
||||
### 5. Testing Becomes Impossible
|
||||
|
||||
With tools-as-state:
|
||||
|
||||
- **Unit tests** can't isolate state—every test affects global storage
|
||||
- **Concurrent tests** interfere with each other
|
||||
- **Mocking** requires replacing actual database/API calls
|
||||
|
||||
Compare to proper state management:
|
||||
|
||||
```python
|
||||
# Isolated test - no external dependencies
|
||||
memory = manager.create_memory("test-exec", "test-stream", IsolationLevel.ISOLATED)
|
||||
await memory.write("key", "value")
|
||||
assert await memory.read("key") == "value"
|
||||
# Other tests unaffected
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## The Solution: Explicit State Management Architecture
|
||||
|
||||
The new architecture introduces explicit state management with proper isolation:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ AgentRuntime │
|
||||
│ - Manages agent lifecycle │
|
||||
│ - Coordinates ExecutionStreams │
|
||||
│ - Aggregates outcomes for goal evaluation │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ Stream A │ │ Stream B │ │ Stream C │ │
|
||||
│ │ (webhook) │ │ (api) │ │ (timer) │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ Concurrent │ │ Concurrent │ │ Concurrent │ │
|
||||
│ │ Executions │ │ Executions │ │ Executions │ │
|
||||
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
|
||||
│ └────────────────┼────────────────┘ │
|
||||
│ ↓ │
|
||||
│ SharedStateManager │
|
||||
│ (Isolation Levels) │
|
||||
│ │
|
||||
│ OutcomeAggregator │
|
||||
│ (Cross-Stream Goals) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
#### 1. SharedStateManager with Isolation Levels
|
||||
|
||||
```python
|
||||
class IsolationLevel(Enum):
|
||||
ISOLATED = "isolated" # Private state per execution
|
||||
SHARED = "shared" # Visible across executions (eventual consistency)
|
||||
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
|
||||
```
|
||||
|
||||
Each execution gets explicit control over state visibility:
|
||||
|
||||
```python
|
||||
# Execution-local state (safe from interference)
|
||||
await memory.write("scratch_data", value, scope=StateScope.EXECUTION)
|
||||
|
||||
# Stream-shared state (visible to all executions in this stream)
|
||||
await memory.write("stream_counter", count, scope=StateScope.STREAM)
|
||||
|
||||
# Global state (visible everywhere, use carefully)
|
||||
await memory.write("system_config", config, scope=StateScope.GLOBAL)
|
||||
```
|
||||
|
||||
#### 2. StreamRuntime with Execution Tracking
|
||||
|
||||
```python
|
||||
class StreamRuntime:
|
||||
def __init__(self, stream_id, storage, outcome_aggregator):
|
||||
# Track runs by execution_id, not single _current_run
|
||||
self._runs: dict[str, Run] = {}
|
||||
```
|
||||
|
||||
Now multiple executions can run concurrently without collision:
|
||||
|
||||
```python
|
||||
# Execution A
|
||||
runtime.start_run(execution_id="exec-A", goal_id="support")
|
||||
runtime.decide(execution_id="exec-A", intent="classify ticket", ...)
|
||||
|
||||
# Execution B (concurrent, no collision)
|
||||
runtime.start_run(execution_id="exec-B", goal_id="support")
|
||||
runtime.decide(execution_id="exec-B", intent="classify ticket", ...)
|
||||
```
|
||||
|
||||
#### 3. OutcomeAggregator for Cross-Stream Goals
|
||||
|
||||
```python
|
||||
class OutcomeAggregator:
|
||||
def record_decision(self, stream_id, execution_id, decision) -> None
|
||||
def record_outcome(self, stream_id, execution_id, decision_id, outcome) -> None
|
||||
async def evaluate_goal_progress(self) -> dict
|
||||
```
|
||||
|
||||
The framework now tracks all decisions across all streams, enabling:
|
||||
|
||||
- Unified goal progress evaluation
|
||||
- Constraint violation detection across executions
|
||||
- Success criteria tracking with proper attribution
|
||||
|
||||
#### 4. EventBus for Coordination
|
||||
|
||||
```python
|
||||
# Stream A publishes
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="webhook",
|
||||
execution_id="exec-123",
|
||||
data={"ticket_resolved": True},
|
||||
))
|
||||
|
||||
# Stream B subscribes
|
||||
bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_COMPLETED],
|
||||
handler=on_ticket_resolved,
|
||||
filter_stream="webhook",
|
||||
)
|
||||
```
|
||||
|
||||
Streams can coordinate without tight coupling or shared mutable state.
|
||||
|
||||
---
|
||||
|
||||
## When Tools ARE Appropriate
|
||||
|
||||
Tools remain the right choice for:
|
||||
|
||||
1. **External system integration** - Calling APIs, databases, services
|
||||
2. **Side effects** - Sending emails, creating resources
|
||||
3. **Data retrieval** - Fetching information needed for decisions
|
||||
|
||||
The key distinction:
|
||||
|
||||
| Use Case | Correct Approach |
|
||||
| ------------------------------------ | --------------------------------- |
|
||||
| Coordinate between executions | SharedStateManager |
|
||||
| Track decision outcomes | StreamRuntime + OutcomeAggregator |
|
||||
| Call external API | Tool |
|
||||
| Persist business data | Tool (to external storage) |
|
||||
| Share scratch state during execution | StreamMemory |
|
||||
| Publish events to other streams | EventBus |
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### Before (Anti-Pattern)
|
||||
|
||||
```python
|
||||
# tools.py - State hidden in tools
|
||||
@tool
|
||||
def get_processing_count() -> int:
|
||||
return redis.get("processing_count") or 0
|
||||
|
||||
@tool
|
||||
def increment_processing_count() -> int:
|
||||
return redis.incr("processing_count")
|
||||
```
|
||||
|
||||
### After (Proper Architecture)
|
||||
|
||||
```python
|
||||
# In node execution
|
||||
async def execute(self, context, memory):
|
||||
# Read from managed state
|
||||
count = await memory.read("processing_count") or 0
|
||||
|
||||
# Update with proper isolation
|
||||
await memory.write(
|
||||
"processing_count",
|
||||
count + 1,
|
||||
scope=StateScope.STREAM, # Explicit scope
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
| Aspect | Tools-as-State | Explicit State Management |
|
||||
| ------------- | ---------------- | ------------------------- |
|
||||
| Concurrency | Race conditions | Isolation levels |
|
||||
| Transactions | None | Execution-scoped |
|
||||
| Visibility | Hidden | Observable |
|
||||
| Testing | Requires mocking | Isolated by design |
|
||||
| Goal tracking | Broken | Full attribution |
|
||||
| Debugging | Opaque | Traceable |
|
||||
|
||||
The multi-entry-point architecture doesn't just enable concurrent execution—it provides the foundation for **reliable, observable, goal-driven agents** that can operate safely in production environments.
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- [core/framework/runtime/agent_runtime.py](../../core/framework/runtime/agent_runtime.py) - AgentRuntime implementation
|
||||
- [core/framework/runtime/shared_state.py](../../core/framework/runtime/shared_state.py) - SharedStateManager
|
||||
- [core/framework/runtime/outcome_aggregator.py](../../core/framework/runtime/outcome_aggregator.py) - Cross-stream goal evaluation
|
||||
- [core/framework/runtime/tests/test_agent_runtime.py](../../core/framework/runtime/tests/test_agent_runtime.py) - Test examples
|
||||
Reference in New Issue
Block a user