Files
hive/core/framework/runtime/agent_runtime.py
T
2026-02-12 19:52:15 -08:00

603 lines
20 KiB
Python

"""
Agent Runtime - Top-level orchestrator for multi-entry-point agents.
Manages agent lifecycle and coordinates multiple execution streams
while preserving the goal-driven approach.
"""
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
from framework.graph.checkpoint_config import CheckpointConfig
from framework.graph.executor import ExecutionResult
from framework.runtime.event_bus import EventBus
from framework.runtime.execution_stream import EntryPointSpec, ExecutionStream
from framework.runtime.outcome_aggregator import OutcomeAggregator
from framework.runtime.shared_state import SharedStateManager
from framework.storage.concurrent import ConcurrentStorage
from framework.storage.session_store import SessionStore
if TYPE_CHECKING:
from framework.graph.edge import GraphSpec
from framework.graph.goal import Goal
from framework.llm.provider import LLMProvider, Tool
logger = logging.getLogger(__name__)
@dataclass
class AgentRuntimeConfig:
"""Configuration for AgentRuntime."""
max_concurrent_executions: int = 100
cache_ttl: float = 60.0
batch_interval: float = 0.1
max_history: int = 1000
execution_result_max: int = 1000
execution_result_ttl_seconds: float | None = None
# Webhook server config (only starts if webhook_routes is non-empty)
webhook_host: str = "127.0.0.1"
webhook_port: int = 8080
webhook_routes: list[dict] = field(default_factory=list)
# Each dict: {"source_id": str, "path": str, "methods": ["POST"], "secret": str|None}
class AgentRuntime:
"""
Top-level runtime that manages agent lifecycle and concurrent executions.
Responsibilities:
- Register and manage multiple entry points
- Coordinate execution streams
- Manage shared state across streams
- Aggregate decisions/outcomes for goal evaluation
- Handle lifecycle events (start, pause, shutdown)
Example:
# Create runtime
runtime = AgentRuntime(
graph=support_agent_graph,
goal=support_agent_goal,
storage_path=Path("./storage"),
llm=llm_provider,
)
# Register entry points
runtime.register_entry_point(EntryPointSpec(
id="webhook",
name="Zendesk Webhook",
entry_node="process-webhook",
trigger_type="webhook",
isolation_level="shared",
))
runtime.register_entry_point(EntryPointSpec(
id="api",
name="API Handler",
entry_node="process-request",
trigger_type="api",
isolation_level="shared",
))
# Start runtime
await runtime.start()
# Trigger executions (non-blocking)
exec_1 = await runtime.trigger("webhook", {"ticket_id": "123"})
exec_2 = await runtime.trigger("api", {"query": "help"})
# Check goal progress
progress = await runtime.get_goal_progress()
print(f"Progress: {progress['overall_progress']:.1%}")
# Stop runtime
await runtime.stop()
"""
def __init__(
self,
graph: "GraphSpec",
goal: "Goal",
storage_path: str | Path,
llm: "LLMProvider | None" = None,
tools: list["Tool"] | None = None,
tool_executor: Callable | None = None,
config: AgentRuntimeConfig | None = None,
runtime_log_store: Any = None,
checkpoint_config: CheckpointConfig | None = None,
):
"""
Initialize agent runtime.
Args:
graph: Graph specification for this agent
goal: Goal driving execution
storage_path: Path for persistent storage
llm: LLM provider for nodes
tools: Available tools
tool_executor: Function to execute tools
config: Optional runtime configuration
runtime_log_store: Optional RuntimeLogStore for per-execution logging
checkpoint_config: Optional checkpoint configuration for resumable sessions
"""
self.graph = graph
self.goal = goal
self._config = config or AgentRuntimeConfig()
self._runtime_log_store = runtime_log_store
self._checkpoint_config = checkpoint_config
# Initialize storage
storage_path_obj = Path(storage_path) if isinstance(storage_path, str) else storage_path
self._storage = ConcurrentStorage(
base_path=storage_path_obj,
cache_ttl=self._config.cache_ttl,
batch_interval=self._config.batch_interval,
)
# Initialize SessionStore for unified sessions (always enabled)
self._session_store = SessionStore(storage_path_obj)
# Initialize shared components
self._state_manager = SharedStateManager()
self._event_bus = EventBus(max_history=self._config.max_history)
self._outcome_aggregator = OutcomeAggregator(goal, self._event_bus)
# LLM and tools
self._llm = llm
self._tools = tools or []
self._tool_executor = tool_executor
# Entry points and streams
self._entry_points: dict[str, EntryPointSpec] = {}
self._streams: dict[str, ExecutionStream] = {}
# Webhook server (created on start if webhook_routes configured)
self._webhook_server: Any = None
# Event-driven entry point subscriptions
self._event_subscriptions: list[str] = []
# State
self._running = False
self._lock = asyncio.Lock()
# Optional greeting shown to user on TUI load (set by AgentRunner)
self.intro_message: str = ""
def register_entry_point(self, spec: EntryPointSpec) -> None:
"""
Register a named entry point for the agent.
Args:
spec: Entry point specification
Raises:
ValueError: If entry point ID already registered
RuntimeError: If runtime is already running
"""
if self._running:
raise RuntimeError("Cannot register entry points while runtime is running")
if spec.id in self._entry_points:
raise ValueError(f"Entry point '{spec.id}' already registered")
# Validate entry node exists in graph
if self.graph.get_node(spec.entry_node) is None:
raise ValueError(f"Entry node '{spec.entry_node}' not found in graph")
self._entry_points[spec.id] = spec
logger.info(f"Registered entry point: {spec.id} -> {spec.entry_node}")
def unregister_entry_point(self, entry_point_id: str) -> bool:
"""
Unregister an entry point.
Args:
entry_point_id: Entry point to remove
Returns:
True if removed, False if not found
Raises:
RuntimeError: If runtime is running
"""
if self._running:
raise RuntimeError("Cannot unregister entry points while runtime is running")
if entry_point_id in self._entry_points:
del self._entry_points[entry_point_id]
return True
return False
async def start(self) -> None:
"""Start the agent runtime and all registered entry points."""
if self._running:
return
async with self._lock:
# Start storage
await self._storage.start()
# Create streams for each entry point
for ep_id, spec in self._entry_points.items():
stream = ExecutionStream(
stream_id=ep_id,
entry_spec=spec,
graph=self.graph,
goal=self.goal,
state_manager=self._state_manager,
storage=self._storage,
outcome_aggregator=self._outcome_aggregator,
event_bus=self._event_bus,
llm=self._llm,
tools=self._tools,
tool_executor=self._tool_executor,
result_retention_max=self._config.execution_result_max,
result_retention_ttl_seconds=self._config.execution_result_ttl_seconds,
runtime_log_store=self._runtime_log_store,
session_store=self._session_store,
checkpoint_config=self._checkpoint_config,
)
await stream.start()
self._streams[ep_id] = stream
# Start webhook server if routes are configured
if self._config.webhook_routes:
from framework.runtime.webhook_server import (
WebhookRoute,
WebhookServer,
WebhookServerConfig,
)
wh_config = WebhookServerConfig(
host=self._config.webhook_host,
port=self._config.webhook_port,
)
self._webhook_server = WebhookServer(self._event_bus, wh_config)
for rc in self._config.webhook_routes:
route = WebhookRoute(
source_id=rc["source_id"],
path=rc["path"],
methods=rc.get("methods", ["POST"]),
secret=rc.get("secret"),
)
self._webhook_server.add_route(route)
await self._webhook_server.start()
# Subscribe event-driven entry points to EventBus
from framework.runtime.event_bus import EventType as _ET
for ep_id, spec in self._entry_points.items():
if spec.trigger_type != "event":
continue
tc = spec.trigger_config
event_types = [_ET(et) for et in tc.get("event_types", [])]
if not event_types:
logger.warning(
f"Entry point '{ep_id}' has trigger_type='event' "
"but no event_types in trigger_config"
)
continue
# Capture ep_id in closure
def _make_handler(entry_point_id: str):
async def _on_event(event):
if self._running and entry_point_id in self._streams:
await self.trigger(entry_point_id, {"event": event.to_dict()})
return _on_event
sub_id = self._event_bus.subscribe(
event_types=event_types,
handler=_make_handler(ep_id),
filter_stream=tc.get("filter_stream"),
filter_node=tc.get("filter_node"),
)
self._event_subscriptions.append(sub_id)
self._running = True
logger.info(f"AgentRuntime started with {len(self._streams)} streams")
async def stop(self) -> None:
"""Stop the agent runtime and all streams."""
if not self._running:
return
async with self._lock:
# Unsubscribe event-driven entry points
for sub_id in self._event_subscriptions:
self._event_bus.unsubscribe(sub_id)
self._event_subscriptions.clear()
# Stop webhook server
if self._webhook_server:
await self._webhook_server.stop()
self._webhook_server = None
# Stop all streams
for stream in self._streams.values():
await stream.stop()
self._streams.clear()
# Stop storage
await self._storage.stop()
self._running = False
logger.info("AgentRuntime stopped")
async def trigger(
self,
entry_point_id: str,
input_data: dict[str, Any],
correlation_id: str | None = None,
session_state: dict[str, Any] | None = None,
) -> str:
"""
Trigger execution at a specific entry point.
Non-blocking - returns immediately with execution ID.
Args:
entry_point_id: Which entry point to trigger
input_data: Input data for the execution
correlation_id: Optional ID to correlate related executions
session_state: Optional session state to resume from (with paused_at, memory)
Returns:
Execution ID for tracking
Raises:
ValueError: If entry point not found
RuntimeError: If runtime not running
"""
if not self._running:
raise RuntimeError("AgentRuntime is not running")
stream = self._streams.get(entry_point_id)
if stream is None:
raise ValueError(f"Entry point '{entry_point_id}' not found")
return await stream.execute(input_data, correlation_id, session_state)
async def trigger_and_wait(
self,
entry_point_id: str,
input_data: dict[str, Any],
timeout: float | None = None,
session_state: dict[str, Any] | None = None,
) -> ExecutionResult | None:
"""
Trigger execution and wait for completion.
Args:
entry_point_id: Which entry point to trigger
input_data: Input data for the execution
timeout: Maximum time to wait (seconds)
session_state: Optional session state to resume from (with paused_at, memory)
Returns:
ExecutionResult or None if timeout
"""
exec_id = await self.trigger(entry_point_id, input_data, session_state=session_state)
stream = self._streams.get(entry_point_id)
if stream is None:
raise ValueError(f"Entry point '{entry_point_id}' not found")
return await stream.wait_for_completion(exec_id, timeout)
async def inject_input(self, node_id: str, content: str) -> bool:
"""Inject user input into a running client-facing node.
Routes input to the EventLoopNode identified by ``node_id``
across all active streams. Used by the TUI ChatRepl to deliver
user responses during client-facing node execution.
Args:
node_id: The node currently waiting for input
content: The user's input text
Returns:
True if input was delivered, False if no matching node found
"""
for stream in self._streams.values():
if await stream.inject_input(node_id, content):
return True
return False
async def get_goal_progress(self) -> dict[str, Any]:
"""
Evaluate goal progress across all streams.
Returns:
Progress report including overall progress, criteria status,
constraint violations, and metrics.
"""
return await self._outcome_aggregator.evaluate_goal_progress()
async def cancel_execution(
self,
entry_point_id: str,
execution_id: str,
) -> bool:
"""
Cancel a running execution.
Args:
entry_point_id: Stream containing the execution
execution_id: Execution to cancel
Returns:
True if cancelled, False if not found
"""
stream = self._streams.get(entry_point_id)
if stream is None:
return False
return await stream.cancel_execution(execution_id)
# === QUERY OPERATIONS ===
def get_entry_points(self) -> list[EntryPointSpec]:
"""Get all registered entry points."""
return list(self._entry_points.values())
def get_stream(self, entry_point_id: str) -> ExecutionStream | None:
"""Get a specific execution stream."""
return self._streams.get(entry_point_id)
def get_execution_result(
self,
entry_point_id: str,
execution_id: str,
) -> ExecutionResult | None:
"""Get result of a completed execution."""
stream = self._streams.get(entry_point_id)
if stream:
return stream.get_result(execution_id)
return None
# === EVENT SUBSCRIPTIONS ===
def subscribe_to_events(
self,
event_types: list,
handler: Callable,
filter_stream: str | None = None,
) -> str:
"""
Subscribe to agent events.
Args:
event_types: Types of events to receive
handler: Async function to call when event occurs
filter_stream: Only receive events from this stream
Returns:
Subscription ID (use to unsubscribe)
"""
return self._event_bus.subscribe(
event_types=event_types,
handler=handler,
filter_stream=filter_stream,
)
def unsubscribe_from_events(self, subscription_id: str) -> bool:
"""Unsubscribe from events."""
return self._event_bus.unsubscribe(subscription_id)
# === STATS AND MONITORING ===
def get_stats(self) -> dict:
"""Get comprehensive runtime statistics."""
stream_stats = {}
for ep_id, stream in self._streams.items():
stream_stats[ep_id] = stream.get_stats()
return {
"running": self._running,
"entry_points": len(self._entry_points),
"streams": stream_stats,
"goal_id": self.goal.id,
"outcome_aggregator": self._outcome_aggregator.get_stats(),
"event_bus": self._event_bus.get_stats(),
"state_manager": self._state_manager.get_stats(),
}
# === PROPERTIES ===
@property
def state_manager(self) -> SharedStateManager:
"""Access the shared state manager."""
return self._state_manager
@property
def event_bus(self) -> EventBus:
"""Access the event bus."""
return self._event_bus
@property
def outcome_aggregator(self) -> OutcomeAggregator:
"""Access the outcome aggregator."""
return self._outcome_aggregator
@property
def webhook_server(self) -> Any:
"""Access the webhook server (None if no webhook entry points)."""
return self._webhook_server
@property
def is_running(self) -> bool:
"""Check if runtime is running."""
return self._running
# === CONVENIENCE FACTORY ===
def create_agent_runtime(
graph: "GraphSpec",
goal: "Goal",
storage_path: str | Path,
entry_points: list[EntryPointSpec],
llm: "LLMProvider | None" = None,
tools: list["Tool"] | None = None,
tool_executor: Callable | None = None,
config: AgentRuntimeConfig | None = None,
runtime_log_store: Any = None,
enable_logging: bool = True,
checkpoint_config: CheckpointConfig | None = None,
) -> AgentRuntime:
"""
Create and configure an AgentRuntime with entry points.
Convenience factory that creates runtime and registers entry points.
Runtime logging is enabled by default for observability.
Args:
graph: Graph specification
goal: Goal driving execution
storage_path: Path for persistent storage
entry_points: Entry point specifications
llm: LLM provider
tools: Available tools
tool_executor: Tool executor function
config: Runtime configuration
runtime_log_store: Optional RuntimeLogStore for per-execution logging.
If None and enable_logging=True, creates one automatically.
enable_logging: Whether to enable runtime logging (default: True).
Set to False to disable logging entirely.
checkpoint_config: Optional checkpoint configuration for resumable sessions.
If None, uses default checkpointing behavior.
Returns:
Configured AgentRuntime (not yet started)
"""
# Auto-create runtime log store if logging is enabled and not provided
if enable_logging and runtime_log_store is None:
from framework.runtime.runtime_log_store import RuntimeLogStore
storage_path_obj = Path(storage_path) if isinstance(storage_path, str) else storage_path
runtime_log_store = RuntimeLogStore(storage_path_obj / "runtime_logs")
runtime = AgentRuntime(
graph=graph,
goal=goal,
storage_path=storage_path,
llm=llm,
tools=tools,
tool_executor=tool_executor,
config=config,
runtime_log_store=runtime_log_store,
checkpoint_config=checkpoint_config,
)
for spec in entry_points:
runtime.register_entry_point(spec)
return runtime