Merge pull request #4720 from TimothyZhang7/feature/event-source
Feature/event source
This commit is contained in:
@@ -703,6 +703,17 @@ class EventLoopNode(NodeProtocol):
|
||||
fb_preview,
|
||||
)
|
||||
|
||||
# Publish judge verdict event
|
||||
judge_type = "custom" if self._judge is not None else "implicit"
|
||||
await self._publish_judge_verdict(
|
||||
stream_id,
|
||||
node_id,
|
||||
action=verdict.action,
|
||||
feedback=fb_preview,
|
||||
judge_type=judge_type,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
if verdict.action == "ACCEPT":
|
||||
# Check for missing output keys
|
||||
missing = self._get_missing_output_keys(
|
||||
@@ -1058,13 +1069,20 @@ class EventLoopNode(NodeProtocol):
|
||||
user_input_requested,
|
||||
)
|
||||
|
||||
# Execute tool calls — separate real tools from set_output
|
||||
# Execute tool calls — framework tools (set_output, ask_user)
|
||||
# run inline; real MCP tools run in parallel.
|
||||
real_tool_results: list[dict] = []
|
||||
limit_hit = False
|
||||
executed_in_batch = 0
|
||||
hard_limit = int(
|
||||
self._config.max_tool_calls_per_turn * (1 + self._config.tool_call_overflow_margin)
|
||||
)
|
||||
|
||||
# Phase 1: triage — handle framework tools immediately,
|
||||
# queue real tools for parallel execution.
|
||||
results_by_id: dict[str, ToolResult] = {}
|
||||
pending_real: list[ToolCallEvent] = []
|
||||
|
||||
for tc in tool_calls:
|
||||
tool_call_count += 1
|
||||
if tool_call_count > hard_limit:
|
||||
@@ -1072,11 +1090,9 @@ class EventLoopNode(NodeProtocol):
|
||||
break
|
||||
executed_in_batch += 1
|
||||
|
||||
# Publish tool call started
|
||||
await self._publish_tool_started(
|
||||
stream_id, node_id, tc.tool_use_id, tc.tool_name, tc.tool_input
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[%s] tool_call: %s(%s)",
|
||||
node_id,
|
||||
@@ -1107,6 +1123,7 @@ class EventLoopNode(NodeProtocol):
|
||||
key = tc.tool_input.get("key", "")
|
||||
await accumulator.set(key, value)
|
||||
outputs_set_this_turn.append(key)
|
||||
await self._publish_output_key_set(stream_id, node_id, key)
|
||||
logged_tool_calls.append(
|
||||
{
|
||||
"tool_use_id": tc.tool_use_id,
|
||||
@@ -1116,6 +1133,8 @@ class EventLoopNode(NodeProtocol):
|
||||
"is_error": result.is_error,
|
||||
}
|
||||
)
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
|
||||
elif tc.tool_name == "ask_user":
|
||||
# --- Framework-level ask_user handling ---
|
||||
user_input_requested = True
|
||||
@@ -1124,10 +1143,10 @@ class EventLoopNode(NodeProtocol):
|
||||
content="Waiting for user input...",
|
||||
is_error=False,
|
||||
)
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
|
||||
else:
|
||||
# --- Real tool execution ---
|
||||
# Guard: detect truncated tool arguments (_raw fallback
|
||||
# from litellm when json.loads fails on max_tokens hit).
|
||||
# --- Real tool: check for truncated args, else queue ---
|
||||
if "_raw" in tc.tool_input:
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
@@ -1143,9 +1162,36 @@ class EventLoopNode(NodeProtocol):
|
||||
node_id,
|
||||
tc.tool_name,
|
||||
)
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
else:
|
||||
result = await self._execute_tool(tc)
|
||||
result = self._truncate_tool_result(result, tc.tool_name)
|
||||
pending_real.append(tc)
|
||||
|
||||
# Phase 2: execute real tools in parallel.
|
||||
if pending_real:
|
||||
raw_results = await asyncio.gather(
|
||||
*(self._execute_tool(tc) for tc in pending_real),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for tc, raw in zip(pending_real, raw_results, strict=True):
|
||||
if isinstance(raw, BaseException):
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"Tool '{tc.tool_name}' raised: {raw}",
|
||||
is_error=True,
|
||||
)
|
||||
else:
|
||||
result = raw
|
||||
results_by_id[tc.tool_use_id] = self._truncate_tool_result(result, tc.tool_name)
|
||||
|
||||
# Phase 3: record results into conversation in original order,
|
||||
# build logged/real lists, and publish completed events.
|
||||
for tc in tool_calls[:executed_in_batch]:
|
||||
result = results_by_id.get(tc.tool_use_id)
|
||||
if result is None:
|
||||
continue # shouldn't happen
|
||||
|
||||
# Build log entries for real tools
|
||||
if tc.tool_name not in ("set_output", "ask_user"):
|
||||
tool_entry = {
|
||||
"tool_use_id": tc.tool_use_id,
|
||||
"tool_name": tc.tool_name,
|
||||
@@ -1156,15 +1202,11 @@ class EventLoopNode(NodeProtocol):
|
||||
real_tool_results.append(tool_entry)
|
||||
logged_tool_calls.append(tool_entry)
|
||||
|
||||
# Record tool result in conversation (both real and set_output
|
||||
# go into the conversation for LLM context continuity)
|
||||
await conversation.add_tool_result(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
)
|
||||
|
||||
# Publish tool call completed
|
||||
await self._publish_tool_completed(
|
||||
stream_id,
|
||||
node_id,
|
||||
@@ -1617,10 +1659,27 @@ class EventLoopNode(NodeProtocol):
|
||||
return result
|
||||
|
||||
# load_data is the designated mechanism for reading spilled files.
|
||||
# The LLM controls chunk size via offset/limit — re-spilling its
|
||||
# result would create a circular loop.
|
||||
# Don't re-spill (circular), but DO truncate with a pagination hint.
|
||||
if tool_name == "load_data":
|
||||
return result
|
||||
preview_chars = max(limit - 300, limit // 2)
|
||||
preview = result.content[:preview_chars]
|
||||
truncated = (
|
||||
f"[load_data result: {len(result.content)} chars — "
|
||||
f"too large for context. Use offset and limit parameters "
|
||||
f"to read smaller chunks, e.g. "
|
||||
f"load_data(filename=..., offset=0, limit=50).]\n\n"
|
||||
f"Preview:\n{preview}…"
|
||||
)
|
||||
logger.info(
|
||||
"load_data result truncated: %d → %d chars (use offset/limit to paginate)",
|
||||
len(result.content),
|
||||
len(truncated),
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=truncated,
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
# Determine a preview size — leave room for the metadata wrapper
|
||||
preview_chars = max(limit - 300, limit // 2)
|
||||
@@ -2121,3 +2180,35 @@ class EventLoopNode(NodeProtocol):
|
||||
result=result,
|
||||
is_error=is_error,
|
||||
)
|
||||
|
||||
async def _publish_judge_verdict(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
action: str,
|
||||
feedback: str = "",
|
||||
judge_type: str = "implicit",
|
||||
iteration: int = 0,
|
||||
) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_judge_verdict(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
action=action,
|
||||
feedback=feedback,
|
||||
judge_type=judge_type,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
async def _publish_output_key_set(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
key: str,
|
||||
) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_output_key_set(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
key=key,
|
||||
)
|
||||
|
||||
@@ -462,6 +462,13 @@ class GraphExecutor:
|
||||
if session_state and current_node_id != graph.entry_node:
|
||||
self.logger.info(f"🔄 Resuming from: {current_node_id}")
|
||||
|
||||
# Emit resume event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_resumed(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
)
|
||||
|
||||
# Start run
|
||||
_run_id = self.runtime.start_run(
|
||||
goal_id=goal.id,
|
||||
@@ -498,6 +505,14 @@ class GraphExecutor:
|
||||
if self._pause_requested.is_set():
|
||||
self.logger.info("⏸ Pause detected - stopping at node boundary")
|
||||
|
||||
# Emit pause event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_paused(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
reason="User requested pause (Ctrl+Z)",
|
||||
)
|
||||
|
||||
# Create session state for pause
|
||||
saved_memory = memory.read_all()
|
||||
pause_session_state: dict[str, Any] = {
|
||||
@@ -782,6 +797,17 @@ class GraphExecutor:
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
|
||||
# Emit retry event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_retry(
|
||||
stream_id=self._stream_id,
|
||||
node_id=current_node_id,
|
||||
retry_count=retry_count,
|
||||
max_retries=max_retries,
|
||||
error=result.error or "",
|
||||
)
|
||||
|
||||
_is_retry = True
|
||||
continue
|
||||
else:
|
||||
@@ -868,6 +894,15 @@ class GraphExecutor:
|
||||
# This must happen BEFORE determining next node, since pause nodes may have no edges
|
||||
if node_spec.id in graph.pause_nodes:
|
||||
self.logger.info("💾 Saving session state after pause node")
|
||||
|
||||
# Emit pause event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_execution_paused(
|
||||
stream_id=self._stream_id,
|
||||
node_id=node_spec.id,
|
||||
reason="HITL pause node",
|
||||
)
|
||||
|
||||
saved_memory = memory.read_all()
|
||||
session_state_out = {
|
||||
"paused_at": node_spec.id,
|
||||
@@ -923,6 +958,16 @@ class GraphExecutor:
|
||||
if result.next_node:
|
||||
# Router explicitly set next node
|
||||
self.logger.info(f" → Router directing to: {result.next_node}")
|
||||
|
||||
# Emit edge traversed event for router-directed edge
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=result.next_node,
|
||||
edge_condition="router",
|
||||
)
|
||||
|
||||
current_node_id = result.next_node
|
||||
self._write_progress(current_node_id, path, memory, node_visit_counts)
|
||||
else:
|
||||
@@ -946,6 +991,18 @@ class GraphExecutor:
|
||||
targets = [e.target for e in traversable_edges]
|
||||
fan_in_node = self._find_convergence_node(graph, targets)
|
||||
|
||||
# Emit edge traversed events for fan-out branches
|
||||
if self._event_bus:
|
||||
for edge in traversable_edges:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=edge.target,
|
||||
edge_condition=edge.condition.value
|
||||
if hasattr(edge.condition, "value")
|
||||
else str(edge.condition),
|
||||
)
|
||||
|
||||
# Execute branches in parallel
|
||||
(
|
||||
_branch_results,
|
||||
@@ -989,6 +1046,14 @@ class GraphExecutor:
|
||||
next_spec = graph.get_node(next_node)
|
||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||
|
||||
# Emit edge traversed event
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_edge_traversed(
|
||||
stream_id=self._stream_id,
|
||||
source_node=current_node_id,
|
||||
target_node=next_node,
|
||||
)
|
||||
|
||||
# CHECKPOINT: node_complete (after determining next node)
|
||||
if (
|
||||
checkpoint_store
|
||||
|
||||
@@ -8,7 +8,7 @@ while preserving the goal-driven approach.
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -39,6 +39,11 @@ class AgentRuntimeConfig:
|
||||
max_history: int = 1000
|
||||
execution_result_max: int = 1000
|
||||
execution_result_ttl_seconds: float | None = None
|
||||
# Webhook server config (only starts if webhook_routes is non-empty)
|
||||
webhook_host: str = "127.0.0.1"
|
||||
webhook_port: int = 8080
|
||||
webhook_routes: list[dict] = field(default_factory=list)
|
||||
# Each dict: {"source_id": str, "path": str, "methods": ["POST"], "secret": str|None}
|
||||
|
||||
|
||||
class AgentRuntime:
|
||||
@@ -150,6 +155,11 @@ class AgentRuntime:
|
||||
self._entry_points: dict[str, EntryPointSpec] = {}
|
||||
self._streams: dict[str, ExecutionStream] = {}
|
||||
|
||||
# Webhook server (created on start if webhook_routes configured)
|
||||
self._webhook_server: Any = None
|
||||
# Event-driven entry point subscriptions
|
||||
self._event_subscriptions: list[str] = []
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
self._lock = asyncio.Lock()
|
||||
@@ -234,6 +244,63 @@ class AgentRuntime:
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
|
||||
# Start webhook server if routes are configured
|
||||
if self._config.webhook_routes:
|
||||
from framework.runtime.webhook_server import (
|
||||
WebhookRoute,
|
||||
WebhookServer,
|
||||
WebhookServerConfig,
|
||||
)
|
||||
|
||||
wh_config = WebhookServerConfig(
|
||||
host=self._config.webhook_host,
|
||||
port=self._config.webhook_port,
|
||||
)
|
||||
self._webhook_server = WebhookServer(self._event_bus, wh_config)
|
||||
|
||||
for rc in self._config.webhook_routes:
|
||||
route = WebhookRoute(
|
||||
source_id=rc["source_id"],
|
||||
path=rc["path"],
|
||||
methods=rc.get("methods", ["POST"]),
|
||||
secret=rc.get("secret"),
|
||||
)
|
||||
self._webhook_server.add_route(route)
|
||||
|
||||
await self._webhook_server.start()
|
||||
|
||||
# Subscribe event-driven entry points to EventBus
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
for ep_id, spec in self._entry_points.items():
|
||||
if spec.trigger_type != "event":
|
||||
continue
|
||||
|
||||
tc = spec.trigger_config
|
||||
event_types = [_ET(et) for et in tc.get("event_types", [])]
|
||||
if not event_types:
|
||||
logger.warning(
|
||||
f"Entry point '{ep_id}' has trigger_type='event' "
|
||||
"but no event_types in trigger_config"
|
||||
)
|
||||
continue
|
||||
|
||||
# Capture ep_id in closure
|
||||
def _make_handler(entry_point_id: str):
|
||||
async def _on_event(event):
|
||||
if self._running and entry_point_id in self._streams:
|
||||
await self.trigger(entry_point_id, {"event": event.to_dict()})
|
||||
|
||||
return _on_event
|
||||
|
||||
sub_id = self._event_bus.subscribe(
|
||||
event_types=event_types,
|
||||
handler=_make_handler(ep_id),
|
||||
filter_stream=tc.get("filter_stream"),
|
||||
filter_node=tc.get("filter_node"),
|
||||
)
|
||||
self._event_subscriptions.append(sub_id)
|
||||
|
||||
self._running = True
|
||||
logger.info(f"AgentRuntime started with {len(self._streams)} streams")
|
||||
|
||||
@@ -243,6 +310,16 @@ class AgentRuntime:
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
# Unsubscribe event-driven entry points
|
||||
for sub_id in self._event_subscriptions:
|
||||
self._event_bus.unsubscribe(sub_id)
|
||||
self._event_subscriptions.clear()
|
||||
|
||||
# Stop webhook server
|
||||
if self._webhook_server:
|
||||
await self._webhook_server.stop()
|
||||
self._webhook_server = None
|
||||
|
||||
# Stop all streams
|
||||
for stream in self._streams.values():
|
||||
await stream.stop()
|
||||
@@ -448,6 +525,11 @@ class AgentRuntime:
|
||||
"""Access the outcome aggregator."""
|
||||
return self._outcome_aggregator
|
||||
|
||||
@property
|
||||
def webhook_server(self) -> Any:
|
||||
"""Access the webhook server (None if no webhook entry points)."""
|
||||
return self._webhook_server
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if runtime is running."""
|
||||
|
||||
@@ -63,9 +63,22 @@ class EventType(StrEnum):
|
||||
NODE_INPUT_BLOCKED = "node_input_blocked"
|
||||
NODE_STALLED = "node_stalled"
|
||||
|
||||
# Judge decisions
|
||||
JUDGE_VERDICT = "judge_verdict"
|
||||
|
||||
# Output tracking
|
||||
OUTPUT_KEY_SET = "output_key_set"
|
||||
|
||||
# Retry / edge tracking
|
||||
NODE_RETRY = "node_retry"
|
||||
EDGE_TRAVERSED = "edge_traversed"
|
||||
|
||||
# Context management
|
||||
CONTEXT_COMPACTED = "context_compacted"
|
||||
|
||||
# External triggers
|
||||
WEBHOOK_RECEIVED = "webhook_received"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
|
||||
@@ -636,6 +649,158 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
# === JUDGE / OUTPUT / RETRY / EDGE PUBLISHERS ===
|
||||
|
||||
async def emit_judge_verdict(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
action: str,
|
||||
feedback: str = "",
|
||||
judge_type: str = "implicit",
|
||||
iteration: int = 0,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit judge verdict event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.JUDGE_VERDICT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"action": action,
|
||||
"feedback": feedback,
|
||||
"judge_type": judge_type,
|
||||
"iteration": iteration,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_output_key_set(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
key: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit output key set event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.OUTPUT_KEY_SET,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"key": key},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_retry(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
retry_count: int,
|
||||
max_retries: int,
|
||||
error: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node retry event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_RETRY,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"retry_count": retry_count,
|
||||
"max_retries": max_retries,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_edge_traversed(
|
||||
self,
|
||||
stream_id: str,
|
||||
source_node: str,
|
||||
target_node: str,
|
||||
edge_condition: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit edge traversed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EDGE_TRAVERSED,
|
||||
stream_id=stream_id,
|
||||
node_id=source_node,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"source_node": source_node,
|
||||
"target_node": target_node,
|
||||
"edge_condition": edge_condition,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_execution_paused(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
reason: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution paused event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_PAUSED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"reason": reason},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_execution_resumed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution resumed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_RESUMED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_webhook_received(
|
||||
self,
|
||||
source_id: str,
|
||||
path: str,
|
||||
method: str,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
query_params: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Emit webhook received event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.WEBHOOK_RECEIVED,
|
||||
stream_id=source_id,
|
||||
data={
|
||||
"path": path,
|
||||
"method": method,
|
||||
"headers": headers,
|
||||
"payload": payload,
|
||||
"query_params": query_params or {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
|
||||
@@ -0,0 +1,717 @@
|
||||
"""
|
||||
Tests for WebhookServer and event-driven entry points.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac as hmac_mod
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.webhook_server import (
|
||||
WebhookRoute,
|
||||
WebhookServer,
|
||||
WebhookServerConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_server(event_bus: EventBus, routes: list[WebhookRoute] | None = None):
|
||||
"""Helper to create a WebhookServer with port=0 for OS-assigned port."""
|
||||
config = WebhookServerConfig(host="127.0.0.1", port=0)
|
||||
server = WebhookServer(event_bus, config)
|
||||
for route in routes or []:
|
||||
server.add_route(route)
|
||||
return server
|
||||
|
||||
|
||||
def _base_url(server: WebhookServer) -> str:
|
||||
"""Get the base URL for a running server."""
|
||||
return f"http://127.0.0.1:{server.port}"
|
||||
|
||||
|
||||
class TestWebhookServerLifecycle:
|
||||
"""Tests for server start/stop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="test", path="/webhooks/test", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
|
||||
await server.start()
|
||||
assert server.is_running
|
||||
assert server.port is not None
|
||||
|
||||
await server.stop()
|
||||
assert not server.is_running
|
||||
assert server.port is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_routes_skips_start(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(bus) # no routes
|
||||
|
||||
await server.start()
|
||||
assert not server.is_running
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_not_started(self):
|
||||
bus = EventBus()
|
||||
server = _make_server(bus)
|
||||
|
||||
# Should be a no-op, not raise
|
||||
await server.stop()
|
||||
assert not server.is_running
|
||||
|
||||
|
||||
class TestWebhookEventPublishing:
|
||||
"""Tests for HTTP request -> EventBus event publishing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_publishes_webhook_received(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="gh", path="/webhooks/github", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/github",
|
||||
json={"action": "opened", "number": 42},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
body = await resp.json()
|
||||
assert body["status"] == "accepted"
|
||||
|
||||
# Give event bus time to dispatch
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
event = received[0]
|
||||
assert event.type == EventType.WEBHOOK_RECEIVED
|
||||
assert event.stream_id == "gh"
|
||||
assert event.data["path"] == "/webhooks/github"
|
||||
assert event.data["method"] == "POST"
|
||||
assert event.data["payload"] == {"action": "opened", "number": 42}
|
||||
assert isinstance(event.data["headers"], dict)
|
||||
assert event.data["query_params"] == {}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_params_included(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="hook", path="/webhooks/hook", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/hook?source=test&v=2",
|
||||
json={"data": "hello"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["query_params"] == {"source": "test", "v": "2"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_json_body(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="raw", path="/webhooks/raw", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/raw",
|
||||
data=b"plain text body",
|
||||
headers={"Content-Type": "text/plain"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["payload"] == {"raw_body": "plain text body"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_body(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="empty", path="/webhooks/empty", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{_base_url(server)}/webhooks/empty") as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["payload"] == {}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_routes(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="a", path="/webhooks/a", methods=["POST"]),
|
||||
WebhookRoute(source_id="b", path="/webhooks/b", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/a", json={"from": "a"}
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/b", json={"from": "b"}
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(received) == 2
|
||||
stream_ids = {e.stream_id for e in received}
|
||||
assert stream_ids == {"a", "b"}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_stream_subscription(self):
|
||||
"""Subscribers can filter by stream_id (source_id)."""
|
||||
bus = EventBus()
|
||||
a_events = []
|
||||
b_events = []
|
||||
|
||||
async def handle_a(event):
|
||||
a_events.append(event)
|
||||
|
||||
async def handle_b(event):
|
||||
b_events.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handle_a, filter_stream="a")
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handle_b, filter_stream="b")
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(source_id="a", path="/webhooks/a", methods=["POST"]),
|
||||
WebhookRoute(source_id="b", path="/webhooks/b", methods=["POST"]),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(f"{_base_url(server)}/webhooks/a", json={"x": 1})
|
||||
await session.post(f"{_base_url(server)}/webhooks/b", json={"x": 2})
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(a_events) == 1
|
||||
assert a_events[0].data["payload"] == {"x": 1}
|
||||
assert len(b_events) == 1
|
||||
assert b_events[0].data["payload"] == {"x": 2}
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
class TestHMACVerification:
|
||||
"""Tests for HMAC-SHA256 signature verification."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_signature_accepted(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
secret = "test-secret-key"
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret=secret,
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
body = json.dumps({"event": "push"}).encode()
|
||||
sig = hmac_mod.new(secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Hub-Signature-256": f"sha256={sig}",
|
||||
},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 1
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_signature_rejected(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret="real-secret",
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
json={"event": "push"},
|
||||
headers={"X-Hub-Signature-256": "sha256=invalidsignature"},
|
||||
) as resp:
|
||||
assert resp.status == 401
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 0 # No event published
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_signature_rejected(self):
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="secure",
|
||||
path="/webhooks/secure",
|
||||
methods=["POST"],
|
||||
secret="my-secret",
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# No X-Hub-Signature-256 header
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/secure",
|
||||
json={"event": "push"},
|
||||
) as resp:
|
||||
assert resp.status == 401
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 0
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_secret_skips_verification(self):
|
||||
"""Routes without a secret accept any request."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe([EventType.WEBHOOK_RECEIVED], handler)
|
||||
|
||||
server = _make_server(
|
||||
bus,
|
||||
[
|
||||
WebhookRoute(
|
||||
source_id="open",
|
||||
path="/webhooks/open",
|
||||
methods=["POST"],
|
||||
secret=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
await server.start()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{_base_url(server)}/webhooks/open",
|
||||
json={"data": "test"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 1
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
class TestEventDrivenEntryPoints:
|
||||
"""Tests for event-driven entry points wired through AgentRuntime."""
|
||||
|
||||
def _make_graph_and_goal(self):
|
||||
"""Minimal graph + goal for testing entry point triggering."""
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import SuccessCriterion
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="process-event",
|
||||
name="Process Event",
|
||||
description="Process incoming event",
|
||||
node_type="llm_generate",
|
||||
input_keys=["event"],
|
||||
output_keys=["result"],
|
||||
),
|
||||
]
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id="test-goal",
|
||||
version="1.0.0",
|
||||
entry_node="process-event",
|
||||
entry_points={"start": "process-event"},
|
||||
async_entry_points=[],
|
||||
terminal_nodes=[],
|
||||
pause_nodes=[],
|
||||
nodes=nodes,
|
||||
edges=[],
|
||||
)
|
||||
goal = Goal(
|
||||
id="test-goal",
|
||||
name="Test Goal",
|
||||
description="Test",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="sc-1",
|
||||
description="Done",
|
||||
metric="done",
|
||||
target="yes",
|
||||
weight=1.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
return graph, goal
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_subscribes_to_bus(self):
|
||||
"""Entry point with trigger_type='event' subscribes and triggers on matching events."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
config = AgentRuntimeConfig(
|
||||
webhook_host="127.0.0.1",
|
||||
webhook_port=0,
|
||||
webhook_routes=[
|
||||
{"source_id": "gh", "path": "/webhooks/github"},
|
||||
],
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
config=config,
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="gh-handler",
|
||||
name="GitHub Handler",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["webhook_received"],
|
||||
"filter_stream": "gh",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
assert runtime.webhook_server is not None
|
||||
assert runtime.webhook_server.is_running
|
||||
|
||||
port = runtime.webhook_server.port
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/github",
|
||||
json={"action": "push", "ref": "main"},
|
||||
) as resp:
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
ep_id, data = trigger_calls[0]
|
||||
assert ep_id == "gh-handler"
|
||||
assert "event" in data
|
||||
assert data["event"]["type"] == "webhook_received"
|
||||
assert data["event"]["stream_id"] == "gh"
|
||||
assert data["event"]["data"]["payload"] == {
|
||||
"action": "push",
|
||||
"ref": "main",
|
||||
}
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
assert runtime.webhook_server is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_filter_stream(self):
|
||||
"""Entry point only triggers for matching stream_id (source_id)."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
config = AgentRuntimeConfig(
|
||||
webhook_routes=[
|
||||
{"source_id": "github", "path": "/webhooks/github"},
|
||||
{"source_id": "stripe", "path": "/webhooks/stripe"},
|
||||
],
|
||||
webhook_port=0,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
config=config,
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="gh-only",
|
||||
name="GitHub Only",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["webhook_received"],
|
||||
"filter_stream": "github",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
port = runtime.webhook_server.port
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# POST to stripe — should NOT trigger
|
||||
await session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/stripe",
|
||||
json={"type": "payment"},
|
||||
)
|
||||
# POST to github — should trigger
|
||||
await session.post(
|
||||
f"http://127.0.0.1:{port}/webhooks/github",
|
||||
json={"action": "opened"},
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
assert trigger_calls[0][0] == "gh-only"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_webhook_routes_skips_server(self):
|
||||
"""Runtime without webhook_routes does not start a webhook server."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="manual",
|
||||
name="Manual",
|
||||
entry_node="process-event",
|
||||
trigger_type="manual",
|
||||
)
|
||||
)
|
||||
|
||||
await runtime.start()
|
||||
try:
|
||||
assert runtime.webhook_server is None
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_entry_point_custom_event(self):
|
||||
"""Entry point can subscribe to CUSTOM events, not just webhooks."""
|
||||
graph, goal = self._make_graph_and_goal()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runtime = AgentRuntime(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
storage_path=Path(tmpdir),
|
||||
)
|
||||
|
||||
runtime.register_entry_point(
|
||||
EntryPointSpec(
|
||||
id="custom-handler",
|
||||
name="Custom Handler",
|
||||
entry_node="process-event",
|
||||
trigger_type="event",
|
||||
trigger_config={
|
||||
"event_types": ["custom"],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger_calls = []
|
||||
|
||||
async def mock_trigger(ep_id, data, **kwargs):
|
||||
trigger_calls.append((ep_id, data))
|
||||
|
||||
with patch.object(runtime, "trigger", side_effect=mock_trigger):
|
||||
await runtime.start()
|
||||
|
||||
try:
|
||||
await runtime.event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CUSTOM,
|
||||
stream_id="some-source",
|
||||
data={"key": "value"},
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(trigger_calls) == 1
|
||||
assert trigger_calls[0][0] == "custom-handler"
|
||||
assert trigger_calls[0][1]["event"]["type"] == "custom"
|
||||
assert trigger_calls[0][1]["event"]["data"]["key"] == "value"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Webhook HTTP Server - Receives HTTP requests and publishes them as EventBus events.
|
||||
|
||||
Only starts if webhook-type entry points are registered. Uses aiohttp for
|
||||
a lightweight embedded HTTP server that runs within the existing asyncio loop.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebhookRoute:
|
||||
"""A registered webhook route derived from an EntryPointSpec."""
|
||||
|
||||
source_id: str
|
||||
path: str
|
||||
methods: list[str]
|
||||
secret: str | None = None # For HMAC-SHA256 signature verification
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebhookServerConfig:
|
||||
"""Configuration for the webhook HTTP server."""
|
||||
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8080
|
||||
|
||||
|
||||
class WebhookServer:
|
||||
"""
|
||||
Embedded HTTP server that receives webhook requests and publishes
|
||||
them as WEBHOOK_RECEIVED events on the EventBus.
|
||||
|
||||
The server's only job is: receive HTTP -> publish AgentEvent.
|
||||
Subscribers decide what to do with the event.
|
||||
|
||||
Lifecycle:
|
||||
server = WebhookServer(event_bus, config)
|
||||
server.add_route(WebhookRoute(...))
|
||||
await server.start()
|
||||
# ... server running ...
|
||||
await server.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_bus: EventBus,
|
||||
config: WebhookServerConfig | None = None,
|
||||
):
|
||||
self._event_bus = event_bus
|
||||
self._config = config or WebhookServerConfig()
|
||||
self._routes: dict[str, WebhookRoute] = {} # path -> route
|
||||
self._app: web.Application | None = None
|
||||
self._runner: web.AppRunner | None = None
|
||||
self._site: web.TCPSite | None = None
|
||||
|
||||
def add_route(self, route: WebhookRoute) -> None:
|
||||
"""Register a webhook route."""
|
||||
self._routes[route.path] = route
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the HTTP server. No-op if no routes registered."""
|
||||
if not self._routes:
|
||||
logger.debug("No webhook routes registered, skipping server start")
|
||||
return
|
||||
|
||||
self._app = web.Application()
|
||||
|
||||
for path, route in self._routes.items():
|
||||
for method in route.methods:
|
||||
self._app.router.add_route(method, path, self._handle_request)
|
||||
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
self._site = web.TCPSite(
|
||||
self._runner,
|
||||
self._config.host,
|
||||
self._config.port,
|
||||
)
|
||||
await self._site.start()
|
||||
logger.info(
|
||||
f"Webhook server started on {self._config.host}:{self._config.port} "
|
||||
f"with {len(self._routes)} route(s)"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the HTTP server gracefully."""
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._app = None
|
||||
self._site = None
|
||||
logger.info("Webhook server stopped")
|
||||
|
||||
async def _handle_request(self, request: web.Request) -> web.Response:
|
||||
"""Handle an incoming webhook request."""
|
||||
path = request.path
|
||||
route = self._routes.get(path)
|
||||
|
||||
if route is None:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
|
||||
# Read body
|
||||
try:
|
||||
body = await request.read()
|
||||
except Exception:
|
||||
return web.json_response(
|
||||
{"error": "Failed to read request body"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Verify HMAC signature if secret is configured
|
||||
if route.secret:
|
||||
if not self._verify_signature(request, body, route.secret):
|
||||
return web.json_response({"error": "Invalid signature"}, status=401)
|
||||
|
||||
# Parse body as JSON (fall back to raw text for non-JSON)
|
||||
try:
|
||||
payload = json.loads(body) if body else {}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = {"raw_body": body.decode("utf-8", errors="replace")}
|
||||
|
||||
# Publish event to bus
|
||||
await self._event_bus.emit_webhook_received(
|
||||
source_id=route.source_id,
|
||||
path=path,
|
||||
method=request.method,
|
||||
headers=dict(request.headers),
|
||||
payload=payload,
|
||||
query_params=dict(request.query),
|
||||
)
|
||||
|
||||
return web.json_response({"status": "accepted"}, status=202)
|
||||
|
||||
def _verify_signature(
|
||||
self,
|
||||
request: web.Request,
|
||||
body: bytes,
|
||||
secret: str,
|
||||
) -> bool:
|
||||
"""Verify HMAC-SHA256 signature from X-Hub-Signature-256 header."""
|
||||
signature_header = request.headers.get("X-Hub-Signature-256", "")
|
||||
if not signature_header.startswith("sha256="):
|
||||
return False
|
||||
|
||||
expected_sig = signature_header[7:] # strip "sha256="
|
||||
computed_sig = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
body,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
return hmac.compare_digest(expected_sig, computed_sig)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the server is running."""
|
||||
return self._site is not None
|
||||
|
||||
@property
|
||||
def port(self) -> int | None:
|
||||
"""Return the actual listening port (useful when configured with port=0)."""
|
||||
if self._site and self._site._server and self._site._server.sockets:
|
||||
return self._site._server.sockets[0].getsockname()[1]
|
||||
return None
|
||||
@@ -351,6 +351,13 @@ class AdenTUI(App):
|
||||
EventType.STATE_CHANGED,
|
||||
EventType.NODE_INPUT_BLOCKED,
|
||||
EventType.CONTEXT_COMPACTED,
|
||||
EventType.NODE_INTERNAL_OUTPUT,
|
||||
EventType.JUDGE_VERDICT,
|
||||
EventType.OUTPUT_KEY_SET,
|
||||
EventType.NODE_RETRY,
|
||||
EventType.EDGE_TRAVERSED,
|
||||
EventType.EXECUTION_PAUSED,
|
||||
EventType.EXECUTION_RESUMED,
|
||||
]
|
||||
|
||||
_LOG_PANE_EVENTS = frozenset(_EVENT_TYPES) - {
|
||||
@@ -409,6 +416,31 @@ class AdenTUI(App):
|
||||
event.node_id or event.data.get("node_id", ""),
|
||||
)
|
||||
|
||||
# Track active node in chat_repl for mid-execution input
|
||||
if et == EventType.NODE_LOOP_STARTED:
|
||||
self.chat_repl.handle_node_started(event.node_id or "")
|
||||
elif et == EventType.NODE_LOOP_COMPLETED:
|
||||
self.chat_repl.handle_node_completed(event.node_id or "")
|
||||
|
||||
# Non-client-facing node output → chat repl
|
||||
if et == EventType.NODE_INTERNAL_OUTPUT:
|
||||
content = event.data.get("content", "")
|
||||
if content.strip():
|
||||
self.chat_repl.handle_internal_output(event.node_id or "", content)
|
||||
|
||||
# Execution paused/resumed → chat repl
|
||||
if et == EventType.EXECUTION_PAUSED:
|
||||
reason = event.data.get("reason", "")
|
||||
self.chat_repl.handle_execution_paused(event.node_id or "", reason)
|
||||
elif et == EventType.EXECUTION_RESUMED:
|
||||
self.chat_repl.handle_execution_resumed(event.node_id or "")
|
||||
|
||||
# Goal achieved / constraint violation → chat repl
|
||||
if et == EventType.GOAL_ACHIEVED:
|
||||
self.chat_repl.handle_goal_achieved(event.data)
|
||||
elif et == EventType.CONSTRAINT_VIOLATION:
|
||||
self.chat_repl.handle_constraint_violation(event.data)
|
||||
|
||||
# --- Graph view events ---
|
||||
if et in (
|
||||
EventType.EXECUTION_STARTED,
|
||||
@@ -445,6 +477,13 @@ class AdenTUI(App):
|
||||
started=False,
|
||||
)
|
||||
|
||||
# Edge traversal → graph view
|
||||
if et == EventType.EDGE_TRAVERSED:
|
||||
self.graph_view.handle_edge_traversed(
|
||||
event.data.get("source_node", ""),
|
||||
event.data.get("target_node", ""),
|
||||
)
|
||||
|
||||
# --- Status bar events ---
|
||||
if et == EventType.EXECUTION_STARTED:
|
||||
entry_node = event.data.get("entry_node") or (
|
||||
@@ -469,6 +508,20 @@ class AdenTUI(App):
|
||||
before = event.data.get("usage_before", "?")
|
||||
after = event.data.get("usage_after", "?")
|
||||
self.status_bar.set_node_detail(f"compacted: {before}% \u2192 {after}%")
|
||||
elif et == EventType.JUDGE_VERDICT:
|
||||
action = event.data.get("action", "?")
|
||||
self.status_bar.set_node_detail(f"judge: {action}")
|
||||
elif et == EventType.OUTPUT_KEY_SET:
|
||||
key = event.data.get("key", "?")
|
||||
self.status_bar.set_node_detail(f"set: {key}")
|
||||
elif et == EventType.NODE_RETRY:
|
||||
retry = event.data.get("retry_count", "?")
|
||||
max_r = event.data.get("max_retries", "?")
|
||||
self.status_bar.set_node_detail(f"retry {retry}/{max_r}")
|
||||
elif et == EventType.EXECUTION_PAUSED:
|
||||
self.status_bar.set_node_detail("paused")
|
||||
elif et == EventType.EXECUTION_RESUMED:
|
||||
self.status_bar.set_node_detail("resumed")
|
||||
|
||||
# --- Log pane events ---
|
||||
if et in self._LOG_PANE_EVENTS:
|
||||
|
||||
@@ -110,6 +110,7 @@ class ChatRepl(Vertical):
|
||||
self._waiting_for_input: bool = False
|
||||
self._input_node_id: str | None = None
|
||||
self._pending_ask_question: str = ""
|
||||
self._active_node_id: str | None = None # Currently executing node
|
||||
self._resume_session = resume_session
|
||||
self._resume_checkpoint = resume_checkpoint
|
||||
self._session_index: list[str] = [] # IDs from last listing
|
||||
@@ -813,7 +814,22 @@ class ChatRepl(Vertical):
|
||||
self._write_history(f"[bold red]Error delivering input:[/bold red] {e}")
|
||||
return
|
||||
|
||||
# Double-submit guard: reject input while an execution is in-flight
|
||||
# Mid-execution input: inject into the active node's conversation
|
||||
if self._current_exec_id is not None and self._active_node_id:
|
||||
self._write_history(f"[bold green]You:[/bold green] {user_input}")
|
||||
message.input.value = ""
|
||||
node_id = self._active_node_id
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.runtime.inject_input(node_id, user_input),
|
||||
self._agent_loop,
|
||||
)
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception as e:
|
||||
self._write_history(f"[bold red]Error delivering input:[/bold red] {e}")
|
||||
return
|
||||
|
||||
# Double-submit guard: no active node to inject into
|
||||
if self._current_exec_id is not None:
|
||||
self._write_history("[dim]Agent is still running — please wait.[/dim]")
|
||||
return
|
||||
@@ -941,6 +957,7 @@ class ChatRepl(Vertical):
|
||||
self._streaming_snapshot = ""
|
||||
self._waiting_for_input = False
|
||||
self._input_node_id = None
|
||||
self._active_node_id = None
|
||||
self._pending_ask_question = ""
|
||||
|
||||
# Re-enable input
|
||||
@@ -962,6 +979,7 @@ class ChatRepl(Vertical):
|
||||
self._waiting_for_input = False
|
||||
self._pending_ask_question = ""
|
||||
self._input_node_id = None
|
||||
self._active_node_id = None
|
||||
|
||||
# Re-enable input
|
||||
chat_input = self.query_one("#chat-input", ChatTextArea)
|
||||
@@ -999,3 +1017,36 @@ class ChatRepl(Vertical):
|
||||
chat_input.disabled = False
|
||||
chat_input.placeholder = "Type your response..."
|
||||
chat_input.focus()
|
||||
|
||||
def handle_node_started(self, node_id: str) -> None:
|
||||
"""Track which node is currently executing."""
|
||||
self._active_node_id = node_id
|
||||
|
||||
def handle_node_completed(self, node_id: str) -> None:
|
||||
"""Clear active node when it finishes."""
|
||||
if self._active_node_id == node_id:
|
||||
self._active_node_id = None
|
||||
|
||||
def handle_internal_output(self, node_id: str, content: str) -> None:
|
||||
"""Show output from non-client-facing nodes."""
|
||||
self._write_history(f"[dim cyan]⟨{node_id}⟩[/dim cyan] {content}")
|
||||
|
||||
def handle_execution_paused(self, node_id: str, reason: str) -> None:
|
||||
"""Show that execution has been paused."""
|
||||
msg = f"[bold yellow]⏸ Paused[/bold yellow] at [cyan]{node_id}[/cyan]"
|
||||
if reason:
|
||||
msg += f" [dim]({reason})[/dim]"
|
||||
self._write_history(msg)
|
||||
|
||||
def handle_execution_resumed(self, node_id: str) -> None:
|
||||
"""Show that execution has been resumed."""
|
||||
self._write_history(f"[bold green]▶ Resumed[/bold green] from [cyan]{node_id}[/cyan]")
|
||||
|
||||
def handle_goal_achieved(self, data: dict[str, Any]) -> None:
|
||||
"""Show goal achievement prominently."""
|
||||
self._write_history("[bold green]★ Goal achieved![/bold green]")
|
||||
|
||||
def handle_constraint_violation(self, data: dict[str, Any]) -> None:
|
||||
"""Show constraint violation as a warning."""
|
||||
desc = data.get("description", "Unknown constraint")
|
||||
self._write_history(f"[bold red]⚠ Constraint violation:[/bold red] {desc}")
|
||||
|
||||
@@ -192,3 +192,8 @@ class GraphOverview(Vertical):
|
||||
"""Highlight a stalled node."""
|
||||
self._node_status[node_id] = f"[red]stalled: {reason}[/red]"
|
||||
self._display_graph()
|
||||
|
||||
def handle_edge_traversed(self, source_node: str, target_node: str) -> None:
|
||||
"""Highlight an edge being traversed."""
|
||||
self._node_status[source_node] = f"[dim]→ {target_node}[/dim]"
|
||||
self._display_graph()
|
||||
|
||||
@@ -20,6 +20,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
tui = ["textual>=0.75.0"]
|
||||
webhook = ["aiohttp>=3.9.0"]
|
||||
|
||||
[project.scripts]
|
||||
hive = "framework.cli:main"
|
||||
|
||||
@@ -143,6 +143,18 @@ class FakeEventBus:
|
||||
async def emit_node_loop_completed(self, **kwargs):
|
||||
self.events.append(("completed", kwargs))
|
||||
|
||||
async def emit_edge_traversed(self, **kwargs):
|
||||
self.events.append(("edge_traversed", kwargs))
|
||||
|
||||
async def emit_execution_paused(self, **kwargs):
|
||||
self.events.append(("execution_paused", kwargs))
|
||||
|
||||
async def emit_execution_resumed(self, **kwargs):
|
||||
self.events.append(("execution_resumed", kwargs))
|
||||
|
||||
async def emit_node_retry(self, **kwargs):
|
||||
self.events.append(("node_retry", kwargs))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_emits_node_events():
|
||||
@@ -201,15 +213,19 @@ async def test_executor_emits_node_events():
|
||||
assert result.success is True
|
||||
assert result.path == ["n1", "n2"]
|
||||
|
||||
# Should have 4 events: started/completed for n1, then started/completed for n2
|
||||
assert len(event_bus.events) == 4
|
||||
# Should have 5 events: started/completed for n1, edge_traversed, then started/completed for n2
|
||||
assert len(event_bus.events) == 5
|
||||
assert event_bus.events[0] == ("started", {"stream_id": "test-stream", "node_id": "n1"})
|
||||
assert event_bus.events[1] == (
|
||||
"completed",
|
||||
{"stream_id": "test-stream", "node_id": "n1", "iterations": 1},
|
||||
)
|
||||
assert event_bus.events[2] == ("started", {"stream_id": "test-stream", "node_id": "n2"})
|
||||
assert event_bus.events[3] == (
|
||||
assert event_bus.events[2] == (
|
||||
"edge_traversed",
|
||||
{"stream_id": "test-stream", "source_node": "n1", "target_node": "n2"},
|
||||
)
|
||||
assert event_bus.events[3] == ("started", {"stream_id": "test-stream", "node_id": "n2"})
|
||||
assert event_bus.events[4] == (
|
||||
"completed",
|
||||
{"stream_id": "test-stream", "node_id": "n2", "iterations": 1},
|
||||
)
|
||||
|
||||
@@ -783,6 +783,9 @@ dependencies = [
|
||||
tui = [
|
||||
{ name = "textual" },
|
||||
]
|
||||
webhook = [
|
||||
{ name = "aiohttp" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
@@ -792,6 +795,7 @@ dev = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiohttp", marker = "extra == 'webhook'", specifier = ">=3.9.0" },
|
||||
{ name = "anthropic", specifier = ">=0.40.0" },
|
||||
{ name = "fastmcp", specifier = ">=2.0.0" },
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
@@ -805,7 +809,7 @@ requires-dist = [
|
||||
{ name = "textual", marker = "extra == 'tui'", specifier = ">=0.75.0" },
|
||||
{ name = "tools", editable = "tools" },
|
||||
]
|
||||
provides-extras = ["tui"]
|
||||
provides-extras = ["tui", "webhook"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
|
||||
Reference in New Issue
Block a user