Compare commits

...

23 Commits

Author SHA1 Message Date
bryan d49e858d32 lint update 2026-02-02 11:12:09 -08:00
bryan d7afa5dcf2 wp-12 2026-02-02 10:41:12 -08:00
Timothy 22e816bf86 chore: update gitignore 2026-02-02 10:30:03 -08:00
Timothy @aden 3240616808 Merge pull request #3250 from adenhq/feat/validation-client-facing
(micro-fix): added graph validation for client-facing nodes [WP-10]
2026-02-02 10:02:38 -08:00
Timothy @aden b9f83d4d61 Merge pull request #3244 from TimothyZhang7/feature/aden-sync-by-provider
Feature/aden sync by provider
2026-02-02 09:39:00 -08:00
Timothy @aden 9c16826ad3 Merge pull request #3137 from adenhq/feat/clientIO-gateway
implemented clientIO gateway [WP-9]
2026-02-02 07:29:03 -08:00
Timothy df4d0ad3fd feat: aden provider credential store by provider 2026-02-01 20:34:21 -08:00
bryan 9034d1dc71 lint fix 2026-02-01 20:26:36 -08:00
bryan 537172d8ce implemented clientIO gateway [WP-9] 2026-02-01 20:23:26 -08:00
Timothy 20b2e4b3dd fix: robust compaction logic 2026-02-01 19:59:27 -08:00
Timothy @aden a86043a2ec Merge pull request #3127 from TimothyZhang7/feature/event-loop-wp8
Feature/event loop wp8
2026-02-01 19:07:33 -08:00
Timothy 3947da2cf1 Merge upstream/event-loop-arch into feature/event-loop-wp8
Brings in upstream changes: email tool, csv/pdf fixes, docs updates,
agent builder export atomicity fix, JSON extraction validation bugfix.
No conflicts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-01 19:06:30 -08:00
Timothy 17caab6563 feature: remove hard failure on schema mismatch for context hand off 2026-02-01 18:55:41 -08:00
Timothy @aden a5ae071a03 Merge pull request #723 from trinh31201/bugfix/json-extraction-validation
micro-fix(graph): validate LLM JSON extraction to prevent empty/fabricated data
2026-02-01 18:51:51 -08:00
bryan 9c33da7b8d added graph validation for client-facing nodes [WP-10] 2026-02-01 18:45:35 -08:00
Timothy 94d31743b0 fix: sync with wp7 2026-02-01 18:14:04 -08:00
Timothy 70db618c6e feat: event loop node implementation 2026-02-01 17:16:18 -08:00
Anshumaan Saraf 23146c8dae docs: remove duplicate entry in Edge Protocol docstring (#2994)
Fixes #2717

The Edge Types list in edge.py had 'always' listed twice.
Removed the duplicate line.
2026-02-01 15:19:11 +08:00
Timothy c52ce6bb49 Merge branch 'feature/event-loop-framework' into test/wp1-wp2-wp6-combined 2026-01-30 16:34:12 -08:00
Timothy bcddd4ce77 Merge branch 'feature/credential-manager-aden-provider' into test/wp1-wp2-wp6-combined 2026-01-30 16:30:54 -08:00
Timothy 017872f71b feat: emit bus events 2026-01-30 16:27:39 -08:00
Timothy 7e670ce0a8 feat: event loop WP1-4 2026-01-30 11:43:19 -08:00
trinh31201 3ee6d98905 fix(graph): validate LLM JSON extraction to prevent empty/fabricated data 2026-01-29 01:04:08 +07:00
30 changed files with 9094 additions and 57 deletions
+7 -3
View File
@@ -32,9 +32,13 @@
"mcp__agent-builder__verify_credentials",
"Bash(PYTHONPATH=/home/timothy/oss/hive/core:/home/timothy/oss/hive/exports python:*)",
"Bash(PYTHONPATH=core:exports:tools/src python -m hubspot_input:*)",
"mcp__agent-builder__export_graph"
"mcp__agent-builder__export_graph",
"Bash(python3:*)"
]
},
"enabledMcpjsonServers": ["agent-builder", "tools"],
"enableAllProjectMcpServers": true
"enableAllProjectMcpServers": true,
"enabledMcpjsonServers": [
"agent-builder",
"tools"
]
}
+4 -1
View File
@@ -69,4 +69,7 @@ exports/*
.agent-builder-sessions/*
.venv
.venv
docs/github-issues/*
core/tests/*dumps/*
+779
View File
@@ -0,0 +1,779 @@
#!/usr/bin/env python3
"""
EventLoopNode WebSocket Demo
Real LLM, real FileConversationStore, real EventBus.
Streams EventLoopNode execution to a browser via WebSocket.
Usage:
cd /home/timothy/oss/hive/core
python demos/event_loop_wss_demo.py
Then open http://localhost:8765 in your browser.
"""
import asyncio
import json
import logging
import sys
import tempfile
from http import HTTPStatus
from pathlib import Path
import httpx
import websockets
from bs4 import BeautifulSoup
from websockets.http11 import Request, Response
# Add core, tools, and hive root to path
_CORE_DIR = Path(__file__).resolve().parent.parent
_HIVE_DIR = _CORE_DIR.parent
sys.path.insert(0, str(_CORE_DIR)) # framework.*
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
import os # noqa: E402
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
from core.framework.credentials import CredentialStore # noqa: E402
from framework.credentials.storage import ( # noqa: E402
CompositeStorage,
EncryptedFileStorage,
EnvVarStorage,
)
from framework.graph.event_loop_node import EventLoopNode, JudgeVerdict, LoopConfig # noqa: E402
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
from framework.llm.litellm import LiteLLMProvider # noqa: E402
from framework.llm.provider import Tool # noqa: E402
from framework.runner.tool_registry import ToolRegistry # noqa: E402
from framework.runtime.core import Runtime # noqa: E402
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
from framework.storage.conversation_store import FileConversationStore # noqa: E402
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
logger = logging.getLogger("demo")
# -------------------------------------------------------------------------
# Persistent state (shared across WebSocket connections)
# -------------------------------------------------------------------------
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_demo_"))
STORE = FileConversationStore(STORE_DIR / "conversation")
RUNTIME = Runtime(STORE_DIR / "runtime")
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
# -------------------------------------------------------------------------
# Tool Registry — real tools via ToolRegistry (same pattern as GraphExecutor)
# -------------------------------------------------------------------------
TOOL_REGISTRY = ToolRegistry()
# Credential store: Aden sync (OAuth2 tokens) + encrypted files + env var fallback
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
_local_storage = CompositeStorage(
primary=EncryptedFileStorage(),
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
)
if os.environ.get("ADEN_API_KEY"):
try:
from framework.credentials.aden import ( # noqa: E402
AdenCachedStorage,
AdenClientConfig,
AdenCredentialClient,
AdenSyncProvider,
)
_client = AdenCredentialClient(AdenClientConfig(base_url="https://api.adenhq.com"))
_provider = AdenSyncProvider(client=_client)
_storage = AdenCachedStorage(
local_storage=_local_storage,
aden_provider=_provider,
)
_cred_store = CredentialStore(storage=_storage, providers=[_provider], auto_refresh=True)
_synced = _provider.sync_all(_cred_store)
logger.info("Synced %d credentials from Aden", _synced)
except Exception as e:
logger.warning("Aden sync unavailable: %s", e)
_cred_store = CredentialStore(storage=_local_storage)
else:
logger.info("ADEN_API_KEY not set, using local credential storage")
_cred_store = CredentialStore(storage=_local_storage)
CREDENTIALS = CredentialStoreAdapter(_cred_store)
# Debug: log which credentials resolved
for _name in ["brave_search", "hubspot", "anthropic"]:
_val = CREDENTIALS.get(_name)
if _val:
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
else:
logger.debug("credential %s: not found", _name)
# --- web_search (Brave Search API) ---
TOOL_REGISTRY.register(
name="web_search",
tool=Tool(
name="web_search",
description=(
"Search the web for current information. "
"Returns titles, URLs, and snippets from search results."
),
parameters={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query (1-500 characters)",
},
"num_results": {
"type": "integer",
"description": "Number of results to return (1-20, default 10)",
},
},
"required": ["query"],
},
),
executor=lambda inputs: _exec_web_search(inputs),
)
def _exec_web_search(inputs: dict) -> dict:
api_key = CREDENTIALS.get("brave_search")
if not api_key:
return {"error": "brave_search credential not configured"}
query = inputs.get("query", "")
num_results = min(inputs.get("num_results", 10), 20)
resp = httpx.get(
"https://api.search.brave.com/res/v1/web/search",
params={"q": query, "count": num_results},
headers={"X-Subscription-Token": api_key, "Accept": "application/json"},
timeout=30.0,
)
if resp.status_code != 200:
return {"error": f"Brave API HTTP {resp.status_code}"}
data = resp.json()
results = [
{
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("description", ""),
}
for item in data.get("web", {}).get("results", [])[:num_results]
]
return {"query": query, "results": results, "total": len(results)}
# --- web_scrape (httpx + BeautifulSoup, no playwright for sync compat) ---
TOOL_REGISTRY.register(
name="web_scrape",
tool=Tool(
name="web_scrape",
description=(
"Scrape and extract text content from a webpage URL. "
"Returns the page title and main text content."
),
parameters={
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL of the webpage to scrape",
},
"max_length": {
"type": "integer",
"description": "Maximum text length (default 50000)",
},
},
"required": ["url"],
},
),
executor=lambda inputs: _exec_web_scrape(inputs),
)
_SCRAPE_HEADERS = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/131.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml",
}
def _exec_web_scrape(inputs: dict) -> dict:
url = inputs.get("url", "")
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
if not url.startswith(("http://", "https://")):
url = "https://" + url
try:
resp = httpx.get(url, timeout=30.0, follow_redirects=True, headers=_SCRAPE_HEADERS)
if resp.status_code != 200:
return {"error": f"HTTP {resp.status_code}"}
soup = BeautifulSoup(resp.text, "html.parser")
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
tag.decompose()
title = soup.title.get_text(strip=True) if soup.title else ""
main = (
soup.find("article")
or soup.find("main")
or soup.find(attrs={"role": "main"})
or soup.find("body")
)
text = main.get_text(separator=" ", strip=True) if main else ""
text = " ".join(text.split())
if len(text) > max_length:
text = text[:max_length] + "..."
return {"url": url, "title": title, "content": text, "length": len(text)}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except Exception as e:
return {"error": f"Scrape failed: {e}"}
# --- HubSpot CRM tools (optional, requires HUBSPOT_ACCESS_TOKEN) ---
_HUBSPOT_API = "https://api.hubapi.com"
def _hubspot_headers() -> dict | None:
token = CREDENTIALS.get("hubspot")
if token:
logger.debug("HubSpot token: %s...%s (len=%d)", token[:8], token[-4:], len(token))
else:
logger.debug("HubSpot token: not found")
if not token:
return None
return {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
def _exec_hubspot_search(inputs: dict) -> dict:
headers = _hubspot_headers()
if not headers:
return {"error": "HUBSPOT_ACCESS_TOKEN not set"}
object_type = inputs.get("object_type", "contacts")
query = inputs.get("query", "")
limit = min(inputs.get("limit", 10), 100)
body: dict = {"limit": limit}
if query:
body["query"] = query
try:
resp = httpx.post(
f"{_HUBSPOT_API}/crm/v3/objects/{object_type}/search",
headers=headers,
json=body,
timeout=30.0,
)
if resp.status_code != 200:
return {"error": f"HubSpot API HTTP {resp.status_code}: {resp.text[:200]}"}
return resp.json()
except httpx.TimeoutException:
return {"error": "Request timed out"}
except Exception as e:
return {"error": f"HubSpot error: {e}"}
TOOL_REGISTRY.register(
name="hubspot_search",
tool=Tool(
name="hubspot_search",
description=(
"Search HubSpot CRM objects (contacts, companies, or deals). "
"Returns matching records with their properties."
),
parameters={
"type": "object",
"properties": {
"object_type": {
"type": "string",
"description": "CRM object type: 'contacts', 'companies', or 'deals'",
},
"query": {
"type": "string",
"description": "Search query (name, email, domain, etc.)",
},
"limit": {
"type": "integer",
"description": "Max results (1-100, default 10)",
},
},
"required": ["object_type"],
},
),
executor=lambda inputs: _exec_hubspot_search(inputs),
)
logger.info(
"ToolRegistry loaded: %s",
", ".join(TOOL_REGISTRY.get_registered_names()),
)
# -------------------------------------------------------------------------
# ChatJudge — keeps the event loop alive between user messages
# -------------------------------------------------------------------------
class ChatJudge:
"""Judge that blocks between user messages, keeping the loop alive.
After the LLM finishes responding, the judge awaits a signal indicating
a new user message has been injected, then returns RETRY to continue.
"""
def __init__(self, on_ready=None):
self._message_ready = asyncio.Event()
self._shutdown = False
self._on_ready = on_ready # async callback fired when waiting for input
async def evaluate(self, context: dict) -> JudgeVerdict:
# Notify client that the LLM is done — ready for next input
if self._on_ready:
await self._on_ready()
# Block until next user message (or shutdown)
self._message_ready.clear()
await self._message_ready.wait()
if self._shutdown:
return JudgeVerdict(action="ACCEPT")
return JudgeVerdict(action="RETRY")
def signal_message(self):
"""Unblock the judge — a new user message has been injected."""
self._message_ready.set()
def signal_shutdown(self):
"""Unblock the judge and let the loop exit cleanly."""
self._shutdown = True
self._message_ready.set()
# -------------------------------------------------------------------------
# HTML page (embedded)
# -------------------------------------------------------------------------
HTML_PAGE = ( # noqa: E501
"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>EventLoopNode Live Demo</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: 'SF Mono', 'Fira Code', monospace;
background: #0d1117; color: #c9d1d9;
height: 100vh; display: flex; flex-direction: column;
}
header {
background: #161b22; padding: 12px 20px;
border-bottom: 1px solid #30363d;
display: flex; align-items: center; gap: 16px;
}
header h1 { font-size: 16px; color: #58a6ff; font-weight: 600; }
.status {
font-size: 12px; padding: 3px 10px; border-radius: 12px;
background: #21262d; color: #8b949e;
}
.status.running { background: #1a4b2e; color: #3fb950; }
.status.done { background: #1a3a5c; color: #58a6ff; }
.status.error { background: #4b1a1a; color: #f85149; }
.chat { flex: 1; overflow-y: auto; padding: 16px; }
.msg {
margin: 8px 0; padding: 10px 14px; border-radius: 8px;
line-height: 1.6; white-space: pre-wrap; word-wrap: break-word;
}
.msg.user { background: #1a3a5c; color: #58a6ff; }
.msg.assistant { background: #161b22; color: #c9d1d9; }
.msg.event {
background: transparent; color: #8b949e; font-size: 11px;
padding: 4px 14px; border-left: 3px solid #30363d;
}
.msg.event.loop { border-left-color: #58a6ff; }
.msg.event.tool { border-left-color: #d29922; }
.msg.event.stall { border-left-color: #f85149; }
.input-bar {
padding: 12px 16px; background: #161b22;
border-top: 1px solid #30363d; display: flex; gap: 8px;
}
.input-bar input {
flex: 1; background: #0d1117; border: 1px solid #30363d;
color: #c9d1d9; padding: 8px 12px; border-radius: 6px;
font-family: inherit; font-size: 14px; outline: none;
}
.input-bar input:focus { border-color: #58a6ff; }
.input-bar button {
background: #238636; color: #fff; border: none;
padding: 8px 20px; border-radius: 6px; cursor: pointer;
font-family: inherit; font-weight: 600;
}
.input-bar button:hover { background: #2ea043; }
.input-bar button:disabled {
background: #21262d; color: #484f58; cursor: not-allowed;
}
.input-bar button.clear { background: #da3633; }
.input-bar button.clear:hover { background: #f85149; }
</style>
</head>
<body>
<header>
<h1>EventLoopNode Live</h1>
<span id="status" class="status">Idle</span>
<span id="iter" class="status" style="display:none">Step 0</span>
</header>
<div id="chat" class="chat"></div>
<div class="input-bar">
<input id="input" type="text"
placeholder="Ask anything..." autofocus />
<button id="go" onclick="run()">Send</button>
<button class="clear"
onclick="clearConversation()">Clear</button>
</div>
<script>
let ws = null;
let currentAssistantEl = null;
let iterCount = 0;
const chat = document.getElementById('chat');
const status = document.getElementById('status');
const iterEl = document.getElementById('iter');
const goBtn = document.getElementById('go');
const inputEl = document.getElementById('input');
inputEl.addEventListener('keydown', e => {
if (e.key === 'Enter') run();
});
function setStatus(text, cls) {
status.textContent = text;
status.className = 'status ' + cls;
}
function addMsg(text, cls) {
const el = document.createElement('div');
el.className = 'msg ' + cls;
el.textContent = text;
chat.appendChild(el);
chat.scrollTop = chat.scrollHeight;
return el;
}
function connect() {
ws = new WebSocket('ws://' + location.host + '/ws');
ws.onopen = () => {
setStatus('Ready', 'done');
goBtn.disabled = false;
};
ws.onmessage = handleEvent;
ws.onerror = () => { setStatus('Error', 'error'); };
ws.onclose = () => {
setStatus('Reconnecting...', '');
goBtn.disabled = true;
setTimeout(connect, 2000);
};
}
function handleEvent(msg) {
const evt = JSON.parse(msg.data);
if (evt.type === 'llm_text_delta') {
if (currentAssistantEl) {
currentAssistantEl.textContent += evt.content;
chat.scrollTop = chat.scrollHeight;
}
}
else if (evt.type === 'ready') {
setStatus('Ready', 'done');
if (currentAssistantEl && !currentAssistantEl.textContent)
currentAssistantEl.remove();
goBtn.disabled = false;
}
else if (evt.type === 'node_loop_iteration') {
iterCount = evt.iteration || (iterCount + 1);
iterEl.textContent = 'Step ' + iterCount;
iterEl.style.display = '';
}
else if (evt.type === 'tool_call_started') {
var info = evt.tool_name + '('
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
addMsg('TOOL ' + info, 'event tool');
}
else if (evt.type === 'tool_call_completed') {
var preview = (evt.result || '').slice(0, 200);
var cls = evt.is_error ? 'stall' : 'tool';
addMsg('RESULT ' + evt.tool_name + ': ' + preview,
'event ' + cls);
currentAssistantEl = addMsg('', 'assistant');
}
else if (evt.type === 'result') {
setStatus('Session ended', evt.success ? 'done' : 'error');
if (evt.error) addMsg('ERROR ' + evt.error, 'event stall');
if (currentAssistantEl && !currentAssistantEl.textContent)
currentAssistantEl.remove();
goBtn.disabled = false;
}
else if (evt.type === 'node_stalled') {
addMsg('STALLED ' + evt.reason, 'event stall');
}
else if (evt.type === 'cleared') {
chat.innerHTML = '';
iterCount = 0;
iterEl.textContent = 'Step 0';
iterEl.style.display = 'none';
setStatus('Ready', 'done');
goBtn.disabled = false;
}
}
function run() {
const text = inputEl.value.trim();
if (!text || !ws || ws.readyState !== 1) return;
addMsg(text, 'user');
currentAssistantEl = addMsg('', 'assistant');
inputEl.value = '';
setStatus('Running', 'running');
goBtn.disabled = true;
ws.send(JSON.stringify({ topic: text }));
}
function clearConversation() {
if (ws && ws.readyState === 1) {
ws.send(JSON.stringify({ command: 'clear' }));
}
}
connect();
</script>
</body>
</html>"""
)
# -------------------------------------------------------------------------
# WebSocket handler
# -------------------------------------------------------------------------
async def handle_ws(websocket):
"""Persistent WebSocket: long-lived EventLoopNode kept alive by ChatJudge."""
global STORE
# -- Event forwarding (WebSocket ← EventBus) ----------------------------
bus = EventBus()
async def forward_event(event):
try:
payload = {"type": event.type.value, **event.data}
if event.node_id:
payload["node_id"] = event.node_id
await websocket.send(json.dumps(payload))
except Exception:
pass
bus.subscribe(
event_types=[
EventType.NODE_LOOP_STARTED,
EventType.NODE_LOOP_ITERATION,
EventType.NODE_LOOP_COMPLETED,
EventType.LLM_TEXT_DELTA,
EventType.TOOL_CALL_STARTED,
EventType.TOOL_CALL_COMPLETED,
EventType.NODE_STALLED,
],
handler=forward_event,
)
# -- Ready callback (tells browser the LLM is done, waiting for input) --
async def send_ready():
try:
await websocket.send(json.dumps({"type": "ready"}))
except Exception:
pass
# -- Per-connection state -----------------------------------------------
judge = ChatJudge(on_ready=send_ready)
node = None
loop_task = None
tools = list(TOOL_REGISTRY.get_tools().values())
tool_executor = TOOL_REGISTRY.get_executor()
node_spec = NodeSpec(
id="assistant",
name="Chat Assistant",
description="A conversational assistant that remembers context across messages",
node_type="event_loop",
system_prompt=(
"You are a helpful assistant with access to tools. "
"You can search the web, scrape webpages, and query HubSpot CRM. "
"Use tools when the user asks for current information or external data. "
"You have full conversation history, so you can reference previous messages."
),
)
async def start_loop(first_message: str):
"""Create an EventLoopNode and run it as a background task."""
nonlocal node, loop_task
memory = SharedMemory()
ctx = NodeContext(
runtime=RUNTIME,
node_id="assistant",
node_spec=node_spec,
memory=memory,
input_data={},
llm=LLM,
available_tools=tools,
)
node = EventLoopNode(
event_bus=bus,
judge=judge,
config=LoopConfig(max_iterations=10_000, max_history_tokens=32_000),
conversation_store=STORE,
tool_executor=tool_executor,
)
await node.inject_event(first_message)
async def _run():
try:
result = await node.execute(ctx)
try:
await websocket.send(
json.dumps(
{
"type": "result",
"success": result.success,
"output": result.output,
"error": result.error,
"tokens": result.tokens_used,
}
)
)
except Exception:
pass
logger.info(f"Loop ended: success={result.success}, tokens={result.tokens_used}")
except websockets.exceptions.ConnectionClosed:
logger.info("Loop stopped: WebSocket closed")
except Exception as e:
logger.exception("Loop error")
try:
await websocket.send(
json.dumps(
{
"type": "result",
"success": False,
"error": str(e),
"output": {},
}
)
)
except Exception:
pass
loop_task = asyncio.create_task(_run())
async def stop_loop():
"""Signal the judge and wait for the loop task to finish."""
nonlocal node, loop_task
if loop_task and not loop_task.done():
judge.signal_shutdown()
try:
await asyncio.wait_for(loop_task, timeout=5.0)
except (TimeoutError, asyncio.CancelledError):
loop_task.cancel()
node = None
loop_task = None
# -- Message loop (runs for the lifetime of this WebSocket) -------------
try:
async for raw in websocket:
try:
msg = json.loads(raw)
except Exception:
continue
# Clear command
if msg.get("command") == "clear":
import shutil
await stop_loop()
await STORE.close()
conv_dir = STORE_DIR / "conversation"
if conv_dir.exists():
shutil.rmtree(conv_dir)
STORE = FileConversationStore(conv_dir)
# Reset judge for next session
judge = ChatJudge(on_ready=send_ready)
await websocket.send(json.dumps({"type": "cleared"}))
logger.info("Conversation cleared")
continue
topic = msg.get("topic", "")
if not topic:
continue
if node is None:
# First message — spin up the loop
logger.info(f"Starting persistent loop: {topic}")
await start_loop(topic)
else:
# Subsequent message — inject and unblock the judge
logger.info(f"Injecting message: {topic}")
await node.inject_event(topic)
judge.signal_message()
except websockets.exceptions.ConnectionClosed:
pass
finally:
await stop_loop()
logger.info("WebSocket closed, loop stopped")
# -------------------------------------------------------------------------
# HTTP handler for serving the HTML page
# -------------------------------------------------------------------------
async def process_request(connection, request: Request):
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
if request.path == "/ws":
return None # let websockets handle the upgrade
# Serve the HTML page for any other path
return Response(
HTTPStatus.OK,
"OK",
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
HTML_PAGE.encode(),
)
# -------------------------------------------------------------------------
# Main
# -------------------------------------------------------------------------
async def main():
port = 8765
async with websockets.serve(
handle_ws,
"0.0.0.0",
port,
process_request=process_request,
):
logger.info(f"Demo running at http://localhost:{port}")
logger.info("Open in your browser and enter a topic to research.")
await asyncio.Future() # run forever
if __name__ == "__main__":
asyncio.run(main())
+930
View File
@@ -0,0 +1,930 @@
#!/usr/bin/env python3
"""
Two-Node ContextHandoff Demo
Demonstrates ContextHandoff between two EventLoopNode instances:
Node A (Researcher) ContextHandoff Node B (Analyst)
Real LLM, real FileConversationStore, real EventBus.
Streams both nodes to a browser via WebSocket.
Usage:
cd /home/timothy/oss/hive/core
python demos/handoff_demo.py
Then open http://localhost:8766 in your browser.
"""
import asyncio
import json
import logging
import sys
import tempfile
from http import HTTPStatus
from pathlib import Path
import httpx
import websockets
from bs4 import BeautifulSoup
from websockets.http11 import Request, Response
# Add core, tools, and hive root to path
_CORE_DIR = Path(__file__).resolve().parent.parent
_HIVE_DIR = _CORE_DIR.parent
sys.path.insert(0, str(_CORE_DIR)) # framework.*
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
from core.framework.credentials import CredentialStore # noqa: E402
from framework.credentials.storage import ( # noqa: E402
CompositeStorage,
EncryptedFileStorage,
EnvVarStorage,
)
from framework.graph.context_handoff import ContextHandoff # noqa: E402
from framework.graph.conversation import NodeConversation # noqa: E402
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
from framework.llm.litellm import LiteLLMProvider # noqa: E402
from framework.llm.provider import Tool # noqa: E402
from framework.runner.tool_registry import ToolRegistry # noqa: E402
from framework.runtime.core import Runtime # noqa: E402
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
from framework.storage.conversation_store import FileConversationStore # noqa: E402
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
logger = logging.getLogger("handoff_demo")
# -------------------------------------------------------------------------
# Persistent state
# -------------------------------------------------------------------------
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_handoff_"))
RUNTIME = Runtime(STORE_DIR / "runtime")
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
# -------------------------------------------------------------------------
# Credentials
# -------------------------------------------------------------------------
# Composite credential store: encrypted files (primary) + env vars (fallback)
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
_composite = CompositeStorage(
primary=EncryptedFileStorage(),
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
)
CREDENTIALS = CredentialStoreAdapter(CredentialStore(storage=_composite))
for _name in ["brave_search", "hubspot"]:
_val = CREDENTIALS.get(_name)
if _val:
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
else:
logger.debug("credential %s: not found", _name)
# -------------------------------------------------------------------------
# Tool Registry — web_search + web_scrape for Node A (Researcher)
# -------------------------------------------------------------------------
TOOL_REGISTRY = ToolRegistry()
def _exec_web_search(inputs: dict) -> dict:
api_key = CREDENTIALS.get("brave_search")
if not api_key:
return {"error": "brave_search credential not configured"}
query = inputs.get("query", "")
num_results = min(inputs.get("num_results", 10), 20)
resp = httpx.get(
"https://api.search.brave.com/res/v1/web/search",
params={"q": query, "count": num_results},
headers={
"X-Subscription-Token": api_key,
"Accept": "application/json",
},
timeout=30.0,
)
if resp.status_code != 200:
return {"error": f"Brave API HTTP {resp.status_code}"}
data = resp.json()
results = [
{
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("description", ""),
}
for item in data.get("web", {}).get("results", [])[:num_results]
]
return {"query": query, "results": results, "total": len(results)}
TOOL_REGISTRY.register(
name="web_search",
tool=Tool(
name="web_search",
description=(
"Search the web for current information. "
"Returns titles, URLs, and snippets from search results."
),
parameters={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query (1-500 characters)",
},
"num_results": {
"type": "integer",
"description": "Number of results (1-20, default 10)",
},
},
"required": ["query"],
},
),
executor=lambda inputs: _exec_web_search(inputs),
)
_SCRAPE_HEADERS = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/131.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml",
}
def _exec_web_scrape(inputs: dict) -> dict:
url = inputs.get("url", "")
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
if not url.startswith(("http://", "https://")):
url = "https://" + url
try:
resp = httpx.get(
url,
timeout=30.0,
follow_redirects=True,
headers=_SCRAPE_HEADERS,
)
if resp.status_code != 200:
return {"error": f"HTTP {resp.status_code}"}
soup = BeautifulSoup(resp.text, "html.parser")
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
tag.decompose()
title = soup.title.get_text(strip=True) if soup.title else ""
main = (
soup.find("article")
or soup.find("main")
or soup.find(attrs={"role": "main"})
or soup.find("body")
)
text = main.get_text(separator=" ", strip=True) if main else ""
text = " ".join(text.split())
if len(text) > max_length:
text = text[:max_length] + "..."
return {
"url": url,
"title": title,
"content": text,
"length": len(text),
}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except Exception as e:
return {"error": f"Scrape failed: {e}"}
TOOL_REGISTRY.register(
name="web_scrape",
tool=Tool(
name="web_scrape",
description=(
"Scrape and extract text content from a webpage URL. "
"Returns the page title and main text content."
),
parameters={
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL of the webpage to scrape",
},
"max_length": {
"type": "integer",
"description": "Maximum text length (default 50000)",
},
},
"required": ["url"],
},
),
executor=lambda inputs: _exec_web_scrape(inputs),
)
logger.info(
"ToolRegistry loaded: %s",
", ".join(TOOL_REGISTRY.get_registered_names()),
)
# -------------------------------------------------------------------------
# Node Specs
# -------------------------------------------------------------------------
RESEARCHER_SPEC = NodeSpec(
id="researcher",
name="Researcher",
description="Researches a topic using web search and scraping tools",
node_type="event_loop",
input_keys=["topic"],
output_keys=["research_summary"],
system_prompt=(
"You are a thorough research assistant. Your job is to research "
"the given topic using the web_search and web_scrape tools.\n\n"
"1. Search for relevant information on the topic\n"
"2. Scrape 1-2 of the most promising URLs for details\n"
"3. Synthesize your findings into a comprehensive summary\n"
"4. Use set_output with key='research_summary' to save your "
"findings\n\n"
"Be thorough but efficient. Aim for 2-4 search/scrape calls, "
"then summarize and set_output."
),
)
ANALYST_SPEC = NodeSpec(
id="analyst",
name="Analyst",
description="Analyzes research findings and provides insights",
node_type="event_loop",
input_keys=["context"],
output_keys=["analysis"],
system_prompt=(
"You are a strategic analyst. You receive research findings from "
"a previous researcher and must:\n\n"
"1. Identify key themes and patterns\n"
"2. Assess the reliability and significance of the findings\n"
"3. Provide actionable insights and recommendations\n"
"4. Use set_output with key='analysis' to save your analysis\n\n"
"Be concise but insightful. Focus on what matters most."
),
)
# -------------------------------------------------------------------------
# HTML page
# -------------------------------------------------------------------------
HTML_PAGE = ( # noqa: E501
"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>ContextHandoff Demo</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: 'SF Mono', 'Fira Code', monospace;
background: #0d1117;
color: #c9d1d9;
height: 100vh;
display: flex;
flex-direction: column;
}
header {
background: #161b22;
padding: 12px 20px;
border-bottom: 1px solid #30363d;
display: flex;
align-items: center;
gap: 16px;
}
header h1 {
font-size: 16px;
color: #58a6ff;
font-weight: 600;
}
.badge {
font-size: 12px;
padding: 3px 10px;
border-radius: 12px;
background: #21262d;
color: #8b949e;
}
.badge.researcher {
background: #1a3a5c;
color: #58a6ff;
}
.badge.analyst {
background: #1a4b2e;
color: #3fb950;
}
.badge.handoff {
background: #3d1f00;
color: #d29922;
}
.badge.done {
background: #21262d;
color: #8b949e;
}
.badge.error {
background: #4b1a1a;
color: #f85149;
}
.chat {
flex: 1;
overflow-y: auto;
padding: 16px;
}
.msg {
margin: 8px 0;
padding: 10px 14px;
border-radius: 8px;
line-height: 1.6;
white-space: pre-wrap;
word-wrap: break-word;
}
.msg.user {
background: #1a3a5c;
color: #58a6ff;
}
.msg.assistant {
background: #161b22;
color: #c9d1d9;
}
.msg.assistant.analyst-msg {
border-left: 3px solid #3fb950;
}
.msg.event {
background: transparent;
color: #8b949e;
font-size: 11px;
padding: 4px 14px;
border-left: 3px solid #30363d;
}
.msg.event.loop {
border-left-color: #58a6ff;
}
.msg.event.tool {
border-left-color: #d29922;
}
.msg.event.stall {
border-left-color: #f85149;
}
.handoff-banner {
margin: 16px 0;
padding: 16px;
background: #1c1200;
border: 1px solid #d29922;
border-radius: 8px;
text-align: center;
}
.handoff-banner h3 {
color: #d29922;
font-size: 14px;
margin-bottom: 8px;
}
.handoff-banner p, .result-banner p {
color: #8b949e;
font-size: 12px;
line-height: 1.5;
max-height: 200px;
overflow-y: auto;
white-space: pre-wrap;
text-align: left;
}
.result-banner {
margin: 16px 0;
padding: 16px;
background: #0a2614;
border: 1px solid #3fb950;
border-radius: 8px;
}
.result-banner h3 {
color: #3fb950;
font-size: 14px;
margin-bottom: 8px;
text-align: center;
}
.result-banner .label {
color: #58a6ff;
font-size: 11px;
font-weight: 600;
margin-top: 10px;
margin-bottom: 2px;
}
.result-banner .tokens {
color: #484f58;
font-size: 11px;
text-align: center;
margin-top: 10px;
}
.input-bar {
padding: 12px 16px;
background: #161b22;
border-top: 1px solid #30363d;
display: flex;
gap: 8px;
}
.input-bar input {
flex: 1;
background: #0d1117;
border: 1px solid #30363d;
color: #c9d1d9;
padding: 8px 12px;
border-radius: 6px;
font-family: inherit;
font-size: 14px;
outline: none;
}
.input-bar input:focus {
border-color: #58a6ff;
}
.input-bar button {
background: #238636;
color: #fff;
border: none;
padding: 8px 20px;
border-radius: 6px;
cursor: pointer;
font-family: inherit;
font-weight: 600;
}
.input-bar button:hover {
background: #2ea043;
}
.input-bar button:disabled {
background: #21262d;
color: #484f58;
cursor: not-allowed;
}
</style>
</head>
<body>
<header>
<h1>ContextHandoff Demo</h1>
<span id="phase" class="badge">Idle</span>
<span id="iter" class="badge" style="display:none">Step 0</span>
</header>
<div id="chat" class="chat"></div>
<div class="input-bar">
<input id="input" type="text"
placeholder="Enter a research topic..." autofocus />
<button id="go" onclick="run()">Research</button>
</div>
<script>
let ws = null;
let currentAssistantEl = null;
let iterCount = 0;
let currentPhase = 'idle';
const chat = document.getElementById('chat');
const phase = document.getElementById('phase');
const iterEl = document.getElementById('iter');
const goBtn = document.getElementById('go');
const inputEl = document.getElementById('input');
inputEl.addEventListener('keydown', e => {
if (e.key === 'Enter') run();
});
function setPhase(text, cls) {
phase.textContent = text;
phase.className = 'badge ' + cls;
currentPhase = cls;
}
function addMsg(text, cls) {
const el = document.createElement('div');
el.className = 'msg ' + cls;
el.textContent = text;
chat.appendChild(el);
chat.scrollTop = chat.scrollHeight;
return el;
}
function addHandoffBanner(summary) {
const banner = document.createElement('div');
banner.className = 'handoff-banner';
const h3 = document.createElement('h3');
h3.textContent = 'Context Handoff: Researcher -> Analyst';
const p = document.createElement('p');
p.textContent = summary || 'Passing research context...';
banner.appendChild(h3);
banner.appendChild(p);
chat.appendChild(banner);
chat.scrollTop = chat.scrollHeight;
}
function addResultBanner(researcher, analyst, tokens) {
const banner = document.createElement('div');
banner.className = 'result-banner';
const h3 = document.createElement('h3');
h3.textContent = 'Pipeline Complete';
banner.appendChild(h3);
if (researcher && researcher.research_summary) {
const lbl = document.createElement('div');
lbl.className = 'label';
lbl.textContent = 'RESEARCH SUMMARY';
banner.appendChild(lbl);
const p = document.createElement('p');
p.textContent = researcher.research_summary;
banner.appendChild(p);
}
if (analyst && analyst.analysis) {
const lbl = document.createElement('div');
lbl.className = 'label';
lbl.textContent = 'ANALYSIS';
lbl.style.color = '#3fb950';
banner.appendChild(lbl);
const p = document.createElement('p');
p.textContent = analyst.analysis;
banner.appendChild(p);
}
if (tokens) {
const t = document.createElement('div');
t.className = 'tokens';
t.textContent = 'Total tokens: ' + tokens.toLocaleString();
banner.appendChild(t);
}
chat.appendChild(banner);
chat.scrollTop = chat.scrollHeight;
}
function connect() {
ws = new WebSocket('ws://' + location.host + '/ws');
ws.onopen = () => {
setPhase('Ready', 'done');
goBtn.disabled = false;
};
ws.onmessage = handleEvent;
ws.onerror = () => { setPhase('Error', 'error'); };
ws.onclose = () => {
setPhase('Reconnecting...', '');
goBtn.disabled = true;
setTimeout(connect, 2000);
};
}
function handleEvent(msg) {
const evt = JSON.parse(msg.data);
if (evt.type === 'phase') {
if (evt.phase === 'researcher') {
setPhase('Researcher', 'researcher');
} else if (evt.phase === 'handoff') {
setPhase('Handoff', 'handoff');
} else if (evt.phase === 'analyst') {
setPhase('Analyst', 'analyst');
}
iterCount = 0;
iterEl.style.display = 'none';
}
else if (evt.type === 'llm_text_delta') {
if (currentAssistantEl) {
currentAssistantEl.textContent += evt.content;
chat.scrollTop = chat.scrollHeight;
}
}
else if (evt.type === 'node_loop_iteration') {
iterCount = evt.iteration || (iterCount + 1);
iterEl.textContent = 'Step ' + iterCount;
iterEl.style.display = '';
}
else if (evt.type === 'tool_call_started') {
var info = evt.tool_name + '('
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
addMsg('TOOL ' + info, 'event tool');
}
else if (evt.type === 'tool_call_completed') {
var preview = (evt.result || '').slice(0, 200);
var cls = evt.is_error ? 'stall' : 'tool';
addMsg(
'RESULT ' + evt.tool_name + ': ' + preview,
'event ' + cls
);
var assistCls = currentPhase === 'analyst'
? 'assistant analyst-msg' : 'assistant';
currentAssistantEl = addMsg('', assistCls);
}
else if (evt.type === 'handoff_context') {
addHandoffBanner(evt.summary);
var assistCls = 'assistant analyst-msg';
currentAssistantEl = addMsg('', assistCls);
}
else if (evt.type === 'node_result') {
if (evt.node_id === 'researcher') {
if (currentAssistantEl
&& !currentAssistantEl.textContent) {
currentAssistantEl.remove();
}
}
}
else if (evt.type === 'done') {
setPhase('Done', 'done');
iterEl.style.display = 'none';
if (currentAssistantEl
&& !currentAssistantEl.textContent) {
currentAssistantEl.remove();
}
currentAssistantEl = null;
addResultBanner(
evt.researcher, evt.analyst, evt.total_tokens
);
goBtn.disabled = false;
inputEl.placeholder = 'Enter another topic...';
}
else if (evt.type === 'error') {
setPhase('Error', 'error');
addMsg('ERROR ' + evt.message, 'event stall');
goBtn.disabled = false;
}
else if (evt.type === 'node_stalled') {
addMsg('STALLED ' + evt.reason, 'event stall');
}
}
function run() {
const text = inputEl.value.trim();
if (!text || !ws || ws.readyState !== 1) return;
chat.innerHTML = '';
addMsg(text, 'user');
currentAssistantEl = addMsg('', 'assistant');
inputEl.value = '';
goBtn.disabled = true;
ws.send(JSON.stringify({ topic: text }));
}
connect();
</script>
</body>
</html>"""
)
# -------------------------------------------------------------------------
# WebSocket handler — sequential Node A → Handoff → Node B
# -------------------------------------------------------------------------
async def handle_ws(websocket):
"""Run the two-node handoff pipeline per user message."""
try:
async for raw in websocket:
try:
msg = json.loads(raw)
except Exception:
continue
topic = msg.get("topic", "")
if not topic:
continue
logger.info(f"Starting handoff pipeline for: {topic}")
try:
await _run_pipeline(websocket, topic)
except websockets.exceptions.ConnectionClosed:
logger.info("WebSocket closed during pipeline")
return
except Exception as e:
logger.exception("Pipeline error")
try:
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
except Exception:
pass
except websockets.exceptions.ConnectionClosed:
pass
async def _run_pipeline(websocket, topic: str):
"""Execute: Node A (research) → ContextHandoff → Node B (analysis)."""
import shutil
# Fresh stores for each run
run_dir = Path(tempfile.mkdtemp(prefix="hive_run_", dir=STORE_DIR))
store_a = FileConversationStore(run_dir / "node_a")
store_b = FileConversationStore(run_dir / "node_b")
# Shared event bus
bus = EventBus()
async def forward_event(event):
try:
payload = {"type": event.type.value, **event.data}
if event.node_id:
payload["node_id"] = event.node_id
await websocket.send(json.dumps(payload))
except Exception:
pass
bus.subscribe(
event_types=[
EventType.NODE_LOOP_STARTED,
EventType.NODE_LOOP_ITERATION,
EventType.NODE_LOOP_COMPLETED,
EventType.LLM_TEXT_DELTA,
EventType.TOOL_CALL_STARTED,
EventType.TOOL_CALL_COMPLETED,
EventType.NODE_STALLED,
],
handler=forward_event,
)
tools = list(TOOL_REGISTRY.get_tools().values())
tool_executor = TOOL_REGISTRY.get_executor()
# ---- Phase 1: Researcher ------------------------------------------------
await websocket.send(json.dumps({"type": "phase", "phase": "researcher"}))
node_a = EventLoopNode(
event_bus=bus,
judge=None, # implicit judge: accept when output_keys filled
config=LoopConfig(
max_iterations=20,
max_tool_calls_per_turn=10,
max_history_tokens=32_000,
),
conversation_store=store_a,
tool_executor=tool_executor,
)
ctx_a = NodeContext(
runtime=RUNTIME,
node_id="researcher",
node_spec=RESEARCHER_SPEC,
memory=SharedMemory(),
input_data={"topic": topic},
llm=LLM,
available_tools=tools,
)
result_a = await node_a.execute(ctx_a)
logger.info(
"Researcher done: success=%s, tokens=%s",
result_a.success,
result_a.tokens_used,
)
await websocket.send(
json.dumps(
{
"type": "node_result",
"node_id": "researcher",
"success": result_a.success,
"output": result_a.output,
}
)
)
if not result_a.success:
await websocket.send(
json.dumps(
{
"type": "error",
"message": f"Researcher failed: {result_a.error}",
}
)
)
return
# ---- Phase 2: Context Handoff -------------------------------------------
await websocket.send(json.dumps({"type": "phase", "phase": "handoff"}))
# Restore the researcher's conversation from store
conversation_a = await NodeConversation.restore(store_a)
if conversation_a is None:
await websocket.send(
json.dumps(
{
"type": "error",
"message": "Failed to restore researcher conversation",
}
)
)
return
handoff_engine = ContextHandoff(llm=LLM)
handoff_context = handoff_engine.summarize_conversation(
conversation=conversation_a,
node_id="researcher",
output_keys=["research_summary"],
)
formatted_handoff = ContextHandoff.format_as_input(handoff_context)
logger.info(
"Handoff: %d turns, ~%d tokens, keys=%s",
handoff_context.turn_count,
handoff_context.total_tokens_used,
list(handoff_context.key_outputs.keys()),
)
# Send handoff context to browser
await websocket.send(
json.dumps(
{
"type": "handoff_context",
"summary": handoff_context.summary[:500],
"turn_count": handoff_context.turn_count,
"tokens": handoff_context.total_tokens_used,
"key_outputs": handoff_context.key_outputs,
}
)
)
# ---- Phase 3: Analyst ---------------------------------------------------
await websocket.send(json.dumps({"type": "phase", "phase": "analyst"}))
node_b = EventLoopNode(
event_bus=bus,
judge=None, # implicit judge
config=LoopConfig(
max_iterations=10,
max_tool_calls_per_turn=5,
max_history_tokens=32_000,
),
conversation_store=store_b,
)
ctx_b = NodeContext(
runtime=RUNTIME,
node_id="analyst",
node_spec=ANALYST_SPEC,
memory=SharedMemory(),
input_data={"context": formatted_handoff},
llm=LLM,
available_tools=[],
)
result_b = await node_b.execute(ctx_b)
logger.info(
"Analyst done: success=%s, tokens=%s",
result_b.success,
result_b.tokens_used,
)
# ---- Done ---------------------------------------------------------------
await websocket.send(
json.dumps(
{
"type": "done",
"researcher": result_a.output,
"analyst": result_b.output,
"total_tokens": ((result_a.tokens_used or 0) + (result_b.tokens_used or 0)),
}
)
)
# Clean up temp stores
try:
shutil.rmtree(run_dir)
except Exception:
pass
# -------------------------------------------------------------------------
# HTTP handler
# -------------------------------------------------------------------------
async def process_request(connection, request: Request):
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
if request.path == "/ws":
return None
return Response(
HTTPStatus.OK,
"OK",
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
HTML_PAGE.encode(),
)
# -------------------------------------------------------------------------
# Main
# -------------------------------------------------------------------------
async def main():
port = 8766
async with websockets.serve(
handle_ws,
"0.0.0.0",
port,
process_request=process_request,
):
logger.info(f"Handoff demo at http://localhost:{port}")
logger.info("Enter a research topic to start the pipeline.")
await asyncio.Future()
if __name__ == "__main__":
asyncio.run(main())
+88 -6
View File
@@ -64,6 +64,8 @@ class AdenCachedStorage(CredentialStorage):
- **Reads**: Try local cache first, fallback to Aden if stale/missing
- **Writes**: Always write to local cache
- **Offline resilience**: Uses cached credentials when Aden is unreachable
- **Provider-based lookup**: Match credentials by provider name (e.g., "hubspot")
when direct ID lookup fails, since Aden uses hash-based IDs internally.
The cache TTL determines how long to trust local credentials before
checking with the Aden server for updates. This balances:
@@ -85,6 +87,7 @@ class AdenCachedStorage(CredentialStorage):
# First access fetches from Aden
# Subsequent accesses use cache until TTL expires
# Can look up by provider name OR credential ID
token = store.get_key("hubspot", "access_token")
"""
@@ -111,21 +114,24 @@ class AdenCachedStorage(CredentialStorage):
self._cache_ttl = timedelta(seconds=cache_ttl_seconds)
self._prefer_local = prefer_local
self._cache_timestamps: dict[str, datetime] = {}
# Index: provider name (e.g., "hubspot") -> credential hash ID
self._provider_index: dict[str, str] = {}
def save(self, credential: CredentialObject) -> None:
"""
Save credential to local cache.
Save credential to local cache and update provider index.
Args:
credential: The credential to save.
"""
self._local.save(credential)
self._cache_timestamps[credential.id] = datetime.now(UTC)
self._index_provider(credential)
logger.debug(f"Cached credential '{credential.id}'")
def load(self, credential_id: str) -> CredentialObject | None:
"""
Load credential from cache, with Aden fallback.
Load credential from cache, with Aden fallback and provider-based lookup.
The loading strategy depends on the `prefer_local` setting:
@@ -141,8 +147,37 @@ class AdenCachedStorage(CredentialStorage):
2. Update local cache with response
3. Fall back to local cache only if Aden fails
Provider-based lookup:
When a provider index mapping exists for the credential_id (e.g.,
"hubspot" hash ID), the Aden-synced credential is loaded first.
This ensures fresh OAuth tokens from Aden take priority over stale
local credentials (env vars, old encrypted files).
Args:
credential_id: The credential identifier.
credential_id: The credential identifier or provider name.
Returns:
CredentialObject if found, None otherwise.
"""
# Check provider index first — Aden-synced credentials take priority
resolved_id = self._provider_index.get(credential_id)
if resolved_id and resolved_id != credential_id:
result = self._load_by_id(resolved_id)
if result is not None:
logger.info(
f"Loaded credential '{credential_id}' via provider index (id='{resolved_id}')"
)
return result
# Direct lookup (exact credential_id match)
return self._load_by_id(credential_id)
def _load_by_id(self, credential_id: str) -> CredentialObject | None:
"""
Load credential by exact ID from cache, with Aden fallback.
Args:
credential_id: The exact credential identifier.
Returns:
CredentialObject if found, None otherwise.
@@ -200,15 +235,21 @@ class AdenCachedStorage(CredentialStorage):
def exists(self, credential_id: str) -> bool:
"""
Check if credential exists in local cache.
Check if credential exists in local cache (by ID or provider name).
Args:
credential_id: The credential identifier.
credential_id: The credential identifier or provider name.
Returns:
True if credential exists locally.
"""
return self._local.exists(credential_id)
if self._local.exists(credential_id):
return True
# Check provider index
resolved_id = self._provider_index.get(credential_id)
if resolved_id and resolved_id != credential_id:
return self._local.exists(resolved_id)
return False
def _is_cache_fresh(self, credential_id: str) -> bool:
"""
@@ -242,6 +283,47 @@ class AdenCachedStorage(CredentialStorage):
self._cache_timestamps.clear()
logger.debug("Invalidated all cache entries")
def _index_provider(self, credential: CredentialObject) -> None:
"""
Index a credential by its provider/integration type.
Aden credentials carry an ``_integration_type`` key whose value is
the provider name (e.g., ``hubspot``). This method maps that
provider name to the credential's hash ID so that subsequent
``load("hubspot")`` calls resolve to the correct credential.
Args:
credential: The credential to index.
"""
integration_type_key = credential.keys.get("_integration_type")
if integration_type_key is None:
return
provider_name = integration_type_key.value.get_secret_value()
if provider_name:
self._provider_index[provider_name] = credential.id
logger.debug(f"Indexed provider '{provider_name}' -> '{credential.id}'")
def rebuild_provider_index(self) -> int:
"""
Rebuild the provider index from all locally cached credentials.
Useful after loading from disk when the in-memory index is empty.
Returns:
Number of provider mappings indexed.
"""
self._provider_index.clear()
indexed = 0
for cred_id in self._local.list_all():
cred = self._local.load(cred_id)
if cred:
before = len(self._provider_index)
self._index_provider(cred)
if len(self._provider_index) > before:
indexed += 1
logger.debug(f"Rebuilt provider index with {indexed} mappings")
return indexed
def sync_all_from_aden(self) -> int:
"""
Sync all credentials from Aden server to local cache.
@@ -589,6 +589,149 @@ class TestAdenCachedStorage:
assert info["stale"]["is_fresh"] is False
assert info["stale"]["ttl_remaining_seconds"] == 0
def test_save_indexes_provider(self, cached_storage):
"""Test save builds the provider index from _integration_type key."""
cred = CredentialObject(
id="aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("token-value"),
),
"_integration_type": CredentialKey(
name="_integration_type",
value=SecretStr("hubspot"),
),
},
)
cached_storage.save(cred)
assert cached_storage._provider_index["hubspot"] == "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
def test_load_by_provider_name(self, cached_storage):
"""Test load resolves provider name to hash-based credential ID."""
hash_id = "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
cred = CredentialObject(
id=hash_id,
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("hubspot-token"),
),
"_integration_type": CredentialKey(
name="_integration_type",
value=SecretStr("hubspot"),
),
},
)
# Save builds the index
cached_storage.save(cred)
# Load by provider name should resolve to the hash ID
loaded = cached_storage.load("hubspot")
assert loaded is not None
assert loaded.id == hash_id
assert loaded.keys["access_token"].value.get_secret_value() == "hubspot-token"
def test_load_by_direct_id_still_works(self, cached_storage):
"""Test load by direct hash ID still works as before."""
hash_id = "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
cred = CredentialObject(
id=hash_id,
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("token"),
),
"_integration_type": CredentialKey(
name="_integration_type",
value=SecretStr("hubspot"),
),
},
)
cached_storage.save(cred)
# Direct ID lookup should still work
loaded = cached_storage.load(hash_id)
assert loaded is not None
assert loaded.id == hash_id
def test_exists_by_provider_name(self, cached_storage):
"""Test exists resolves provider name to hash-based credential ID."""
hash_id = "c2xhY2s6dGVzdDo5OTk="
cred = CredentialObject(
id=hash_id,
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("slack-token"),
),
"_integration_type": CredentialKey(
name="_integration_type",
value=SecretStr("slack"),
),
},
)
cached_storage.save(cred)
assert cached_storage.exists("slack") is True
assert cached_storage.exists(hash_id) is True
assert cached_storage.exists("nonexistent") is False
def test_rebuild_provider_index(self, cached_storage, local_storage):
"""Test rebuild_provider_index reconstructs from local storage."""
# Manually save credentials to local storage (bypassing cached_storage.save)
for provider_name, hash_id in [("hubspot", "hash_hub"), ("slack", "hash_slack")]:
cred = CredentialObject(
id=hash_id,
credential_type=CredentialType.OAUTH2,
keys={
"_integration_type": CredentialKey(
name="_integration_type",
value=SecretStr(provider_name),
),
},
)
local_storage.save(cred)
# Index should be empty (we bypassed save)
assert len(cached_storage._provider_index) == 0
# Rebuild
indexed = cached_storage.rebuild_provider_index()
assert indexed == 2
assert cached_storage._provider_index["hubspot"] == "hash_hub"
assert cached_storage._provider_index["slack"] == "hash_slack"
def test_save_without_integration_type_no_index(self, cached_storage):
"""Test save does not index credentials without _integration_type key."""
cred = CredentialObject(
id="plain-cred",
credential_type=CredentialType.API_KEY,
keys={
"api_key": CredentialKey(
name="api_key",
value=SecretStr("key-value"),
),
},
)
cached_storage.save(cred)
assert "plain-cred" not in cached_storage._provider_index
assert len(cached_storage._provider_index) == 0
# =============================================================================
# Integration Tests
+28
View File
@@ -1,8 +1,22 @@
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
from framework.graph.client_io import (
ActiveNodeClientIO,
ClientIOGateway,
InertNodeClientIO,
NodeClientIO,
)
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
from framework.graph.context_handoff import ContextHandoff, HandoffContext
from framework.graph.conversation import ConversationStore, Message, NodeConversation
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.event_loop_node import (
EventLoopNode,
JudgeProtocol,
JudgeVerdict,
LoopConfig,
OutputAccumulator,
)
from framework.graph.executor import GraphExecutor
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
@@ -77,4 +91,18 @@ __all__ = [
"NodeConversation",
"ConversationStore",
"Message",
# Event Loop
"EventLoopNode",
"LoopConfig",
"OutputAccumulator",
"JudgeProtocol",
"JudgeVerdict",
# Context Handoff
"ContextHandoff",
"HandoffContext",
# Client I/O
"NodeClientIO",
"ActiveNodeClientIO",
"InertNodeClientIO",
"ClientIOGateway",
]
+170
View File
@@ -0,0 +1,170 @@
"""
Client I/O gateway for graph nodes.
Provides the bridge between node code and external clients:
- ActiveNodeClientIO: for client_facing=True nodes (streams output, accepts input)
- InertNodeClientIO: for client_facing=False nodes (logs internally, redirects input)
- ClientIOGateway: factory that creates the right variant per node
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from framework.runtime.event_bus import EventBus
logger = logging.getLogger(__name__)
class NodeClientIO(ABC):
"""Abstract base for node client I/O."""
@abstractmethod
async def emit_output(self, content: str, is_final: bool = False) -> None:
"""Emit output content. If is_final=True, signal end of stream."""
@abstractmethod
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
"""Request input. Behavior depends on whether the node is client-facing."""
class ActiveNodeClientIO(NodeClientIO):
"""
Client I/O for client_facing=True nodes.
- emit_output() queues content and publishes CLIENT_OUTPUT_DELTA.
- request_input() publishes CLIENT_INPUT_REQUESTED, then awaits provide_input().
- output_stream() yields queued content until the final sentinel.
"""
def __init__(
self,
node_id: str,
event_bus: EventBus | None = None,
) -> None:
self.node_id = node_id
self._event_bus = event_bus
self._output_queue: asyncio.Queue[str | None] = asyncio.Queue()
self._output_snapshot = ""
self._input_event: asyncio.Event | None = None
self._input_result: str | None = None
async def emit_output(self, content: str, is_final: bool = False) -> None:
self._output_snapshot += content
await self._output_queue.put(content)
if self._event_bus is not None:
await self._event_bus.emit_client_output_delta(
stream_id=self.node_id,
node_id=self.node_id,
content=content,
snapshot=self._output_snapshot,
)
if is_final:
await self._output_queue.put(None)
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
if self._input_event is not None:
raise RuntimeError("request_input already pending for this node")
self._input_event = asyncio.Event()
self._input_result = None
if self._event_bus is not None:
await self._event_bus.emit_client_input_requested(
stream_id=self.node_id,
node_id=self.node_id,
prompt=prompt,
)
try:
if timeout is not None:
await asyncio.wait_for(self._input_event.wait(), timeout=timeout)
else:
await self._input_event.wait()
finally:
self._input_event = None
if self._input_result is None:
raise RuntimeError("input event was set but no input was provided")
result = self._input_result
self._input_result = None
return result
async def provide_input(self, content: str) -> None:
"""Called externally to fulfill a pending request_input()."""
if self._input_event is None:
raise RuntimeError("no pending request_input to fulfill")
self._input_result = content
self._input_event.set()
async def output_stream(self) -> AsyncIterator[str]:
"""Async iterator that yields output chunks until the final sentinel."""
while True:
chunk = await self._output_queue.get()
if chunk is None:
break
yield chunk
class InertNodeClientIO(NodeClientIO):
"""
Client I/O for client_facing=False nodes.
- emit_output() publishes NODE_INTERNAL_OUTPUT (content is not discarded).
- request_input() publishes NODE_INPUT_BLOCKED and returns a redirect string.
"""
def __init__(
self,
node_id: str,
event_bus: EventBus | None = None,
) -> None:
self.node_id = node_id
self._event_bus = event_bus
async def emit_output(self, content: str, is_final: bool = False) -> None:
if self._event_bus is not None:
await self._event_bus.emit_node_internal_output(
stream_id=self.node_id,
node_id=self.node_id,
content=content,
)
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
if self._event_bus is not None:
await self._event_bus.emit_node_input_blocked(
stream_id=self.node_id,
node_id=self.node_id,
prompt=prompt,
)
return (
"You are an internal processing node. There is no user to interact with."
" Work with the data provided in your inputs to complete your task."
)
class ClientIOGateway:
"""Factory that creates the appropriate NodeClientIO for a node."""
def __init__(self, event_bus: EventBus | None = None) -> None:
self._event_bus = event_bus
def create_io(self, node_id: str, client_facing: bool) -> NodeClientIO:
if client_facing:
return ActiveNodeClientIO(
node_id=node_id,
event_bus=self._event_bus,
)
return InertNodeClientIO(
node_id=node_id,
event_bus=self._event_bus,
)
+191
View File
@@ -0,0 +1,191 @@
"""Context handoff: summarize a completed NodeConversation for the next graph node."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from framework.graph.conversation import _try_extract_key
if TYPE_CHECKING:
from framework.graph.conversation import NodeConversation
from framework.llm.provider import LLMProvider
logger = logging.getLogger(__name__)
_TRUNCATE_CHARS = 500
# ---------------------------------------------------------------------------
# Data
# ---------------------------------------------------------------------------
@dataclass
class HandoffContext:
"""Structured summary of a completed node conversation."""
source_node_id: str
summary: str
key_outputs: dict[str, Any]
turn_count: int
total_tokens_used: int
# ---------------------------------------------------------------------------
# ContextHandoff
# ---------------------------------------------------------------------------
class ContextHandoff:
"""Summarize a completed NodeConversation into a HandoffContext.
Parameters
----------
llm : LLMProvider | None
Optional LLM provider for abstractive summarization.
When *None*, all summarization uses the extractive fallback.
"""
def __init__(self, llm: LLMProvider | None = None) -> None:
self.llm = llm
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def summarize_conversation(
self,
conversation: NodeConversation,
node_id: str,
output_keys: list[str] | None = None,
) -> HandoffContext:
"""Produce a HandoffContext from *conversation*.
1. Extracts turn_count & total_tokens_used (sync properties).
2. Extracts key_outputs by scanning assistant messages most-recent-first.
3. Builds a summary via the LLM (if available) or extractive fallback.
"""
turn_count = conversation.turn_count
total_tokens_used = conversation.estimate_tokens()
messages = conversation.messages # defensive copy
# --- key outputs ---------------------------------------------------
key_outputs: dict[str, Any] = {}
if output_keys:
remaining = set(output_keys)
for msg in reversed(messages):
if msg.role != "assistant" or not remaining:
continue
for key in list(remaining):
value = _try_extract_key(msg.content, key)
if value is not None:
key_outputs[key] = value
remaining.discard(key)
# --- summary -------------------------------------------------------
if self.llm is not None:
try:
summary = self._llm_summary(messages, output_keys or [])
except Exception:
logger.warning(
"LLM summarization failed; falling back to extractive.",
exc_info=True,
)
summary = self._extractive_summary(messages)
else:
summary = self._extractive_summary(messages)
return HandoffContext(
source_node_id=node_id,
summary=summary,
key_outputs=key_outputs,
turn_count=turn_count,
total_tokens_used=total_tokens_used,
)
@staticmethod
def format_as_input(handoff: HandoffContext) -> str:
"""Render *handoff* as structured plain text for the next node's input."""
header = (
f"--- CONTEXT FROM: {handoff.source_node_id} "
f"({handoff.turn_count} turns, ~{handoff.total_tokens_used} tokens) ---"
)
sections: list[str] = [header, ""]
if handoff.key_outputs:
sections.append("KEY OUTPUTS:")
for k, v in handoff.key_outputs.items():
sections.append(f"- {k}: {v}")
sections.append("")
summary_text = handoff.summary or "No summary available."
sections.append("SUMMARY:")
sections.append(summary_text)
sections.append("")
sections.append("--- END CONTEXT ---")
return "\n".join(sections)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
@staticmethod
def _extractive_summary(messages: list) -> str:
"""Build a summary from key assistant messages without an LLM.
Strategy:
- Include the first assistant message (initial assessment).
- Include the last assistant message (final conclusion).
- Truncate each to ~500 chars.
"""
if not messages:
return "Empty conversation."
assistant_msgs = [m for m in messages if m.role == "assistant"]
if not assistant_msgs:
return "No assistant responses."
parts: list[str] = []
first = assistant_msgs[0].content
parts.append(first[:_TRUNCATE_CHARS])
if len(assistant_msgs) > 1:
last = assistant_msgs[-1].content
parts.append(last[:_TRUNCATE_CHARS])
return "\n\n".join(parts)
def _llm_summary(self, messages: list, output_keys: list[str]) -> str:
"""Produce a summary by calling the LLM provider."""
if self.llm is None:
raise ValueError("_llm_summary called without an LLM provider")
conversation_text = "\n".join(f"[{m.role}]: {m.content}" for m in messages)
key_hint = ""
if output_keys:
key_hint = (
"\nThe following output keys are especially important: "
+ ", ".join(output_keys)
+ ".\n"
)
system_prompt = (
"You are a concise summarizer. Given the conversation below, "
"produce a brief summary (at most ~500 tokens) that captures the "
"key decisions, findings, and outcomes. Focus on what was concluded "
"rather than the back-and-forth process." + key_hint
)
response = self.llm.complete(
messages=[{"role": "user", "content": conversation_text}],
system=system_prompt,
max_tokens=500,
)
return response.content.strip()
+115 -36
View File
@@ -108,6 +108,50 @@ class ConversationStore(Protocol):
# ---------------------------------------------------------------------------
def _try_extract_key(content: str, key: str) -> str | None:
"""Try 4 strategies to extract a *key*'s value from message content.
Strategies (in order):
1. Whole message is JSON ``json.loads``, check for key.
2. Embedded JSON via ``find_json_object`` helper.
3. Colon format: ``key: value``.
4. Equals format: ``key = value``.
"""
from framework.graph.node import find_json_object
# 1. Whole message is JSON
try:
parsed = json.loads(content)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 2. Embedded JSON via find_json_object
json_str = find_json_object(content)
if json_str:
try:
parsed = json.loads(json_str)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 3. Colon format: key: value
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
if match:
return match.group(1).strip()
# 4. Equals format: key = value
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
if match:
return match.group(1).strip()
return None
class NodeConversation:
"""Message history for a graph node with optional write-through persistence.
@@ -133,6 +177,7 @@ class NodeConversation:
self._messages: list[Message] = []
self._next_seq: int = 0
self._meta_persisted: bool = False
self._last_api_input_tokens: int | None = None
# --- Properties --------------------------------------------------------
@@ -205,14 +250,78 @@ class NodeConversation:
# --- Query -------------------------------------------------------------
def to_llm_messages(self) -> list[dict[str, Any]]:
"""Return messages as OpenAI-format dicts (system prompt excluded)."""
return [m.to_llm_dict() for m in self._messages]
"""Return messages as OpenAI-format dicts (system prompt excluded).
Automatically repairs orphaned tool_use blocks (assistant messages
with tool_calls that lack corresponding tool-result messages). This
can happen when a loop is cancelled mid-tool-execution.
"""
msgs = [m.to_llm_dict() for m in self._messages]
return self._repair_orphaned_tool_calls(msgs)
@staticmethod
def _repair_orphaned_tool_calls(
msgs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Ensure every tool_call has a matching tool-result message."""
repaired: list[dict[str, Any]] = []
for i, m in enumerate(msgs):
repaired.append(m)
tool_calls = m.get("tool_calls")
if m.get("role") != "assistant" or not tool_calls:
continue
# Collect IDs of tool results that follow this assistant message
answered: set[str] = set()
for j in range(i + 1, len(msgs)):
if msgs[j].get("role") == "tool":
tid = msgs[j].get("tool_call_id")
if tid:
answered.add(tid)
else:
break # stop at first non-tool message
# Patch any missing results
for tc in tool_calls:
tc_id = tc.get("id")
if tc_id and tc_id not in answered:
repaired.append(
{
"role": "tool",
"tool_call_id": tc_id,
"content": "ERROR: Tool execution was interrupted.",
}
)
return repaired
def estimate_tokens(self) -> int:
"""Rough token estimate: total characters / 4."""
"""Best available token estimate.
Uses actual API input token count when available (set via
:meth:`update_token_count`), otherwise falls back to the rough
``total_chars / 4`` heuristic.
"""
if self._last_api_input_tokens is not None:
return self._last_api_input_tokens
total_chars = sum(len(m.content) for m in self._messages)
return total_chars // 4
def update_token_count(self, actual_input_tokens: int) -> None:
"""Store actual API input token count for more accurate compaction.
Called by EventLoopNode after each LLM call with the ``input_tokens``
value from the API response. This value includes system prompt and
tool definitions, so it may be higher than a message-only estimate.
"""
self._last_api_input_tokens = actual_input_tokens
def usage_ratio(self) -> float:
"""Current token usage as a fraction of *max_history_tokens*.
Returns 0.0 when ``max_history_tokens`` is zero (unlimited).
"""
if self._max_history_tokens <= 0:
return 0.0
return self.estimate_tokens() / self._max_history_tokens
def needs_compaction(self) -> bool:
return self.estimate_tokens() >= self._max_history_tokens * self._compaction_threshold
@@ -244,39 +353,7 @@ class NodeConversation:
def _try_extract_key(self, content: str, key: str) -> str | None:
"""Try 4 strategies to extract a key's value from message content."""
from framework.graph.node import find_json_object
# 1. Whole message is JSON
try:
parsed = json.loads(content)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 2. Embedded JSON via find_json_object
json_str = find_json_object(content)
if json_str:
try:
parsed = json.loads(json_str)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 3. Colon format: key: value
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
if match:
return match.group(1).strip()
# 4. Equals format: key = value
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
if match:
return match.group(1).strip()
return None
return _try_extract_key(content, key)
# --- Lifecycle ---------------------------------------------------------
@@ -330,6 +407,7 @@ class NodeConversation:
await self._store.write_cursor({"next_seq": self._next_seq})
self._messages = [summary_msg] + recent_messages
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
async def clear(self) -> None:
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
@@ -337,6 +415,7 @@ class NodeConversation:
await self._store.delete_parts_before(self._next_seq)
await self._store.write_cursor({"next_seq": self._next_seq})
self._messages.clear()
self._last_api_input_tokens = None
def export_summary(self) -> str:
"""Structured summary with [STATS], [CONFIG], [RECENT_MESSAGES] sections."""
+36 -1
View File
@@ -11,7 +11,6 @@ our edges can be created dynamically by a Builder agent based on the goal.
Edge Types:
- always: Always traverse after source completes
- always: Always traverse after source completes
- on_success: Traverse only if source succeeds
- on_failure: Traverse only if source fails
- conditional: Traverse based on expression evaluation (SAFE SUBSET ONLY)
@@ -609,4 +608,40 @@ class GraphSpec(BaseModel):
continue
errors.append(f"Node '{node.id}' is unreachable from entry")
# Client-facing fan-out validation
fan_outs = self.detect_fan_out_nodes()
for source_id, targets in fan_outs.items():
client_facing_targets = [
t
for t in targets
if self.get_node(t) and getattr(self.get_node(t), "client_facing", False)
]
if len(client_facing_targets) > 1:
errors.append(
f"Fan-out from '{source_id}' has multiple client-facing nodes: "
f"{client_facing_targets}. Only one branch may be client-facing."
)
# Output key overlap on parallel event_loop nodes
for source_id, targets in fan_outs.items():
event_loop_targets = [
t
for t in targets
if self.get_node(t) and getattr(self.get_node(t), "node_type", "") == "event_loop"
]
if len(event_loop_targets) > 1:
seen_keys: dict[str, str] = {}
for node_id in event_loop_targets:
node = self.get_node(node_id)
for key in getattr(node, "output_keys", []):
if key in seen_keys:
errors.append(
f"Fan-out from '{source_id}': event_loop nodes "
f"'{seen_keys[key]}' and '{node_id}' both write to "
f"output_key '{key}'. Parallel event_loop nodes must "
f"have disjoint output_keys to prevent last-wins data loss."
)
else:
seen_keys[key] = node_id
return errors
+879
View File
@@ -0,0 +1,879 @@
"""EventLoopNode: Multi-turn LLM streaming loop with tool execution and judge evaluation.
Implements NodeProtocol and runs a streaming event loop:
1. Calls LLMProvider.stream() to get streaming events
2. Processes text deltas, tool calls, and finish events
3. Executes tools and feeds results back to the conversation
4. Uses judge evaluation (or implicit stop-reason) to decide loop termination
5. Publishes lifecycle events to EventBus
6. Persists conversation and outputs via write-through to ConversationStore
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol, runtime_checkable
from framework.graph.conversation import ConversationStore, NodeConversation
from framework.graph.node import NodeContext, NodeProtocol, NodeResult
from framework.llm.provider import Tool, ToolResult, ToolUse
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
ToolCallEvent,
)
from framework.runtime.event_bus import EventBus
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Judge protocol (simple 3-action interface for event loop evaluation)
# ---------------------------------------------------------------------------
@dataclass
class JudgeVerdict:
"""Result of judge evaluation for the event loop."""
action: Literal["ACCEPT", "RETRY", "ESCALATE"]
feedback: str = ""
@runtime_checkable
class JudgeProtocol(Protocol):
"""Protocol for event-loop judges.
Implementations evaluate the current state of the event loop and
decide whether to accept the output, retry with feedback, or escalate.
"""
async def evaluate(self, context: dict[str, Any]) -> JudgeVerdict: ...
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class LoopConfig:
"""Configuration for the event loop."""
max_iterations: int = 50
max_tool_calls_per_turn: int = 10
judge_every_n_turns: int = 1
stall_detection_threshold: int = 3
max_history_tokens: int = 32_000
store_prefix: str = ""
# ---------------------------------------------------------------------------
# Output accumulator with write-through persistence
# ---------------------------------------------------------------------------
@dataclass
class OutputAccumulator:
"""Accumulates output key-value pairs with optional write-through persistence.
Values are stored in memory and optionally written through to a
ConversationStore's cursor data for crash recovery.
"""
values: dict[str, Any] = field(default_factory=dict)
store: ConversationStore | None = None
async def set(self, key: str, value: Any) -> None:
"""Set a key-value pair, persisting immediately if store is available."""
self.values[key] = value
if self.store:
cursor = await self.store.read_cursor() or {}
outputs = cursor.get("outputs", {})
outputs[key] = value
cursor["outputs"] = outputs
await self.store.write_cursor(cursor)
def get(self, key: str) -> Any | None:
"""Get a value by key, or None if not present."""
return self.values.get(key)
def to_dict(self) -> dict[str, Any]:
"""Return a copy of all accumulated values."""
return dict(self.values)
def has_all_keys(self, required: list[str]) -> bool:
"""Check if all required keys have been set (non-None)."""
return all(key in self.values and self.values[key] is not None for key in required)
@classmethod
async def restore(cls, store: ConversationStore) -> OutputAccumulator:
"""Restore an OutputAccumulator from a store's cursor data."""
cursor = await store.read_cursor()
values = {}
if cursor and "outputs" in cursor:
values = cursor["outputs"]
return cls(values=values, store=store)
# ---------------------------------------------------------------------------
# EventLoopNode
# ---------------------------------------------------------------------------
class EventLoopNode(NodeProtocol):
"""Multi-turn LLM streaming loop with tool execution and judge evaluation.
Lifecycle:
1. Try to restore from durable state (crash recovery)
2. If no prior state, init from NodeSpec.system_prompt + input_keys
3. Loop: drain injection queue -> stream LLM -> execute tools -> judge
(each add_* and set_output writes through to store immediately)
4. Publish events to EventBus at each stage
5. Write cursor after each iteration
6. Terminate when judge returns ACCEPT (or max iterations)
7. Build output dict from OutputAccumulator
Always returns NodeResult with retryable=False semantics. The executor
must NOT retry event loop nodes -- retry is handled internally by the
judge (RETRY action continues the loop). See WP-7 enforcement.
"""
def __init__(
self,
event_bus: EventBus | None = None,
judge: JudgeProtocol | None = None,
config: LoopConfig | None = None,
tool_executor: Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None = None,
conversation_store: ConversationStore | None = None,
) -> None:
self._event_bus = event_bus
self._judge = judge
self._config = config or LoopConfig()
self._tool_executor = tool_executor
self._conversation_store = conversation_store
self._injection_queue: asyncio.Queue[str] = asyncio.Queue()
def validate_input(self, ctx: NodeContext) -> list[str]:
"""Validate hard requirements only.
Event loop nodes are LLM-powered and can reason about flexible input,
so input_keys are treated as hints not strict requirements.
Only the LLM provider is a hard dependency.
"""
errors = []
if ctx.llm is None:
errors.append("LLM provider is required for EventLoopNode")
return errors
# -------------------------------------------------------------------
# Public API
# -------------------------------------------------------------------
async def execute(self, ctx: NodeContext) -> NodeResult:
"""Run the event loop."""
start_time = time.time()
total_input_tokens = 0
total_output_tokens = 0
stream_id = ctx.node_id
node_id = ctx.node_id
# 1. Guard: LLM required
if ctx.llm is None:
return NodeResult(success=False, error="LLM provider not available")
# 2. Restore or create new conversation + accumulator
conversation, accumulator, start_iteration = await self._restore(ctx)
if conversation is None:
conversation = NodeConversation(
system_prompt=ctx.node_spec.system_prompt or "",
max_history_tokens=self._config.max_history_tokens,
output_keys=ctx.node_spec.output_keys or None,
store=self._conversation_store,
)
accumulator = OutputAccumulator(store=self._conversation_store)
start_iteration = 0
# Add initial user message from input data
initial_message = self._build_initial_message(ctx)
if initial_message:
await conversation.add_user_message(initial_message)
# 3. Build tool list: node tools + synthetic set_output tool
tools = list(ctx.available_tools)
set_output_tool = self._build_set_output_tool(ctx.node_spec.output_keys)
if set_output_tool:
tools.append(set_output_tool)
# 4. Publish loop started
await self._publish_loop_started(stream_id, node_id)
# 5. Stall detection state
recent_responses: list[str] = []
# 6. Main loop
for iteration in range(start_iteration, self._config.max_iterations):
# 6a. Check pause
if await self._check_pause(ctx, conversation, iteration):
latency_ms = int((time.time() - start_time) * 1000)
return NodeResult(
success=True,
output=accumulator.to_dict(),
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
# 6b. Drain injection queue
await self._drain_injection_queue(conversation)
# 6c. Publish iteration event
await self._publish_iteration(stream_id, node_id, iteration)
# 6d. Pre-turn compaction check (tiered)
if conversation.needs_compaction():
await self._compact_tiered(ctx, conversation)
# 6e. Run single LLM turn
assistant_text, tool_results_list, turn_tokens = await self._run_single_turn(
ctx, conversation, tools, iteration, accumulator
)
total_input_tokens += turn_tokens.get("input", 0)
total_output_tokens += turn_tokens.get("output", 0)
# 6e'. Feed actual API token count back for accurate estimation
turn_input = turn_tokens.get("input", 0)
if turn_input > 0:
conversation.update_token_count(turn_input)
# 6e''. Post-turn compaction check (catches tool-result bloat)
if conversation.needs_compaction():
await self._compact_tiered(ctx, conversation)
# 6f. Stall detection
recent_responses.append(assistant_text)
if len(recent_responses) > self._config.stall_detection_threshold:
recent_responses.pop(0)
if self._is_stalled(recent_responses):
await self._publish_stalled(stream_id, node_id)
latency_ms = int((time.time() - start_time) * 1000)
return NodeResult(
success=False,
error=(
f"Node stalled: {self._config.stall_detection_threshold} "
"consecutive identical responses"
),
output=accumulator.to_dict(),
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
# 6g. Write cursor checkpoint
await self._write_cursor(ctx, conversation, accumulator, iteration)
# 6h. Judge evaluation
should_judge = (
(iteration + 1) % self._config.judge_every_n_turns == 0
or not tool_results_list # no tool calls = natural stop
)
if should_judge:
verdict = await self._evaluate(
ctx,
conversation,
accumulator,
assistant_text,
tool_results_list,
iteration,
)
if verdict.action == "ACCEPT":
# Check for missing output keys
missing = self._get_missing_output_keys(accumulator, ctx.node_spec.output_keys)
if missing and self._judge is not None:
hint = (
f"Missing required output keys: {missing}. "
"Use set_output to provide them."
)
await conversation.add_user_message(hint)
continue
# Write outputs to shared memory
for key, value in accumulator.to_dict().items():
ctx.memory.write(key, value, validate=False)
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
latency_ms = int((time.time() - start_time) * 1000)
return NodeResult(
success=True,
output=accumulator.to_dict(),
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
elif verdict.action == "ESCALATE":
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
latency_ms = int((time.time() - start_time) * 1000)
return NodeResult(
success=False,
error=f"Judge escalated: {verdict.feedback}",
output=accumulator.to_dict(),
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
elif verdict.action == "RETRY":
if verdict.feedback:
await conversation.add_user_message(f"[Judge feedback]: {verdict.feedback}")
continue
# 7. Max iterations exhausted
await self._publish_loop_completed(stream_id, node_id, self._config.max_iterations)
latency_ms = int((time.time() - start_time) * 1000)
return NodeResult(
success=False,
error=(f"Max iterations ({self._config.max_iterations}) reached without acceptance"),
output=accumulator.to_dict(),
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
async def inject_event(self, content: str) -> None:
"""Inject an external event into the running loop.
The content becomes a user message prepended to the next iteration.
Thread-safe via asyncio.Queue.
"""
await self._injection_queue.put(content)
# -------------------------------------------------------------------
# Single LLM turn with caller-managed tool orchestration
# -------------------------------------------------------------------
async def _run_single_turn(
self,
ctx: NodeContext,
conversation: NodeConversation,
tools: list[Tool],
iteration: int,
accumulator: OutputAccumulator,
) -> tuple[str, list[dict], dict[str, int]]:
"""Run a single LLM turn with streaming and tool execution.
Returns (assistant_text, tool_results, token_counts).
"""
stream_id = ctx.node_id
node_id = ctx.node_id
token_counts: dict[str, int] = {"input": 0, "output": 0}
tool_call_count = 0
final_text = ""
# Inner tool loop: stream may produce tool calls requiring re-invocation
while True:
# Pre-send guard: if context is at or over budget, compact before
# calling the LLM — prevents API context-length errors.
if conversation.usage_ratio() >= 1.0:
logger.warning(
"Pre-send guard: context at %.0f%% of budget, compacting",
conversation.usage_ratio() * 100,
)
await self._compact_tiered(ctx, conversation)
messages = conversation.to_llm_messages()
accumulated_text = ""
tool_calls: list[ToolCallEvent] = []
# Stream LLM response
async for event in ctx.llm.stream(
messages=messages,
system=conversation.system_prompt,
tools=tools if tools else None,
max_tokens=ctx.max_tokens,
):
if isinstance(event, TextDeltaEvent):
accumulated_text = event.snapshot
await self._publish_text_delta(
stream_id, node_id, event.content, event.snapshot, ctx
)
elif isinstance(event, ToolCallEvent):
tool_calls.append(event)
elif isinstance(event, FinishEvent):
token_counts["input"] += event.input_tokens
token_counts["output"] += event.output_tokens
elif isinstance(event, StreamErrorEvent):
if not event.recoverable:
raise RuntimeError(f"Stream error: {event.error}")
logger.warning(f"Recoverable stream error: {event.error}")
final_text = accumulated_text
# Record assistant message (write-through via conversation store)
tc_dicts = None
if tool_calls:
tc_dicts = [
{
"id": tc.tool_use_id,
"type": "function",
"function": {
"name": tc.tool_name,
"arguments": json.dumps(tc.tool_input),
},
}
for tc in tool_calls
]
await conversation.add_assistant_message(
content=accumulated_text,
tool_calls=tc_dicts,
)
# If no tool calls, turn is complete
if not tool_calls:
return final_text, [], token_counts
# Execute tool calls
tool_results: list[dict] = []
for tc in tool_calls:
tool_call_count += 1
if tool_call_count > self._config.max_tool_calls_per_turn:
logger.warning(
f"Max tool calls per turn ({self._config.max_tool_calls_per_turn}) exceeded"
)
break
# Publish tool call started
await self._publish_tool_started(
stream_id, node_id, tc.tool_use_id, tc.tool_name, tc.tool_input
)
# Handle set_output synthetic tool
if tc.tool_name == "set_output":
result = self._handle_set_output(tc.tool_input, ctx.node_spec.output_keys)
result = ToolResult(
tool_use_id=tc.tool_use_id,
content=result.content,
is_error=result.is_error,
)
# Async write-through for set_output
if not result.is_error:
await accumulator.set(tc.tool_input["key"], tc.tool_input["value"])
else:
# Execute real tool
result = await self._execute_tool(tc)
# Record tool result in conversation (write-through)
await conversation.add_tool_result(
tool_use_id=tc.tool_use_id,
content=result.content,
is_error=result.is_error,
)
tool_results.append(
{
"tool_use_id": tc.tool_use_id,
"tool_name": tc.tool_name,
"content": result.content,
"is_error": result.is_error,
}
)
# Publish tool call completed
await self._publish_tool_completed(
stream_id,
node_id,
tc.tool_use_id,
tc.tool_name,
result.content,
result.is_error,
)
# Tool calls processed -- loop back to stream with updated conversation
# -------------------------------------------------------------------
# set_output synthetic tool
# -------------------------------------------------------------------
def _build_set_output_tool(self, output_keys: list[str] | None) -> Tool | None:
"""Build the synthetic set_output tool for explicit output declaration."""
if not output_keys:
return None
return Tool(
name="set_output",
description=(
"Set an output value for this node. Call once per output key. "
f"Valid keys: {output_keys}"
),
parameters={
"type": "object",
"properties": {
"key": {
"type": "string",
"description": f"Output key. Must be one of: {output_keys}",
"enum": output_keys,
},
"value": {
"type": "string",
"description": "The output value to store.",
},
},
"required": ["key", "value"],
},
)
def _handle_set_output(
self,
tool_input: dict[str, Any],
output_keys: list[str] | None,
) -> ToolResult:
"""Handle set_output tool call. Returns ToolResult (sync)."""
key = tool_input.get("key", "")
valid_keys = output_keys or []
if key not in valid_keys:
return ToolResult(
tool_use_id="",
content=f"Invalid output key '{key}'. Valid keys: {valid_keys}",
is_error=True,
)
return ToolResult(
tool_use_id="",
content=f"Output '{key}' set successfully.",
is_error=False,
)
# -------------------------------------------------------------------
# Judge evaluation
# -------------------------------------------------------------------
async def _evaluate(
self,
ctx: NodeContext,
conversation: NodeConversation,
accumulator: OutputAccumulator,
assistant_text: str,
tool_results: list[dict],
iteration: int,
) -> JudgeVerdict:
"""Evaluate the current state using judge or implicit logic."""
if self._judge is not None:
context = {
"assistant_text": assistant_text,
"tool_calls": tool_results,
"output_accumulator": accumulator.to_dict(),
"iteration": iteration,
"conversation_summary": conversation.export_summary(),
"output_keys": ctx.node_spec.output_keys,
"missing_keys": self._get_missing_output_keys(
accumulator, ctx.node_spec.output_keys
),
}
return await self._judge.evaluate(context)
# Implicit judge: accept when no tool calls and all output keys present
if not tool_results:
missing = self._get_missing_output_keys(accumulator, ctx.node_spec.output_keys)
if not missing:
return JudgeVerdict(action="ACCEPT")
else:
return JudgeVerdict(
action="RETRY",
feedback=(
f"Missing output keys: {missing}. Use set_output tool to provide them."
),
)
# Tool calls were made -- continue loop
return JudgeVerdict(action="RETRY", feedback="")
# -------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------
def _build_initial_message(self, ctx: NodeContext) -> str:
"""Build the initial user message from input data and memory.
Includes ALL input_data (not just declared input_keys) so that
upstream handoff data flows through regardless of key naming.
Declared input_keys are also checked in shared memory as fallback.
"""
parts = []
seen: set[str] = set()
# Include everything from input_data (flexible handoff)
for key, value in ctx.input_data.items():
if value is not None:
parts.append(f"{key}: {value}")
seen.add(key)
# Fallback: check memory for declared input_keys not already covered
for key in ctx.node_spec.input_keys:
if key not in seen:
value = ctx.memory.read(key)
if value is not None:
parts.append(f"{key}: {value}")
if ctx.goal_context:
parts.append(f"\nGoal: {ctx.goal_context}")
return "\n".join(parts) if parts else "Begin."
def _get_missing_output_keys(
self,
accumulator: OutputAccumulator,
output_keys: list[str] | None,
) -> list[str]:
"""Return output keys that have not been set yet."""
if not output_keys:
return []
return [k for k in output_keys if accumulator.get(k) is None]
def _is_stalled(self, recent_responses: list[str]) -> bool:
"""Detect stall: N consecutive identical non-empty responses."""
if len(recent_responses) < self._config.stall_detection_threshold:
return False
if not recent_responses[0]:
return False
return all(r == recent_responses[0] for r in recent_responses)
async def _execute_tool(self, tc: ToolCallEvent) -> ToolResult:
"""Execute a tool call, handling both sync and async executors."""
if self._tool_executor is None:
return ToolResult(
tool_use_id=tc.tool_use_id,
content=f"No tool executor configured for '{tc.tool_name}'",
is_error=True,
)
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
result = self._tool_executor(tool_use)
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
result = await result
return result
async def _compact_tiered(
self,
ctx: NodeContext,
conversation: NodeConversation,
) -> None:
"""Run compaction with aggressiveness scaled to usage level.
| Usage | Strategy |
|----------------|---------------------------------------------|
| 80-100% | Normal: LLM summary, keep 4 recent messages |
| 100-120% | Aggressive: LLM summary, keep 2 recent |
| >= 120% | Emergency: static summary, keep 1 recent |
"""
ratio = conversation.usage_ratio()
if ratio >= 1.2:
# Emergency -- don't risk another LLM call on a bloated context
logger.warning("Emergency compaction triggered (usage %.0f%%)", ratio * 100)
await conversation.compact(
"Previous conversation context (emergency compaction).",
keep_recent=1,
)
elif ratio >= 1.0:
logger.info("Aggressive compaction triggered (usage %.0f%%)", ratio * 100)
summary = await self._generate_compaction_summary(ctx, conversation)
await conversation.compact(summary, keep_recent=2)
else:
summary = await self._generate_compaction_summary(ctx, conversation)
await conversation.compact(summary, keep_recent=4)
async def _generate_compaction_summary(
self,
ctx: NodeContext,
conversation: NodeConversation,
) -> str:
"""Use LLM to generate a conversation summary for compaction."""
messages_text = "\n".join(
f"[{m.role}]: {m.content[:200]}" for m in conversation.messages[-10:]
)
prompt = (
"Summarize this conversation so far in 2-3 sentences, "
"preserving key decisions and results:\n\n"
f"{messages_text}"
)
try:
response = ctx.llm.complete(
messages=[{"role": "user", "content": prompt}],
system="Summarize conversations concisely.",
max_tokens=300,
)
return response.content
except Exception as e:
logger.warning(f"Compaction summary generation failed: {e}")
return "Previous conversation context (summary unavailable)."
# -------------------------------------------------------------------
# Persistence: restore, cursor, injection, pause
# -------------------------------------------------------------------
async def _restore(
self,
ctx: NodeContext,
) -> tuple[NodeConversation | None, OutputAccumulator | None, int]:
"""Attempt to restore from a previous checkpoint."""
if self._conversation_store is None:
return None, None, 0
conversation = await NodeConversation.restore(self._conversation_store)
if conversation is None:
return None, None, 0
accumulator = await OutputAccumulator.restore(self._conversation_store)
cursor = await self._conversation_store.read_cursor()
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
logger.info(
f"Restored event loop: iteration={start_iteration}, "
f"messages={conversation.message_count}, "
f"outputs={list(accumulator.values.keys())}"
)
return conversation, accumulator, start_iteration
async def _write_cursor(
self,
ctx: NodeContext,
conversation: NodeConversation,
accumulator: OutputAccumulator,
iteration: int,
) -> None:
"""Write checkpoint cursor for crash recovery."""
if self._conversation_store:
cursor = await self._conversation_store.read_cursor() or {}
cursor.update(
{
"iteration": iteration,
"node_id": ctx.node_id,
"next_seq": conversation.next_seq,
"outputs": accumulator.to_dict(),
}
)
await self._conversation_store.write_cursor(cursor)
async def _drain_injection_queue(self, conversation: NodeConversation) -> int:
"""Drain all pending injected events as user messages. Returns count."""
count = 0
while not self._injection_queue.empty():
try:
content = self._injection_queue.get_nowait()
await conversation.add_user_message(f"[External event]: {content}")
count += 1
except asyncio.QueueEmpty:
break
return count
async def _check_pause(
self,
ctx: NodeContext,
conversation: NodeConversation,
iteration: int,
) -> bool:
"""Check if pause has been requested. Returns True if paused."""
pause_requested = ctx.input_data.get("pause_requested", False)
if not pause_requested:
pause_requested = ctx.memory.read("pause_requested") or False
if pause_requested:
logger.info(f"Pause requested at iteration {iteration}")
return True
return False
# -------------------------------------------------------------------
# EventBus publishing helpers
# -------------------------------------------------------------------
async def _publish_loop_started(self, stream_id: str, node_id: str) -> None:
if self._event_bus:
await self._event_bus.emit_node_loop_started(
stream_id=stream_id,
node_id=node_id,
max_iterations=self._config.max_iterations,
)
async def _publish_iteration(self, stream_id: str, node_id: str, iteration: int) -> None:
if self._event_bus:
await self._event_bus.emit_node_loop_iteration(
stream_id=stream_id,
node_id=node_id,
iteration=iteration,
)
async def _publish_loop_completed(self, stream_id: str, node_id: str, iterations: int) -> None:
if self._event_bus:
await self._event_bus.emit_node_loop_completed(
stream_id=stream_id,
node_id=node_id,
iterations=iterations,
)
async def _publish_stalled(self, stream_id: str, node_id: str) -> None:
if self._event_bus:
await self._event_bus.emit_node_stalled(
stream_id=stream_id,
node_id=node_id,
reason="Consecutive identical responses detected",
)
async def _publish_text_delta(
self,
stream_id: str,
node_id: str,
content: str,
snapshot: str,
ctx: NodeContext,
) -> None:
if self._event_bus:
if ctx.node_spec.client_facing:
await self._event_bus.emit_client_output_delta(
stream_id=stream_id,
node_id=node_id,
content=content,
snapshot=snapshot,
)
else:
await self._event_bus.emit_llm_text_delta(
stream_id=stream_id,
node_id=node_id,
content=content,
snapshot=snapshot,
)
async def _publish_tool_started(
self,
stream_id: str,
node_id: str,
tool_use_id: str,
tool_name: str,
tool_input: dict,
) -> None:
if self._event_bus:
await self._event_bus.emit_tool_call_started(
stream_id=stream_id,
node_id=node_id,
tool_use_id=tool_use_id,
tool_name=tool_name,
tool_input=tool_input,
)
async def _publish_tool_completed(
self,
stream_id: str,
node_id: str,
tool_use_id: str,
tool_name: str,
result: str,
is_error: bool,
) -> None:
if self._event_bus:
await self._event_bus.emit_tool_call_completed(
stream_id=stream_id,
node_id=node_id,
tool_use_id=tool_use_id,
tool_name=tool_name,
result=result,
is_error=is_error,
)
+51 -4
View File
@@ -11,6 +11,7 @@ The executor:
import asyncio
import logging
import warnings
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
@@ -380,6 +381,15 @@ class GraphExecutor:
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
max_retries = getattr(node_spec, "max_retries", 3)
# Event loop nodes handle retry internally via judge —
# executor retry is catastrophic (retry multiplication)
if node_spec.node_type == "event_loop" and max_retries > 0:
self.logger.warning(
f"EventLoopNode '{node_spec.id}' has max_retries={max_retries}. "
"Overriding to 0 — event loop nodes handle retry internally via judge."
)
max_retries = 0
if node_retry_counts[current_node_id] < max_retries:
# Retry - don't increment steps for retries
steps -= 1
@@ -658,7 +668,15 @@ class GraphExecutor:
)
# Valid node types - no ambiguous "llm" type allowed
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
VALID_NODE_TYPES = {
"llm_tool_use",
"llm_generate",
"router",
"function",
"human_input",
"event_loop",
}
DEPRECATED_NODE_TYPES = {"llm_tool_use": "event_loop", "llm_generate": "event_loop"}
def _get_node_implementation(
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
@@ -676,6 +694,17 @@ class GraphExecutor:
f"Use 'llm_tool_use' for nodes that call tools, 'llm_generate' for text generation."
)
# Warn on deprecated node types
if node_spec.node_type in self.DEPRECATED_NODE_TYPES:
replacement = self.DEPRECATED_NODE_TYPES[node_spec.node_type]
warnings.warn(
f"Node type '{node_spec.node_type}' is deprecated. "
f"Use '{replacement}' instead. "
f"Node: '{node_spec.id}'",
DeprecationWarning,
stacklevel=2,
)
# Create based on type
if node_spec.node_type == "llm_tool_use":
if not node_spec.tools:
@@ -713,6 +742,13 @@ class GraphExecutor:
cleanup_llm_model=cleanup_llm_model,
)
if node_spec.node_type == "event_loop":
# Event loop nodes must be pre-registered (like function nodes)
raise RuntimeError(
f"EventLoopNode '{node_spec.id}' not found in registry. "
"Register it with executor.register_node() before execution."
)
# Should never reach here due to validation above
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
@@ -909,6 +945,17 @@ class GraphExecutor:
branch.status = "failed"
branch.error = f"Node {branch.node_id} not found in graph"
return branch, RuntimeError(branch.error)
effective_max_retries = node_spec.max_retries
if node_spec.node_type == "event_loop":
if effective_max_retries > 1:
self.logger.warning(
f"EventLoopNode '{node_spec.id}' has "
f"max_retries={effective_max_retries}. Overriding "
"to 1 — event loop nodes handle retry internally."
)
effective_max_retries = 1
branch.status = "running"
try:
@@ -942,7 +989,7 @@ class GraphExecutor:
# Execute with retries
last_result = None
for attempt in range(node_spec.max_retries):
for attempt in range(effective_max_retries):
branch.retry_count = attempt
# Build context for this branch
@@ -970,7 +1017,7 @@ class GraphExecutor:
self.logger.warning(
f" ↻ Branch {node_spec.name}: "
f"retry {attempt + 1}/{node_spec.max_retries}"
f"retry {attempt + 1}/{effective_max_retries}"
)
# All retries exhausted
@@ -979,7 +1026,7 @@ class GraphExecutor:
branch.result = last_result
self.logger.error(
f" ✗ Branch {node_spec.name}: "
f"failed after {node_spec.max_retries} attempts"
f"failed after {effective_max_retries} attempts"
)
return branch, last_result
+21 -2
View File
@@ -153,7 +153,10 @@ class NodeSpec(BaseModel):
# Node behavior type
node_type: str = Field(
default="llm_tool_use",
description="Type: 'llm_tool_use', 'llm_generate', 'function', 'router', 'human_input'",
description=(
"Type: 'event_loop', 'function', 'router', 'human_input'. "
"Deprecated: 'llm_tool_use', 'llm_generate' (use 'event_loop' instead)."
),
)
# Data flow
@@ -218,6 +221,12 @@ class NodeSpec(BaseModel):
description="Maximum retries when Pydantic validation fails (with feedback to LLM)",
)
# Client-facing behavior
client_facing: bool = Field(
default=False,
description="If True, this node streams output to the end user and can request input.",
)
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
@@ -1348,7 +1357,9 @@ Expected output keys: {output_keys}
LLM Response:
{raw_response}
Output ONLY the JSON object, nothing else."""
Output ONLY the JSON object, nothing else.
If no valid JSON object exists in the response, output exactly: {{"error": "NO_JSON_FOUND"}}
Do NOT fabricate data or return empty objects."""
try:
result = cleaner_llm.complete(
@@ -1395,6 +1406,14 @@ Output ONLY the JSON object, nothing else."""
parsed = json.loads(cleaned)
except json.JSONDecodeError:
parsed = json.loads(_fix_unescaped_newlines_in_json(cleaned))
# Validate LLM didn't return empty or fabricated data
if parsed.get("error") == "NO_JSON_FOUND":
raise ValueError("Cannot parse JSON from response")
if not parsed or parsed == {}:
raise ValueError("Cannot parse JSON from response")
if all(v is None for v in parsed.values()):
raise ValueError("Cannot parse JSON from response")
logger.info(" ✓ LLM cleaned JSON output")
return parsed
+24 -1
View File
@@ -1,8 +1,31 @@
"""LLM provider abstraction."""
from framework.llm.provider import LLMProvider, LLMResponse
from framework.llm.stream_events import (
FinishEvent,
ReasoningDeltaEvent,
ReasoningStartEvent,
StreamErrorEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
ToolResultEvent,
)
__all__ = ["LLMProvider", "LLMResponse"]
__all__ = [
"LLMProvider",
"LLMResponse",
"StreamEvent",
"TextDeltaEvent",
"TextEndEvent",
"ToolCallEvent",
"ToolResultEvent",
"ReasoningStartEvent",
"ReasoningDeltaEvent",
"FinishEvent",
"StreamErrorEvent",
]
try:
from framework.llm.anthropic import AnthropicProvider # noqa: F401
+167 -1
View File
@@ -7,10 +7,11 @@ Groq, and local models.
See: https://docs.litellm.ai/docs/providers
"""
import asyncio
import json
import logging
import time
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from datetime import datetime
from pathlib import Path
from typing import Any
@@ -23,6 +24,7 @@ except ImportError:
RateLimitError = Exception # type: ignore[assignment, misc]
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
from framework.llm.stream_events import StreamEvent
logger = logging.getLogger(__name__)
@@ -425,3 +427,167 @@ class LiteLLMProvider(LLMProvider):
},
},
}
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
"""Stream a completion via litellm.acompletion(stream=True).
Yields StreamEvent objects as chunks arrive from the provider.
Tool call arguments are accumulated across chunks and yielded as
a single ToolCallEvent with fully parsed JSON when complete.
Empty responses (e.g. Gemini stealth rate-limits that return 200
with no content) are retried with exponential backoff, mirroring
the retry behaviour of ``_completion_with_rate_limit_retry``.
"""
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
full_messages: list[dict[str, Any]] = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
"max_tokens": max_tokens,
"stream": True,
"stream_options": {"include_usage": True},
**self.extra_kwargs,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
buffered_events: list[StreamEvent] = []
accumulated_text = ""
tool_calls_acc: dict[int, dict[str, str]] = {}
input_tokens = 0
output_tokens = 0
try:
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
async for chunk in response:
choice = chunk.choices[0] if chunk.choices else None
if not choice:
continue
delta = choice.delta
# --- Text content ---
if delta and delta.content:
accumulated_text += delta.content
buffered_events.append(
TextDeltaEvent(
content=delta.content,
snapshot=accumulated_text,
)
)
# --- Tool calls (accumulate across chunks) ---
if delta and delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
if tc.id:
tool_calls_acc[idx]["id"] = tc.id
if tc.function:
if tc.function.name:
tool_calls_acc[idx]["name"] = tc.function.name
if tc.function.arguments:
tool_calls_acc[idx]["arguments"] += tc.function.arguments
# --- Finish ---
if choice.finish_reason:
for _idx, tc_data in sorted(tool_calls_acc.items()):
try:
parsed_args = json.loads(tc_data["arguments"])
except (json.JSONDecodeError, KeyError):
parsed_args = {"_raw": tc_data.get("arguments", "")}
buffered_events.append(
ToolCallEvent(
tool_use_id=tc_data["id"],
tool_name=tc_data["name"],
tool_input=parsed_args,
)
)
if accumulated_text:
buffered_events.append(TextEndEvent(full_text=accumulated_text))
usage = getattr(chunk, "usage", None)
if usage:
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
output_tokens = getattr(usage, "completion_tokens", 0) or 0
buffered_events.append(
FinishEvent(
stop_reason=choice.finish_reason,
input_tokens=input_tokens,
output_tokens=output_tokens,
model=self.model,
)
)
# Check whether the stream produced any real content.
has_content = accumulated_text or tool_calls_acc
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
token_count, token_method = _estimate_tokens(
self.model,
full_messages,
)
dump_path = _dump_failed_request(
model=self.model,
kwargs=kwargs,
error_type="empty_stream",
attempt=attempt,
)
logger.warning(
f"[stream-retry] {self.model} returned empty stream — "
f"~{token_count} tokens ({token_method}). "
f"Request dumped to: {dump_path}. "
f"Retrying in {wait}s "
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
)
await asyncio.sleep(wait)
continue
# Success (or final attempt) — flush buffered events.
for event in buffered_events:
yield event
return
except RateLimitError as e:
if attempt < RATE_LIMIT_MAX_RETRIES:
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
logger.warning(
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
f"Retrying in {wait}s "
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
)
await asyncio.sleep(wait)
continue
yield StreamErrorEvent(error=str(e), recoverable=False)
return
except Exception as e:
yield StreamErrorEvent(error=str(e), recoverable=False)
return
+32 -1
View File
@@ -2,10 +2,16 @@
import json
import re
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from typing import Any
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
from framework.llm.stream_events import (
FinishEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
)
class MockLLMProvider(LLMProvider):
@@ -175,3 +181,28 @@ class MockLLMProvider(LLMProvider):
output_tokens=0,
stop_reason="mock_complete",
)
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
"""Stream a mock completion as word-level TextDeltaEvents.
Splits the mock response into words and yields each as a separate
TextDeltaEvent with an accumulating snapshot, exercising the full
streaming pipeline without any API calls.
"""
content = self._generate_mock_response(system=system, json_mode=False)
words = content.split(" ")
accumulated = ""
for i, word in enumerate(words):
chunk = word if i == 0 else " " + word
accumulated += chunk
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
yield TextEndEvent(full_text=accumulated)
yield FinishEvent(stop_reason="mock_complete", model=self.model)
+43 -1
View File
@@ -1,7 +1,7 @@
"""LLM Provider abstraction for pluggable LLM backends."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from dataclasses import dataclass, field
from typing import Any
@@ -108,3 +108,45 @@ class LLMProvider(ABC):
Final LLMResponse after tool use completes
"""
pass
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator["StreamEvent"]:
"""
Stream a completion as an async iterator of StreamEvents.
Default implementation wraps complete() with synthetic events.
Subclasses SHOULD override for true streaming.
Tool orchestration is the CALLER's responsibility:
- Caller detects ToolCallEvent, executes tool, adds result
to messages, calls stream() again.
"""
from framework.llm.stream_events import (
FinishEvent,
TextDeltaEvent,
TextEndEvent,
)
response = self.complete(
messages=messages,
system=system,
tools=tools,
max_tokens=max_tokens,
)
yield TextDeltaEvent(content=response.content, snapshot=response.content)
yield TextEndEvent(full_text=response.content)
yield FinishEvent(
stop_reason=response.stop_reason,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
model=response.model,
)
# Deferred import target for type annotation
from framework.llm.stream_events import StreamEvent as StreamEvent # noqa: E402, F401
+96
View File
@@ -0,0 +1,96 @@
"""Stream event types for LLM streaming responses.
Defines a discriminated union of frozen dataclasses representing every event
a streaming LLM call can produce. These types form the contract between the
LLM provider layer, EventLoopNode, event bus, persistence, and monitoring.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
@dataclass(frozen=True)
class TextDeltaEvent:
"""A chunk of text produced by the LLM."""
type: Literal["text_delta"] = "text_delta"
content: str = "" # this chunk's text
snapshot: str = "" # accumulated text so far
@dataclass(frozen=True)
class TextEndEvent:
"""Signals that text generation is complete."""
type: Literal["text_end"] = "text_end"
full_text: str = ""
@dataclass(frozen=True)
class ToolCallEvent:
"""The LLM has requested a tool call."""
type: Literal["tool_call"] = "tool_call"
tool_use_id: str = ""
tool_name: str = ""
tool_input: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class ToolResultEvent:
"""Result of executing a tool call."""
type: Literal["tool_result"] = "tool_result"
tool_use_id: str = ""
content: str = ""
is_error: bool = False
@dataclass(frozen=True)
class ReasoningStartEvent:
"""The LLM has started a reasoning/thinking block."""
type: Literal["reasoning_start"] = "reasoning_start"
@dataclass(frozen=True)
class ReasoningDeltaEvent:
"""A chunk of reasoning/thinking content."""
type: Literal["reasoning_delta"] = "reasoning_delta"
content: str = ""
@dataclass(frozen=True)
class FinishEvent:
"""The LLM has finished generating."""
type: Literal["finish"] = "finish"
stop_reason: str = ""
input_tokens: int = 0
output_tokens: int = 0
model: str = ""
@dataclass(frozen=True)
class StreamErrorEvent:
"""An error occurred during streaming."""
type: Literal["error"] = "error"
error: str = ""
recoverable: bool = False
# Discriminated union of all stream event types
StreamEvent = (
TextDeltaEvent
| TextEndEvent
| ToolCallEvent
| ToolResultEvent
| ReasoningStartEvent
| ReasoningDeltaEvent
| FinishEvent
| StreamErrorEvent
)
+277
View File
@@ -41,6 +41,28 @@ class EventType(str, Enum):
STREAM_STARTED = "stream_started"
STREAM_STOPPED = "stream_stopped"
# Node event-loop lifecycle
NODE_LOOP_STARTED = "node_loop_started"
NODE_LOOP_ITERATION = "node_loop_iteration"
NODE_LOOP_COMPLETED = "node_loop_completed"
# LLM streaming observability
LLM_TEXT_DELTA = "llm_text_delta"
LLM_REASONING_DELTA = "llm_reasoning_delta"
# Tool lifecycle
TOOL_CALL_STARTED = "tool_call_started"
TOOL_CALL_COMPLETED = "tool_call_completed"
# Client I/O (client_facing=True nodes only)
CLIENT_OUTPUT_DELTA = "client_output_delta"
CLIENT_INPUT_REQUESTED = "client_input_requested"
# Internal node observability (client_facing=False nodes)
NODE_INTERNAL_OUTPUT = "node_internal_output"
NODE_INPUT_BLOCKED = "node_input_blocked"
NODE_STALLED = "node_stalled"
# Custom events
CUSTOM = "custom"
@@ -51,6 +73,7 @@ class AgentEvent:
type: EventType
stream_id: str
node_id: str | None = None # Which node emitted this event
execution_id: str | None = None
data: dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.now)
@@ -61,6 +84,7 @@ class AgentEvent:
return {
"type": self.type.value,
"stream_id": self.stream_id,
"node_id": self.node_id,
"execution_id": self.execution_id,
"data": self.data,
"timestamp": self.timestamp.isoformat(),
@@ -80,6 +104,7 @@ class Subscription:
event_types: set[EventType]
handler: EventHandler
filter_stream: str | None = None # Only receive events from this stream
filter_node: str | None = None # Only receive events from this node
filter_execution: str | None = None # Only receive events from this execution
@@ -138,6 +163,7 @@ class EventBus:
event_types: list[EventType],
handler: EventHandler,
filter_stream: str | None = None,
filter_node: str | None = None,
filter_execution: str | None = None,
) -> str:
"""
@@ -147,6 +173,7 @@ class EventBus:
event_types: Types of events to receive
handler: Async function to call when event occurs
filter_stream: Only receive events from this stream
filter_node: Only receive events from this node
filter_execution: Only receive events from this execution
Returns:
@@ -160,6 +187,7 @@ class EventBus:
event_types=set(event_types),
handler=handler,
filter_stream=filter_stream,
filter_node=filter_node,
filter_execution=filter_execution,
)
@@ -218,6 +246,10 @@ class EventBus:
if subscription.filter_stream and subscription.filter_stream != event.stream_id:
return False
# Check node filter
if subscription.filter_node and subscription.filter_node != event.node_id:
return False
# Check execution filter
if subscription.filter_execution and subscription.filter_execution != event.execution_id:
return False
@@ -359,6 +391,248 @@ class EventBus:
)
)
# === NODE EVENT-LOOP PUBLISHERS ===
async def emit_node_loop_started(
self,
stream_id: str,
node_id: str,
execution_id: str | None = None,
max_iterations: int | None = None,
) -> None:
"""Emit node loop started event."""
await self.publish(
AgentEvent(
type=EventType.NODE_LOOP_STARTED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"max_iterations": max_iterations},
)
)
async def emit_node_loop_iteration(
self,
stream_id: str,
node_id: str,
iteration: int,
execution_id: str | None = None,
) -> None:
"""Emit node loop iteration event."""
await self.publish(
AgentEvent(
type=EventType.NODE_LOOP_ITERATION,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"iteration": iteration},
)
)
async def emit_node_loop_completed(
self,
stream_id: str,
node_id: str,
iterations: int,
execution_id: str | None = None,
) -> None:
"""Emit node loop completed event."""
await self.publish(
AgentEvent(
type=EventType.NODE_LOOP_COMPLETED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"iterations": iterations},
)
)
# === LLM STREAMING PUBLISHERS ===
async def emit_llm_text_delta(
self,
stream_id: str,
node_id: str,
content: str,
snapshot: str,
execution_id: str | None = None,
) -> None:
"""Emit LLM text delta event."""
await self.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"content": content, "snapshot": snapshot},
)
)
async def emit_llm_reasoning_delta(
self,
stream_id: str,
node_id: str,
content: str,
execution_id: str | None = None,
) -> None:
"""Emit LLM reasoning delta event."""
await self.publish(
AgentEvent(
type=EventType.LLM_REASONING_DELTA,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"content": content},
)
)
# === TOOL LIFECYCLE PUBLISHERS ===
async def emit_tool_call_started(
self,
stream_id: str,
node_id: str,
tool_use_id: str,
tool_name: str,
tool_input: dict[str, Any] | None = None,
execution_id: str | None = None,
) -> None:
"""Emit tool call started event."""
await self.publish(
AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={
"tool_use_id": tool_use_id,
"tool_name": tool_name,
"tool_input": tool_input or {},
},
)
)
async def emit_tool_call_completed(
self,
stream_id: str,
node_id: str,
tool_use_id: str,
tool_name: str,
result: str = "",
is_error: bool = False,
execution_id: str | None = None,
) -> None:
"""Emit tool call completed event."""
await self.publish(
AgentEvent(
type=EventType.TOOL_CALL_COMPLETED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={
"tool_use_id": tool_use_id,
"tool_name": tool_name,
"result": result,
"is_error": is_error,
},
)
)
# === CLIENT I/O PUBLISHERS ===
async def emit_client_output_delta(
self,
stream_id: str,
node_id: str,
content: str,
snapshot: str,
execution_id: str | None = None,
) -> None:
"""Emit client output delta event (client_facing=True nodes)."""
await self.publish(
AgentEvent(
type=EventType.CLIENT_OUTPUT_DELTA,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"content": content, "snapshot": snapshot},
)
)
async def emit_client_input_requested(
self,
stream_id: str,
node_id: str,
prompt: str = "",
execution_id: str | None = None,
) -> None:
"""Emit client input requested event (client_facing=True nodes)."""
await self.publish(
AgentEvent(
type=EventType.CLIENT_INPUT_REQUESTED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"prompt": prompt},
)
)
# === INTERNAL NODE PUBLISHERS ===
async def emit_node_internal_output(
self,
stream_id: str,
node_id: str,
content: str,
execution_id: str | None = None,
) -> None:
"""Emit node internal output event (client_facing=False nodes)."""
await self.publish(
AgentEvent(
type=EventType.NODE_INTERNAL_OUTPUT,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"content": content},
)
)
async def emit_node_stalled(
self,
stream_id: str,
node_id: str,
reason: str = "",
execution_id: str | None = None,
) -> None:
"""Emit node stalled event."""
await self.publish(
AgentEvent(
type=EventType.NODE_STALLED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"reason": reason},
)
)
async def emit_node_input_blocked(
self,
stream_id: str,
node_id: str,
prompt: str = "",
execution_id: str | None = None,
) -> None:
"""Emit node input blocked event."""
await self.publish(
AgentEvent(
type=EventType.NODE_INPUT_BLOCKED,
stream_id=stream_id,
node_id=node_id,
execution_id=execution_id,
data={"prompt": prompt},
)
)
# === QUERY OPERATIONS ===
def get_history(
@@ -410,6 +684,7 @@ class EventBus:
self,
event_type: EventType,
stream_id: str | None = None,
node_id: str | None = None,
execution_id: str | None = None,
timeout: float | None = None,
) -> AgentEvent | None:
@@ -419,6 +694,7 @@ class EventBus:
Args:
event_type: Type of event to wait for
stream_id: Filter by stream
node_id: Filter by node
execution_id: Filter by execution
timeout: Maximum time to wait (seconds)
@@ -438,6 +714,7 @@ class EventBus:
event_types=[event_type],
handler=handler,
filter_stream=stream_id,
filter_node=node_id,
filter_execution=execution_id,
)
+237
View File
@@ -0,0 +1,237 @@
"""
Tests for client-facing fan-out and event_loop output_key overlap validation.
Validates two rules added to GraphSpec.validate():
1. Fan-out must not have multiple client_facing=True targets.
2. Parallel event_loop nodes must have disjoint output_keys.
"""
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.node import NodeSpec
# ---------------------------------------------------------------------------
# Rule 1: client_facing fan-out
# ---------------------------------------------------------------------------
class TestClientFacingFanOut:
"""Fan-out to multiple client_facing=True targets must be rejected."""
def test_fan_out_two_client_facing_fails(self):
"""Two client-facing targets on the same fan-out -> error."""
graph = GraphSpec(
id="g1",
goal_id="goal1",
entry_node="src",
nodes=[
NodeSpec(id="src", name="src", description="Source node"),
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
NodeSpec(id="b", name="b", description="Node b", client_facing=True),
],
edges=[
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
],
)
errors = graph.validate()
cf_errors = [e for e in errors if "multiple client-facing" in e]
assert len(cf_errors) == 1
assert "'src'" in cf_errors[0]
def test_fan_out_one_client_facing_passes(self):
"""Only one client-facing target -> no error."""
graph = GraphSpec(
id="g1",
goal_id="goal1",
entry_node="src",
nodes=[
NodeSpec(id="src", name="src", description="Source node"),
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
NodeSpec(id="b", name="b", description="Node b", client_facing=False),
],
edges=[
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
],
)
errors = graph.validate()
cf_errors = [e for e in errors if "multiple client-facing" in e]
assert len(cf_errors) == 0
def test_fan_out_zero_client_facing_passes(self):
"""No client-facing targets at all -> no error."""
graph = GraphSpec(
id="g1",
goal_id="goal1",
entry_node="src",
nodes=[
NodeSpec(id="src", name="src", description="Source node"),
NodeSpec(id="a", name="a", description="Node a"),
NodeSpec(id="b", name="b", description="Node b"),
],
edges=[
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
],
)
errors = graph.validate()
cf_errors = [e for e in errors if "multiple client-facing" in e]
assert len(cf_errors) == 0
# ---------------------------------------------------------------------------
# Rule 2: event_loop output_key overlap
# ---------------------------------------------------------------------------
class TestEventLoopOutputKeyOverlap:
"""Parallel event_loop nodes with overlapping output_keys must be rejected."""
def test_overlapping_output_keys_event_loop_fails(self):
"""Two event_loop nodes sharing an output_key -> error."""
graph = GraphSpec(
id="g1",
goal_id="goal1",
entry_node="src",
nodes=[
NodeSpec(id="src", name="src", description="Source node"),
NodeSpec(
id="a",
name="a",
description="Node a",
node_type="event_loop",
output_keys=["status", "shared"],
),
NodeSpec(
id="b",
name="b",
description="Node b",
node_type="event_loop",
output_keys=["result", "shared"],
),
],
edges=[
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
],
)
errors = graph.validate()
key_errors = [e for e in errors if "output_key" in e]
assert len(key_errors) == 1
assert "'shared'" in key_errors[0]
def test_disjoint_output_keys_event_loop_passes(self):
"""Two event_loop nodes with disjoint output_keys -> no error."""
graph = GraphSpec(
id="g1",
goal_id="goal1",
entry_node="src",
nodes=[
NodeSpec(id="src", name="src", description="Source node"),
NodeSpec(
id="a",
name="a",
description="Node a",
node_type="event_loop",
output_keys=["status"],
),
NodeSpec(
id="b",
name="b",
description="Node b",
node_type="event_loop",
output_keys=["result"],
),
],
edges=[
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
],
)
errors = graph.validate()
key_errors = [e for e in errors if "output_key" in e]
assert len(key_errors) == 0
def test_overlapping_keys_non_event_loop_no_error(self):
"""Non-event_loop nodes with overlapping keys -> no error (last-wins OK)."""
graph = GraphSpec(
id="g1",
goal_id="goal1",
entry_node="src",
nodes=[
NodeSpec(id="src", name="src", description="Source node"),
NodeSpec(
id="a",
name="a",
description="Node a",
node_type="llm_generate",
output_keys=["shared"],
),
NodeSpec(
id="b",
name="b",
description="Node b",
node_type="llm_generate",
output_keys=["shared"],
),
],
edges=[
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
],
)
errors = graph.validate()
key_errors = [e for e in errors if "output_key" in e]
assert len(key_errors) == 0
# ---------------------------------------------------------------------------
# Baseline: no fan-out -> no errors from these rules
# ---------------------------------------------------------------------------
class TestNoFanOutUnaffected:
"""Linear graphs should not trigger either validation rule."""
def test_no_fan_out_unaffected(self):
"""Linear chain with client_facing and event_loop nodes -> no errors."""
graph = GraphSpec(
id="g1",
goal_id="goal1",
entry_node="a",
terminal_nodes=["c"],
nodes=[
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
NodeSpec(
id="b",
name="b",
description="Node b",
node_type="event_loop",
output_keys=["x"],
),
NodeSpec(
id="c",
name="c",
description="Node c",
client_facing=True,
node_type="event_loop",
output_keys=["x"],
),
],
edges=[
EdgeSpec(id="a->b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
EdgeSpec(id="b->c", source="b", target="c", condition=EdgeCondition.ON_SUCCESS),
],
)
errors = graph.validate()
cf_errors = [e for e in errors if "multiple client-facing" in e]
key_errors = [e for e in errors if "output_key" in e]
assert len(cf_errors) == 0
assert len(key_errors) == 0
+150
View File
@@ -0,0 +1,150 @@
"""
Tests for ClientIO gateway (WP-9).
Covers:
- ActiveNodeClientIO: emit_output output_stream round-trip, request_input, timeout
- InertNodeClientIO: emit_output publishes NODE_INTERNAL_OUTPUT, request_input returns redirect
- ClientIOGateway: factory creates correct variant
"""
import asyncio
import pytest
from framework.graph.client_io import (
ActiveNodeClientIO,
ClientIOGateway,
InertNodeClientIO,
NodeClientIO,
)
from framework.runtime.event_bus import AgentEvent, EventType
_AGENT_EVENT_FIELDS = {"stream_id", "node_id", "execution_id", "correlation_id"}
class MockEventBus:
"""Lightweight stand-in for EventBus that records published events."""
def __init__(self) -> None:
self.events: list[AgentEvent] = []
async def _record(self, event_type: EventType, **kwargs) -> None:
agent_kwargs = {k: v for k, v in kwargs.items() if k in _AGENT_EVENT_FIELDS}
data = {k: v for k, v in kwargs.items() if k not in _AGENT_EVENT_FIELDS}
self.events.append(AgentEvent(type=event_type, **agent_kwargs, data=data))
async def emit_client_output_delta(self, **kwargs) -> None:
await self._record(EventType.CLIENT_OUTPUT_DELTA, **kwargs)
async def emit_client_input_requested(self, **kwargs) -> None:
await self._record(EventType.CLIENT_INPUT_REQUESTED, **kwargs)
async def emit_node_internal_output(self, **kwargs) -> None:
await self._record(EventType.NODE_INTERNAL_OUTPUT, **kwargs)
async def emit_node_input_blocked(self, **kwargs) -> None:
await self._record(EventType.NODE_INPUT_BLOCKED, **kwargs)
# --- ActiveNodeClientIO tests ---
@pytest.mark.asyncio
async def test_active_emit_and_consume():
"""emit_output → output_stream round-trip works correctly."""
bus = MockEventBus()
io = ActiveNodeClientIO(node_id="n1", event_bus=bus)
await io.emit_output("Hello ")
await io.emit_output("World", is_final=True)
chunks = []
async for chunk in io.output_stream():
chunks.append(chunk)
assert chunks == ["Hello ", "World"]
assert len(bus.events) == 2
assert all(e.type == EventType.CLIENT_OUTPUT_DELTA for e in bus.events)
# Verify snapshot accumulates
assert bus.events[0].data["snapshot"] == "Hello "
assert bus.events[1].data["snapshot"] == "Hello World"
@pytest.mark.asyncio
async def test_active_request_input():
"""request_input blocks until provide_input is called."""
bus = MockEventBus()
io = ActiveNodeClientIO(node_id="n1", event_bus=bus)
async def fulfill_later():
await asyncio.sleep(0.01)
await io.provide_input("user says hi")
task = asyncio.create_task(fulfill_later())
result = await io.request_input(prompt="What?")
await task
assert result == "user says hi"
assert len(bus.events) == 1
assert bus.events[0].type == EventType.CLIENT_INPUT_REQUESTED
assert bus.events[0].data["prompt"] == "What?"
@pytest.mark.asyncio
async def test_active_request_input_timeout():
"""request_input raises TimeoutError when timeout expires."""
io = ActiveNodeClientIO(node_id="n1")
with pytest.raises(TimeoutError):
await io.request_input(prompt="waiting", timeout=0.01)
# --- InertNodeClientIO tests ---
@pytest.mark.asyncio
async def test_inert_emit_publishes_internal():
"""InertNodeClientIO.emit_output publishes NODE_INTERNAL_OUTPUT."""
bus = MockEventBus()
io = InertNodeClientIO(node_id="n2", event_bus=bus)
await io.emit_output("internal log")
assert len(bus.events) == 1
assert bus.events[0].type == EventType.NODE_INTERNAL_OUTPUT
assert bus.events[0].data["content"] == "internal log"
@pytest.mark.asyncio
async def test_inert_request_input_returns_redirect():
"""request_input returns a redirect string and publishes NODE_INPUT_BLOCKED."""
bus = MockEventBus()
io = InertNodeClientIO(node_id="n2", event_bus=bus)
result = await io.request_input(prompt="need data")
assert "internal processing node" in result
assert len(bus.events) == 1
assert bus.events[0].type == EventType.NODE_INPUT_BLOCKED
assert bus.events[0].data["prompt"] == "need data"
# --- ClientIOGateway tests ---
def test_gateway_creates_active_for_client_facing():
"""ClientIOGateway.create_io returns ActiveNodeClientIO when client_facing=True."""
gateway = ClientIOGateway()
io = gateway.create_io(node_id="n1", client_facing=True)
assert isinstance(io, ActiveNodeClientIO)
assert isinstance(io, NodeClientIO)
def test_gateway_creates_inert_for_internal():
"""ClientIOGateway.create_io returns InertNodeClientIO when client_facing=False."""
gateway = ClientIOGateway()
io = gateway.create_io(node_id="n2", client_facing=False)
assert isinstance(io, InertNodeClientIO)
assert isinstance(io, NodeClientIO)
+326
View File
@@ -0,0 +1,326 @@
"""Tests for ContextHandoff and HandoffContext."""
from __future__ import annotations
from typing import Any
import pytest
from framework.graph.context_handoff import ContextHandoff, HandoffContext
from framework.graph.conversation import NodeConversation
from framework.llm.mock import MockLLMProvider
from framework.llm.provider import LLMProvider, LLMResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class SpyLLMProvider(MockLLMProvider):
"""MockLLMProvider that records whether complete() was called."""
def __init__(self) -> None:
super().__init__()
self.complete_called = False
self.complete_call_args: dict[str, Any] | None = None
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
self.complete_called = True
self.complete_call_args = {"messages": messages, **kwargs}
return super().complete(messages, **kwargs)
class FailingLLMProvider(LLMProvider):
"""LLM provider that always raises."""
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
raise RuntimeError("LLM unavailable")
def complete_with_tools(
self,
messages: list[dict[str, Any]],
system: str,
tools: list,
tool_executor: Any,
max_iterations: int = 10,
) -> LLMResponse:
raise RuntimeError("LLM unavailable")
async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation:
"""Build a NodeConversation from (user, assistant) message pairs."""
conv = NodeConversation()
for user_msg, assistant_msg in pairs:
await conv.add_user_message(user_msg)
await conv.add_assistant_message(assistant_msg)
return conv
# ---------------------------------------------------------------------------
# TestHandoffContext
# ---------------------------------------------------------------------------
class TestHandoffContext:
def test_instantiation(self) -> None:
hc = HandoffContext(
source_node_id="node_A",
summary="Summary text",
key_outputs={"result": "42"},
turn_count=3,
total_tokens_used=1200,
)
assert hc.source_node_id == "node_A"
assert hc.summary == "Summary text"
assert hc.key_outputs == {"result": "42"}
assert hc.turn_count == 3
assert hc.total_tokens_used == 1200
def test_field_access(self) -> None:
hc = HandoffContext(
source_node_id="n1",
summary="s",
key_outputs={},
turn_count=0,
total_tokens_used=0,
)
assert hc.key_outputs == {}
# ---------------------------------------------------------------------------
# TestExtractiveSummary
# ---------------------------------------------------------------------------
class TestExtractiveSummary:
@pytest.mark.asyncio
async def test_extractive_summary_includes_first_last(self) -> None:
conv = await _build_conversation(
("hello", "First response here."),
("continue", "Middle response."),
("finish", "Final conclusion."),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="test_node")
assert "First response here." in hc.summary
assert "Final conclusion." in hc.summary
@pytest.mark.asyncio
async def test_extractive_summary_metadata(self) -> None:
conv = await _build_conversation(
("hi", "hello"),
("bye", "goodbye"),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="node_42")
assert hc.source_node_id == "node_42"
assert hc.turn_count == 2
assert hc.total_tokens_used > 0
@pytest.mark.asyncio
async def test_extractive_with_output_keys_colon(self) -> None:
conv = await _build_conversation(
("what is the answer?", "answer: 42"),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["answer"])
assert hc.key_outputs["answer"] == "42"
@pytest.mark.asyncio
async def test_extractive_with_output_keys_equals(self) -> None:
conv = await _build_conversation(
("compute", "result = success"),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["result"])
assert hc.key_outputs["result"] == "success"
@pytest.mark.asyncio
async def test_extractive_json_output_keys(self) -> None:
conv = await _build_conversation(
("give me json", '{"score": 95, "grade": "A"}'),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
assert hc.key_outputs["score"] == "95"
assert hc.key_outputs["grade"] == "A"
@pytest.mark.asyncio
async def test_extractive_empty_conversation(self) -> None:
conv = NodeConversation()
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="empty")
assert hc.summary == "Empty conversation."
assert hc.turn_count == 0
assert hc.key_outputs == {}
@pytest.mark.asyncio
async def test_extractive_no_assistant_messages(self) -> None:
conv = NodeConversation()
await conv.add_user_message("hello?")
await conv.add_user_message("anyone there?")
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="silent")
assert hc.summary == "No assistant responses."
@pytest.mark.asyncio
async def test_extractive_most_recent_wins(self) -> None:
conv = await _build_conversation(
("first", "status: old_value"),
("second", "status: new_value"),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["status"])
assert hc.key_outputs["status"] == "new_value"
@pytest.mark.asyncio
async def test_extractive_truncation(self) -> None:
long_text = "x" * 1000
conv = await _build_conversation(
("go", long_text),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n")
# Summary should be truncated to ~500 chars
assert len(hc.summary) <= 500
# ---------------------------------------------------------------------------
# TestLLMSummary
# ---------------------------------------------------------------------------
class TestLLMSummary:
@pytest.mark.asyncio
async def test_llm_summary_calls_provider(self) -> None:
llm = SpyLLMProvider()
conv = await _build_conversation(
("hi", "hello back"),
("what now?", "we are done"),
)
ch = ContextHandoff(llm=llm)
hc = ch.summarize_conversation(conv, node_id="llm_node")
assert llm.complete_called, "LLM complete() was never invoked"
assert hc.summary == "This is a mock response for testing purposes."
@pytest.mark.asyncio
async def test_llm_summary_includes_output_key_hint(self) -> None:
llm = SpyLLMProvider()
conv = await _build_conversation(
("compute", '{"score": 95}'),
)
ch = ContextHandoff(llm=llm)
ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
assert llm.complete_call_args is not None
system = llm.complete_call_args.get("system", "")
assert "score" in system
assert "grade" in system
@pytest.mark.asyncio
async def test_llm_fallback_on_error(self) -> None:
llm = FailingLLMProvider()
conv = await _build_conversation(
("start", "First assistant message."),
("end", "Last assistant message."),
)
ch = ContextHandoff(llm=llm)
hc = ch.summarize_conversation(conv, node_id="fallback_node")
# Should fall back to extractive (first + last assistant messages)
assert "First assistant message." in hc.summary
assert "Last assistant message." in hc.summary
# ---------------------------------------------------------------------------
# TestFormatAsInput
# ---------------------------------------------------------------------------
class TestFormatAsInput:
def test_format_structure(self) -> None:
hc = HandoffContext(
source_node_id="analyzer",
summary="Analysis complete.",
key_outputs={"score": "95"},
turn_count=5,
total_tokens_used=2000,
)
output = ContextHandoff.format_as_input(hc)
assert "--- CONTEXT FROM: analyzer" in output
assert "KEY OUTPUTS:" in output
assert "SUMMARY:" in output
assert "--- END CONTEXT ---" in output
def test_format_no_key_outputs(self) -> None:
hc = HandoffContext(
source_node_id="simple",
summary="Done.",
key_outputs={},
turn_count=1,
total_tokens_used=100,
)
output = ContextHandoff.format_as_input(hc)
assert "KEY OUTPUTS:" not in output
assert "SUMMARY:" in output
def test_format_content_values(self) -> None:
hc = HandoffContext(
source_node_id="node_X",
summary="Found 3 bugs.",
key_outputs={"bugs": "3", "severity": "high"},
turn_count=7,
total_tokens_used=5000,
)
output = ContextHandoff.format_as_input(hc)
assert "node_X" in output
assert "7 turns" in output
assert "~5000 tokens" in output
assert "- bugs: 3" in output
assert "- severity: high" in output
assert "Found 3 bugs." in output
def test_format_empty_summary(self) -> None:
hc = HandoffContext(
source_node_id="n",
summary="",
key_outputs={},
turn_count=0,
total_tokens_used=0,
)
output = ContextHandoff.format_as_input(hc)
assert "No summary available." in output
@pytest.mark.asyncio
async def test_format_as_input_usable_as_message(self) -> None:
"""Formatted output can be fed into a NodeConversation as a user message."""
hc = HandoffContext(
source_node_id="prev_node",
summary="Completed analysis.",
key_outputs={"result": "42"},
turn_count=3,
total_tokens_used=900,
)
text = ContextHandoff.format_as_input(hc)
conv = NodeConversation()
msg = await conv.add_user_message(text)
assert msg.role == "user"
assert "CONTEXT FROM: prev_node" in msg.content
assert conv.turn_count == 1
File diff suppressed because it is too large Load Diff
+746
View File
@@ -0,0 +1,746 @@
"""WP-8: Tests for EventLoopNode, OutputAccumulator, LoopConfig, JudgeProtocol.
Uses real FileConversationStore (no mocks for storage) and a MockStreamingLLM
that yields pre-programmed StreamEvents to control the loop deterministically.
"""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from framework.graph.conversation import NodeConversation
from framework.graph.event_loop_node import (
EventLoopNode,
JudgeProtocol,
JudgeVerdict,
LoopConfig,
OutputAccumulator,
)
from framework.graph.node import NodeContext, NodeProtocol, NodeSpec, SharedMemory
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
ToolCallEvent,
)
from framework.runtime.core import Runtime
from framework.runtime.event_bus import EventBus, EventType
from framework.storage.conversation_store import FileConversationStore
# ---------------------------------------------------------------------------
# Mock LLM that yields pre-programmed stream events
# ---------------------------------------------------------------------------
class MockStreamingLLM(LLMProvider):
"""Mock LLM that yields pre-programmed StreamEvent sequences.
Each call to stream() consumes the next scenario from the list.
Cycles back to the beginning if more calls are made than scenarios.
"""
def __init__(self, scenarios: list[list] | None = None):
self.scenarios = scenarios or []
self._call_index = 0
self.stream_calls: list[dict] = []
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator:
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
if not self.scenarios:
return
events = self.scenarios[self._call_index % len(self.scenarios)]
self._call_index += 1
for event in events:
yield event
def complete(self, messages, system="", **kwargs) -> LLMResponse:
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
return LLMResponse(content="", model="mock", stop_reason="stop")
# ---------------------------------------------------------------------------
# Helper: build a simple text-only scenario
# ---------------------------------------------------------------------------
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list:
"""Build a stream scenario that produces text and finishes."""
return [
TextDeltaEvent(content=text, snapshot=text),
FinishEvent(
stop_reason="stop", input_tokens=input_tokens, output_tokens=output_tokens, model="mock"
),
]
def tool_call_scenario(
tool_name: str,
tool_input: dict,
tool_use_id: str = "call_1",
text: str = "",
) -> list:
"""Build a stream scenario that produces a tool call."""
events = []
if text:
events.append(TextDeltaEvent(content=text, snapshot=text))
events.append(
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
)
events.append(
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
)
return events
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def runtime():
rt = MagicMock(spec=Runtime)
rt.start_run = MagicMock(return_value="run_1")
rt.decide = MagicMock(return_value="dec_1")
rt.record_outcome = MagicMock()
rt.end_run = MagicMock()
rt.report_problem = MagicMock()
rt.set_node = MagicMock()
return rt
@pytest.fixture
def node_spec():
return NodeSpec(
id="test_loop",
name="Test Loop",
description="A test event loop node",
node_type="event_loop",
output_keys=["result"],
system_prompt="You are a test assistant.",
)
@pytest.fixture
def memory():
return SharedMemory()
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
"""Build a NodeContext for testing."""
return NodeContext(
runtime=runtime,
node_id=node_spec.id,
node_spec=node_spec,
memory=memory,
input_data=input_data or {},
llm=llm,
available_tools=tools or [],
goal_context=goal_context,
)
# ===========================================================================
# NodeProtocol conformance
# ===========================================================================
class TestNodeProtocolConformance:
def test_subclasses_node_protocol(self):
"""EventLoopNode must be a subclass of NodeProtocol."""
assert issubclass(EventLoopNode, NodeProtocol)
def test_has_execute_method(self):
node = EventLoopNode()
assert hasattr(node, "execute")
assert asyncio.iscoroutinefunction(node.execute)
def test_has_validate_input(self):
node = EventLoopNode()
assert hasattr(node, "validate_input")
# ===========================================================================
# Basic loop execution
# ===========================================================================
class TestBasicLoop:
@pytest.mark.asyncio
async def test_basic_text_only_implicit_accept(self, runtime, node_spec, memory):
"""No tools, no judge. LLM produces text, implicit accept on stop."""
# Override to no output_keys so implicit judge accepts immediately
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("Hello world")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.tokens_used > 0
@pytest.mark.asyncio
async def test_no_llm_returns_failure(self, runtime, node_spec, memory):
"""ctx.llm=None should return failure immediately."""
ctx = build_ctx(runtime, node_spec, memory, llm=None)
node = EventLoopNode()
result = await node.execute(ctx)
assert result.success is False
assert "LLM" in result.error
@pytest.mark.asyncio
async def test_max_iterations_failure(self, runtime, node_spec, memory):
"""When max_iterations is reached without acceptance, should fail."""
# LLM always produces text but never calls set_output, so implicit
# judge retries asking for missing keys
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=2))
result = await node.execute(ctx)
assert result.success is False
assert "Max iterations" in result.error
# ===========================================================================
# Judge integration
# ===========================================================================
class TestJudgeIntegration:
@pytest.mark.asyncio
async def test_judge_accept(self, runtime, node_spec, memory):
"""Mock judge ACCEPT -> success."""
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("Done!")])
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
judge.evaluate.assert_called_once()
@pytest.mark.asyncio
async def test_judge_escalate(self, runtime, node_spec, memory):
"""Mock judge ESCALATE -> failure."""
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("Attempt")])
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(
return_value=JudgeVerdict(action="ESCALATE", feedback="Tone violation")
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is False
assert "escalated" in result.error.lower()
assert "Tone violation" in result.error
@pytest.mark.asyncio
async def test_judge_retry_then_accept(self, runtime, node_spec, memory):
"""RETRY twice, then ACCEPT. Should run 3 iterations."""
node_spec.output_keys = []
llm = MockStreamingLLM(
scenarios=[
text_scenario("attempt 1"),
text_scenario("attempt 2"),
text_scenario("attempt 3"),
]
)
call_count = 0
async def evaluate_fn(context):
nonlocal call_count
call_count += 1
if call_count < 3:
return JudgeVerdict(action="RETRY", feedback="Try harder")
return JudgeVerdict(action="ACCEPT")
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=10))
result = await node.execute(ctx)
assert result.success is True
assert call_count == 3
# ===========================================================================
# set_output tool
# ===========================================================================
class TestSetOutput:
@pytest.mark.asyncio
async def test_set_output_accumulates(self, runtime, node_spec, memory):
"""LLM calls set_output -> values appear in NodeResult.output."""
llm = MockStreamingLLM(
scenarios=[
# Turn 1: call set_output
tool_call_scenario("set_output", {"key": "result", "value": "42"}),
# Turn 2: text response (triggers implicit judge)
text_scenario("Done, result is 42"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "42"
@pytest.mark.asyncio
async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory):
"""set_output with key not in output_keys -> is_error=True."""
llm = MockStreamingLLM(
scenarios=[
# Turn 1: call set_output with bad key
tool_call_scenario("set_output", {"key": "bad_key", "value": "x"}),
# Turn 2: call set_output with good key
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
# Turn 3: text done
text_scenario("Done"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "ok"
assert "bad_key" not in result.output
@pytest.mark.asyncio
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
"""Judge accepts but output keys are missing -> retry with hint."""
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
llm = MockStreamingLLM(
scenarios=[
# Turn 1: text without set_output -> judge accepts but keys missing -> retry
text_scenario("I'll get to it"),
# Turn 2: set_output
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
# Turn 3: text -> judge accepts, keys present -> success
text_scenario("All done"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "done"
# ===========================================================================
# Stall detection
# ===========================================================================
class TestStallDetection:
@pytest.mark.asyncio
async def test_stall_detection(self, runtime, node_spec, memory):
"""3 identical responses should trigger stall detection."""
node_spec.output_keys = [] # so implicit judge would accept
# But we need the judge to RETRY so we actually get 3 identical responses
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
llm = MockStreamingLLM(scenarios=[text_scenario("same answer")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
judge=judge,
config=LoopConfig(max_iterations=10, stall_detection_threshold=3),
)
result = await node.execute(ctx)
assert result.success is False
assert "stalled" in result.error.lower()
# ===========================================================================
# EventBus lifecycle events
# ===========================================================================
class TestEventBusLifecycle:
@pytest.mark.asyncio
async def test_lifecycle_events_published(self, runtime, node_spec, memory):
"""NODE_LOOP_STARTED, NODE_LOOP_ITERATION, NODE_LOOP_COMPLETED should be published."""
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("ok")])
bus = EventBus()
received_events = []
bus.subscribe(
event_types=[
EventType.NODE_LOOP_STARTED,
EventType.NODE_LOOP_ITERATION,
EventType.NODE_LOOP_COMPLETED,
],
handler=lambda e: received_events.append(e.type),
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert EventType.NODE_LOOP_STARTED in received_events
assert EventType.NODE_LOOP_ITERATION in received_events
assert EventType.NODE_LOOP_COMPLETED in received_events
@pytest.mark.asyncio
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
spec = NodeSpec(
id="ui_node",
name="UI Node",
description="Streams to user",
node_type="event_loop",
output_keys=[],
client_facing=True,
)
llm = MockStreamingLLM(scenarios=[text_scenario("visible to user")])
bus = EventBus()
received_types = []
bus.subscribe(
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.LLM_TEXT_DELTA],
handler=lambda e: received_types.append(e.type),
)
ctx = build_ctx(runtime, spec, memory, llm)
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
await node.execute(ctx)
assert EventType.CLIENT_OUTPUT_DELTA in received_types
assert EventType.LLM_TEXT_DELTA not in received_types
# ===========================================================================
# Tool execution
# ===========================================================================
class TestToolExecution:
@pytest.mark.asyncio
async def test_tool_execution_feedback(self, runtime, node_spec, memory):
"""Tool call -> result fed back to conversation via stream loop."""
node_spec.output_keys = []
def my_tool_executor(tool_use: ToolUse) -> ToolResult:
return ToolResult(
tool_use_id=tool_use.id,
content=f"Result for {tool_use.name}",
is_error=False,
)
llm = MockStreamingLLM(
scenarios=[
# Turn 1: call a tool
tool_call_scenario("search", {"query": "test"}, tool_use_id="call_search"),
# Turn 2: text response after seeing tool result
text_scenario("Found the answer"),
]
)
ctx = build_ctx(
runtime,
node_spec,
memory,
llm,
tools=[Tool(name="search", description="Search", parameters={})],
)
node = EventLoopNode(
tool_executor=my_tool_executor,
config=LoopConfig(max_iterations=5),
)
result = await node.execute(ctx)
assert result.success is True
# stream() should have been called twice (tool call turn + final text turn)
assert llm._call_index >= 2
# ===========================================================================
# Write-through persistence with real FileConversationStore
# ===========================================================================
class TestWriteThroughPersistence:
@pytest.mark.asyncio
async def test_messages_written_to_store(self, tmp_path, runtime, node_spec, memory):
"""Messages should be persisted immediately via write-through."""
store = FileConversationStore(tmp_path / "conv")
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("Hello")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
conversation_store=store,
config=LoopConfig(max_iterations=5),
)
result = await node.execute(ctx)
assert result.success is True
# Verify parts were written to disk
parts = await store.read_parts()
assert len(parts) >= 2 # at least initial user msg + assistant msg
@pytest.mark.asyncio
async def test_output_accumulator_write_through(self, tmp_path, runtime, node_spec, memory):
"""set_output values should be persisted in cursor immediately."""
store = FileConversationStore(tmp_path / "conv")
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario("set_output", {"key": "result", "value": "persisted_value"}),
text_scenario("Done"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
conversation_store=store,
config=LoopConfig(max_iterations=5),
)
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "persisted_value"
# Verify output was written to cursor on disk
cursor = await store.read_cursor()
assert cursor is not None
assert cursor["outputs"]["result"] == "persisted_value"
# ===========================================================================
# Crash recovery (restore from real FileConversationStore)
# ===========================================================================
class TestCrashRecovery:
@pytest.mark.asyncio
async def test_restore_from_checkpoint(self, tmp_path, runtime, node_spec, memory):
"""Populate a store with state, then verify EventLoopNode restores from it."""
store = FileConversationStore(tmp_path / "conv")
# Simulate a previous run that wrote conversation + cursor
conv = NodeConversation(
system_prompt="You are a test assistant.",
output_keys=["result"],
store=store,
)
await conv.add_user_message("Initial input")
await conv.add_assistant_message("Working on it...")
# Write cursor with iteration and outputs
await store.write_cursor(
{
"iteration": 1,
"next_seq": conv.next_seq,
"outputs": {"result": "partial_value"},
}
)
# Now create a new EventLoopNode and execute -- it should restore
node_spec.output_keys = [] # no required keys so implicit accept works
llm = MockStreamingLLM(scenarios=[text_scenario("Continuing...")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
conversation_store=store,
config=LoopConfig(max_iterations=5),
)
result = await node.execute(ctx)
assert result.success is True
# Should have the restored output
assert result.output.get("result") == "partial_value"
# ===========================================================================
# External event injection
# ===========================================================================
class TestEventInjection:
@pytest.mark.asyncio
async def test_inject_event(self, runtime, node_spec, memory):
"""inject_event() content should appear as user message in next iteration."""
node_spec.output_keys = []
judge_calls = []
async def evaluate_fn(context):
judge_calls.append(context)
if len(judge_calls) >= 2:
return JudgeVerdict(action="ACCEPT")
return JudgeVerdict(action="RETRY")
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
llm = MockStreamingLLM(
scenarios=[
text_scenario("iteration 1"),
text_scenario("iteration 2"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
judge=judge,
config=LoopConfig(max_iterations=5),
)
# Pre-inject an event before execute runs
await node.inject_event("Priority: CEO wants meeting rescheduled")
result = await node.execute(ctx)
assert result.success is True
# Verify the injected content made it into the LLM messages
all_messages = []
for call in llm.stream_calls:
all_messages.extend(call["messages"])
injected_found = any("[External event]" in str(m.get("content", "")) for m in all_messages)
assert injected_found
# ===========================================================================
# Pause/resume
# ===========================================================================
class TestPauseResume:
@pytest.mark.asyncio
async def test_pause_returns_early(self, runtime, node_spec, memory):
"""pause_requested in input_data should trigger early return."""
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("should not run")])
ctx = build_ctx(
runtime,
node_spec,
memory,
llm,
input_data={"pause_requested": True},
)
node = EventLoopNode(config=LoopConfig(max_iterations=10))
result = await node.execute(ctx)
# Should return success (paused, not failed)
assert result.success is True
# LLM should not have been called (paused before first turn)
assert llm._call_index == 0
# ===========================================================================
# Stream errors
# ===========================================================================
class TestStreamErrors:
@pytest.mark.asyncio
async def test_non_recoverable_stream_error_raises(self, runtime, node_spec, memory):
"""Non-recoverable StreamErrorEvent should raise RuntimeError."""
node_spec.output_keys = []
llm = MockStreamingLLM(
scenarios=[
[StreamErrorEvent(error="Connection lost", recoverable=False)],
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
with pytest.raises(RuntimeError, match="Stream error"):
await node.execute(ctx)
# ===========================================================================
# OutputAccumulator unit tests
# ===========================================================================
class TestOutputAccumulator:
@pytest.mark.asyncio
async def test_set_and_get(self):
acc = OutputAccumulator()
await acc.set("key1", "value1")
assert acc.get("key1") == "value1"
assert acc.get("nonexistent") is None
@pytest.mark.asyncio
async def test_to_dict(self):
acc = OutputAccumulator()
await acc.set("a", 1)
await acc.set("b", 2)
assert acc.to_dict() == {"a": 1, "b": 2}
@pytest.mark.asyncio
async def test_has_all_keys(self):
acc = OutputAccumulator()
assert acc.has_all_keys([]) is True
assert acc.has_all_keys(["x"]) is False
await acc.set("x", "val")
assert acc.has_all_keys(["x"]) is True
@pytest.mark.asyncio
async def test_write_through_to_real_store(self, tmp_path):
"""OutputAccumulator should write through to FileConversationStore cursor."""
store = FileConversationStore(tmp_path / "acc_test")
acc = OutputAccumulator(store=store)
await acc.set("result", "hello")
cursor = await store.read_cursor()
assert cursor["outputs"]["result"] == "hello"
@pytest.mark.asyncio
async def test_restore_from_real_store(self, tmp_path):
"""OutputAccumulator.restore() should rebuild from FileConversationStore."""
store = FileConversationStore(tmp_path / "acc_restore")
await store.write_cursor({"outputs": {"key1": "val1", "key2": "val2"}})
acc = await OutputAccumulator.restore(store)
assert acc.get("key1") == "val1"
assert acc.get("key2") == "val2"
assert acc.has_all_keys(["key1", "key2"]) is True
+265
View File
@@ -0,0 +1,265 @@
"""
Tests for event_loop node type wiring (Issue #2513).
Covers:
- NodeSpec.client_facing field
- event_loop in VALID_NODE_TYPES
- _get_node_implementation() event_loop branch
- no-retry enforcement in serial execution path
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from framework.graph.edge import GraphSpec
from framework.graph.executor import GraphExecutor
from framework.graph.goal import Goal
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
from framework.runtime.core import Runtime
class AlwaysFailsNode(NodeProtocol):
"""A test node that always fails."""
def __init__(self):
self.attempt_count = 0
async def execute(self, ctx: NodeContext) -> NodeResult:
self.attempt_count += 1
return NodeResult(success=False, error=f"Permanent error (attempt {self.attempt_count})")
class SucceedsOnceNode(NodeProtocol):
"""A test node that always succeeds."""
async def execute(self, ctx: NodeContext) -> NodeResult:
return NodeResult(success=True, output={"result": "ok"})
@pytest.fixture(autouse=True)
def fast_sleep(monkeypatch):
"""Mock asyncio.sleep to avoid real delays from exponential backoff."""
monkeypatch.setattr("asyncio.sleep", AsyncMock())
@pytest.fixture
def runtime():
"""Create a mock Runtime for testing."""
runtime = MagicMock(spec=Runtime)
runtime.start_run = MagicMock(return_value="test_run_id")
runtime.decide = MagicMock(return_value="test_decision_id")
runtime.record_outcome = MagicMock()
runtime.end_run = MagicMock()
runtime.report_problem = MagicMock()
runtime.set_node = MagicMock()
return runtime
# --- NodeSpec.client_facing tests ---
def test_client_facing_defaults_false():
"""NodeSpec without client_facing should default to False."""
spec = NodeSpec(
id="n1",
name="Node 1",
description="test",
node_type="llm_generate",
)
assert spec.client_facing is False
def test_client_facing_explicit_true():
"""NodeSpec with client_facing=True should retain the value."""
spec = NodeSpec(
id="n1",
name="Node 1",
description="test",
node_type="event_loop",
client_facing=True,
)
assert spec.client_facing is True
# --- VALID_NODE_TYPES tests ---
def test_event_loop_in_valid_node_types():
"""'event_loop' must be in GraphExecutor.VALID_NODE_TYPES."""
assert "event_loop" in GraphExecutor.VALID_NODE_TYPES
def test_event_loop_node_spec_accepted():
"""Creating a NodeSpec with node_type='event_loop' should not raise."""
spec = NodeSpec(
id="el1",
name="Event Loop",
description="test",
node_type="event_loop",
)
assert spec.node_type == "event_loop"
# --- _get_node_implementation() tests ---
def test_unregistered_event_loop_raises(runtime):
"""An event_loop node not in the registry should raise RuntimeError."""
spec = NodeSpec(
id="el1",
name="Event Loop",
description="test",
node_type="event_loop",
)
executor = GraphExecutor(runtime=runtime)
with pytest.raises(RuntimeError, match="not found in registry"):
executor._get_node_implementation(spec)
def test_registered_event_loop_returns_impl(runtime):
"""A registered event_loop node should be returned from the registry."""
spec = NodeSpec(
id="el1",
name="Event Loop",
description="test",
node_type="event_loop",
)
impl = SucceedsOnceNode()
executor = GraphExecutor(runtime=runtime)
executor.register_node("el1", impl)
result = executor._get_node_implementation(spec)
assert result is impl
# --- No-retry enforcement (serial path) ---
@pytest.mark.asyncio
async def test_event_loop_max_retries_forced_zero(runtime):
"""An event_loop node with max_retries=3 should only execute once (no retry)."""
node_spec = NodeSpec(
id="el_fail",
name="Failing Event Loop",
description="event loop that fails",
node_type="event_loop",
max_retries=3,
output_keys=["result"],
)
graph = GraphSpec(
id="test_graph",
goal_id="test_goal",
name="Test Graph",
entry_node="el_fail",
nodes=[node_spec],
edges=[],
terminal_nodes=["el_fail"],
)
goal = Goal(id="test_goal", name="Test", description="test")
executor = GraphExecutor(runtime=runtime)
failing_node = AlwaysFailsNode()
executor.register_node("el_fail", failing_node)
result = await executor.execute(graph, goal, {})
# Event loop nodes get max_retries overridden to 0, meaning execute once then fail
assert not result.success
assert failing_node.attempt_count == 1
@pytest.mark.asyncio
async def test_event_loop_max_retries_zero_no_warning(runtime, caplog):
"""An event_loop node with max_retries=0 should not log a warning."""
node_spec = NodeSpec(
id="el_zero",
name="Zero Retry Event Loop",
description="event loop with 0 retries",
node_type="event_loop",
max_retries=0,
output_keys=["result"],
)
graph = GraphSpec(
id="test_graph",
goal_id="test_goal",
name="Test Graph",
entry_node="el_zero",
nodes=[node_spec],
edges=[],
terminal_nodes=["el_zero"],
)
goal = Goal(id="test_goal", name="Test", description="test")
executor = GraphExecutor(runtime=runtime)
failing_node = AlwaysFailsNode()
executor.register_node("el_zero", failing_node)
import logging
with caplog.at_level(logging.WARNING):
await executor.execute(graph, goal, {})
# max_retries=0 should not trigger the override warning
assert "Overriding to 0" not in caplog.text
@pytest.mark.asyncio
async def test_event_loop_max_retries_positive_logs_warning(runtime, caplog):
"""An event_loop node with max_retries=3 should log a warning about override."""
node_spec = NodeSpec(
id="el_warn",
name="Warning Event Loop",
description="event loop with retries",
node_type="event_loop",
max_retries=3,
output_keys=["result"],
)
graph = GraphSpec(
id="test_graph",
goal_id="test_goal",
name="Test Graph",
entry_node="el_warn",
nodes=[node_spec],
edges=[],
terminal_nodes=["el_warn"],
)
goal = Goal(id="test_goal", name="Test", description="test")
executor = GraphExecutor(runtime=runtime)
failing_node = AlwaysFailsNode()
executor.register_node("el_warn", failing_node)
import logging
with caplog.at_level(logging.WARNING):
await executor.execute(graph, goal, {})
assert "Overriding to 0" in caplog.text
assert "el_warn" in caplog.text
# --- Existing node types unaffected ---
def test_existing_node_types_unchanged():
"""All pre-existing node types must still be in VALID_NODE_TYPES with defaults preserved."""
expected = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
assert expected.issubset(GraphExecutor.VALID_NODE_TYPES)
# Default node_type is still llm_tool_use
spec = NodeSpec(id="x", name="X", description="x")
assert spec.node_type == "llm_tool_use"
# Default max_retries is still 3
assert spec.max_retries == 3
# Default client_facing is False
assert spec.client_facing is False
+978
View File
@@ -0,0 +1,978 @@
"""Tests for extending the stream event type system.
Validates that the StreamEvent discriminated union pattern supports:
- Type-based dispatch (matching on event.type)
- Pattern matching / isinstance branching
- Custom event subclasses following the same frozen-dataclass convention
- Serialization of mixed event sequences
WP-2 tests validate EventType enum extension and node-level event routing:
- All 12 new EventType enum members with correct string values
- node_id routing on AgentEvent
- filter_node on Subscription
- Backward compatibility with existing enum members
"""
import asyncio
from dataclasses import FrozenInstanceError, asdict, dataclass, field
from typing import Any, Literal
import pytest
from framework.llm.stream_events import (
FinishEvent,
ReasoningDeltaEvent,
ReasoningStartEvent,
StreamErrorEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
ToolResultEvent,
)
from framework.runtime.event_bus import AgentEvent, EventBus, EventType, Subscription
# ---------------------------------------------------------------------------
# Helpers: type-based dispatch
# ---------------------------------------------------------------------------
def dispatch_event(event) -> str:
"""Dispatch an event by its type field, returning a label."""
handlers = {
"text_delta": lambda e: f"text:{e.content}",
"text_end": lambda e: f"end:{len(e.full_text)}chars",
"tool_call": lambda e: f"call:{e.tool_name}",
"tool_result": lambda e: f"result:{e.tool_use_id}",
"reasoning_start": lambda _: "reasoning:start",
"reasoning_delta": lambda e: f"reasoning:{e.content[:20]}",
"finish": lambda e: f"finish:{e.stop_reason}",
"error": lambda e: f"error:{e.error}",
}
handler = handlers.get(event.type)
if handler is None:
return f"unknown:{event.type}"
return handler(event)
def collect_text(events: list) -> str:
"""Accumulate full text from a stream of events."""
for event in reversed(events):
if isinstance(event, TextEndEvent):
return event.full_text
if isinstance(event, TextDeltaEvent):
return event.snapshot
return ""
def extract_tool_calls(events: list) -> list[dict[str, Any]]:
"""Extract tool call info from a stream of events."""
return [
{"id": e.tool_use_id, "name": e.tool_name, "input": e.tool_input}
for e in events
if isinstance(e, ToolCallEvent)
]
# ---------------------------------------------------------------------------
# Type-based dispatch tests
# ---------------------------------------------------------------------------
class TestTypeDispatch:
"""Dispatch on event.type string for handler routing."""
def test_dispatch_text_delta(self):
e = TextDeltaEvent(content="hello")
assert dispatch_event(e) == "text:hello"
def test_dispatch_text_end(self):
e = TextEndEvent(full_text="hello world")
assert dispatch_event(e) == "end:11chars"
def test_dispatch_tool_call(self):
e = ToolCallEvent(tool_name="web_search")
assert dispatch_event(e) == "call:web_search"
def test_dispatch_tool_result(self):
e = ToolResultEvent(tool_use_id="abc")
assert dispatch_event(e) == "result:abc"
def test_dispatch_reasoning_start(self):
e = ReasoningStartEvent()
assert dispatch_event(e) == "reasoning:start"
def test_dispatch_reasoning_delta(self):
e = ReasoningDeltaEvent(content="Let me think step by step")
assert dispatch_event(e) == "reasoning:Let me think step by"
def test_dispatch_finish(self):
e = FinishEvent(stop_reason="end_turn")
assert dispatch_event(e) == "finish:end_turn"
def test_dispatch_error(self):
e = StreamErrorEvent(error="timeout")
assert dispatch_event(e) == "error:timeout"
# ---------------------------------------------------------------------------
# isinstance-based filtering
# ---------------------------------------------------------------------------
class TestInstanceFiltering:
"""Filter event streams using isinstance for each event type."""
@pytest.fixture
def text_stream(self) -> list:
"""Simulate a text-only stream."""
return [
TextDeltaEvent(content="Hello", snapshot="Hello"),
TextDeltaEvent(content=" world", snapshot="Hello world"),
TextDeltaEvent(content="!", snapshot="Hello world!"),
TextEndEvent(full_text="Hello world!"),
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=3, model="test"),
]
@pytest.fixture
def tool_stream(self) -> list:
"""Simulate a tool call stream."""
return [
ToolCallEvent(
tool_use_id="call_1",
tool_name="get_weather",
tool_input={"city": "London"},
),
ToolCallEvent(
tool_use_id="call_2",
tool_name="calculator",
tool_input={"expression": "2+2"},
),
FinishEvent(stop_reason="tool_calls"),
]
@pytest.fixture
def reasoning_stream(self) -> list:
"""Simulate a stream with reasoning blocks."""
return [
ReasoningStartEvent(),
ReasoningDeltaEvent(content="Let me analyze this..."),
ReasoningDeltaEvent(content="The answer is 42."),
TextDeltaEvent(content="The answer is 42.", snapshot="The answer is 42."),
TextEndEvent(full_text="The answer is 42."),
FinishEvent(stop_reason="end_turn"),
]
def test_collect_text(self, text_stream):
assert collect_text(text_stream) == "Hello world!"
def test_collect_text_from_tool_stream(self, tool_stream):
assert collect_text(tool_stream) == ""
def test_extract_tool_calls(self, tool_stream):
calls = extract_tool_calls(tool_stream)
assert len(calls) == 2
assert calls[0]["name"] == "get_weather"
assert calls[1]["name"] == "calculator"
def test_extract_tool_calls_from_text_stream(self, text_stream):
assert extract_tool_calls(text_stream) == []
def test_filter_text_deltas(self, text_stream):
deltas = [e for e in text_stream if isinstance(e, TextDeltaEvent)]
assert len(deltas) == 3
def test_filter_finish(self, text_stream):
finishes = [e for e in text_stream if isinstance(e, FinishEvent)]
assert len(finishes) == 1
assert finishes[0].stop_reason == "stop"
def test_reasoning_then_text(self, reasoning_stream):
reasoning = [e for e in reasoning_stream if isinstance(e, ReasoningDeltaEvent)]
text = collect_text(reasoning_stream)
assert len(reasoning) == 2
assert text == "The answer is 42."
def test_mixed_stream_type_counts(self, reasoning_stream):
type_counts = {}
for e in reasoning_stream:
type_counts[e.type] = type_counts.get(e.type, 0) + 1
assert type_counts == {
"reasoning_start": 1,
"reasoning_delta": 2,
"text_delta": 1,
"text_end": 1,
"finish": 1,
}
# ---------------------------------------------------------------------------
# Custom event extension pattern
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class CustomMetricsEvent:
"""Example custom event following the same pattern."""
type: Literal["custom_metrics"] = "custom_metrics"
latency_ms: float = 0.0
tokens_per_second: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class CustomCitationEvent:
"""Example citation event extending the pattern."""
type: Literal["citation"] = "citation"
source_url: str = ""
quote: str = ""
confidence: float = 0.0
class TestCustomEventExtension:
"""Custom events should follow the same frozen-dataclass convention."""
def test_custom_event_construction(self):
e = CustomMetricsEvent(latency_ms=150.5, tokens_per_second=42.3)
assert e.type == "custom_metrics"
assert e.latency_ms == 150.5
def test_custom_event_frozen(self):
e = CustomMetricsEvent()
with pytest.raises(FrozenInstanceError):
e.type = "modified"
def test_custom_event_serialization(self):
e = CustomMetricsEvent(
latency_ms=100.0,
tokens_per_second=50.0,
metadata={"provider": "anthropic"},
)
d = asdict(e)
assert d["type"] == "custom_metrics"
assert d["metadata"] == {"provider": "anthropic"}
def test_custom_event_dispatch(self):
"""Custom events can extend the dispatch map."""
e = CustomMetricsEvent(latency_ms=200.0)
# Falls through to "unknown" in our dispatch_event
assert dispatch_event(e) == "unknown:custom_metrics"
def test_custom_event_in_mixed_stream(self):
"""Custom events can coexist with standard events in a list."""
stream = [
TextDeltaEvent(content="hi", snapshot="hi"),
CustomMetricsEvent(latency_ms=50.0),
TextEndEvent(full_text="hi"),
CustomCitationEvent(source_url="https://example.com", quote="hi"),
FinishEvent(stop_reason="stop"),
]
standard = [
e
for e in stream
if hasattr(e, "type")
and e.type
in {
"text_delta",
"text_end",
"tool_call",
"tool_result",
"reasoning_start",
"reasoning_delta",
"finish",
"error",
}
]
custom = [
e
for e in stream
if e.type
not in {
"text_delta",
"text_end",
"tool_call",
"tool_result",
"reasoning_start",
"reasoning_delta",
"finish",
"error",
}
]
assert len(standard) == 3
assert len(custom) == 2
# ---------------------------------------------------------------------------
# Serialization of full event sequences
# ---------------------------------------------------------------------------
class TestSequenceSerialization:
"""Serialize entire event sequences, as done by the dump tests."""
def test_serialize_text_sequence(self):
events = [
TextDeltaEvent(content="Hello", snapshot="Hello"),
TextDeltaEvent(content=" world", snapshot="Hello world"),
TextEndEvent(full_text="Hello world"),
FinishEvent(stop_reason="stop", model="test-model"),
]
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
assert len(serialized) == 4
assert serialized[0]["index"] == 0
assert serialized[0]["type"] == "text_delta"
assert serialized[-1]["type"] == "finish"
assert serialized[-1]["model"] == "test-model"
def test_serialize_tool_sequence(self):
events = [
ToolCallEvent(
tool_use_id="call_1",
tool_name="search",
tool_input={"query": "test"},
),
FinishEvent(stop_reason="tool_calls"),
]
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
assert serialized[0]["tool_input"] == {"query": "test"}
assert serialized[1]["stop_reason"] == "tool_calls"
def test_serialize_error_sequence(self):
events = [
TextDeltaEvent(content="partial"),
StreamErrorEvent(error="connection reset", recoverable=True),
FinishEvent(stop_reason="error"),
]
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
assert serialized[1]["type"] == "error"
assert serialized[1]["recoverable"] is True
def test_roundtrip_snapshot_accumulation(self):
"""Verify snapshot grows monotonically through serialization."""
chunks = ["Hello", " beautiful", " world", "!"]
events = []
snapshot = ""
for chunk in chunks:
snapshot += chunk
events.append(TextDeltaEvent(content=chunk, snapshot=snapshot))
serialized = [asdict(e) for e in events]
for i in range(1, len(serialized)):
assert len(serialized[i]["snapshot"]) > len(serialized[i - 1]["snapshot"])
assert serialized[-1]["snapshot"] == "Hello beautiful world!"
# ===========================================================================
# WP-2: EventType Enum Extension + Node-Level Event Routing
# ===========================================================================
# The 12 new EventType members added by WP-2
WP2_EVENT_TYPES = {
# Node event-loop lifecycle
EventType.NODE_LOOP_STARTED: "node_loop_started",
EventType.NODE_LOOP_ITERATION: "node_loop_iteration",
EventType.NODE_LOOP_COMPLETED: "node_loop_completed",
# LLM streaming observability
EventType.LLM_TEXT_DELTA: "llm_text_delta",
EventType.LLM_REASONING_DELTA: "llm_reasoning_delta",
# Tool lifecycle
EventType.TOOL_CALL_STARTED: "tool_call_started",
EventType.TOOL_CALL_COMPLETED: "tool_call_completed",
# Client I/O
EventType.CLIENT_OUTPUT_DELTA: "client_output_delta",
EventType.CLIENT_INPUT_REQUESTED: "client_input_requested",
# Internal node observability
EventType.NODE_INTERNAL_OUTPUT: "node_internal_output",
EventType.NODE_INPUT_BLOCKED: "node_input_blocked",
EventType.NODE_STALLED: "node_stalled",
}
# Pre-existing enum members that must remain unchanged
ORIGINAL_EVENT_TYPES = {
EventType.EXECUTION_STARTED: "execution_started",
EventType.EXECUTION_COMPLETED: "execution_completed",
EventType.EXECUTION_FAILED: "execution_failed",
EventType.EXECUTION_PAUSED: "execution_paused",
EventType.EXECUTION_RESUMED: "execution_resumed",
EventType.STATE_CHANGED: "state_changed",
EventType.STATE_CONFLICT: "state_conflict",
EventType.GOAL_PROGRESS: "goal_progress",
EventType.GOAL_ACHIEVED: "goal_achieved",
EventType.CONSTRAINT_VIOLATION: "constraint_violation",
EventType.STREAM_STARTED: "stream_started",
EventType.STREAM_STOPPED: "stream_stopped",
EventType.CUSTOM: "custom",
}
# ---------------------------------------------------------------------------
# WP-2 Part A: EventType enum members
# ---------------------------------------------------------------------------
class TestWP2EventTypeEnumMembers:
"""All 12 new EventType members exist with correct string values."""
@pytest.mark.parametrize(
"member,expected_value",
WP2_EVENT_TYPES.items(),
ids=lambda x: x.name if isinstance(x, EventType) else x,
)
def test_new_member_value(self, member, expected_value):
assert member.value == expected_value
def test_all_12_new_members_exist(self):
assert len(WP2_EVENT_TYPES) == 12
def test_new_member_string_values_are_unique(self):
values = list(WP2_EVENT_TYPES.values())
assert len(values) == len(set(values))
def test_no_collision_with_original_members(self):
new_values = set(WP2_EVENT_TYPES.values())
old_values = set(ORIGINAL_EVENT_TYPES.values())
overlap = new_values & old_values
assert overlap == set(), f"Colliding values: {overlap}"
@pytest.mark.parametrize(
"member,expected_value",
ORIGINAL_EVENT_TYPES.items(),
ids=lambda x: x.name if isinstance(x, EventType) else x,
)
def test_original_members_unchanged(self, member, expected_value):
assert member.value == expected_value
def test_event_type_is_str_enum(self):
"""EventType members compare equal to their string values."""
assert EventType.NODE_LOOP_STARTED == "node_loop_started"
assert EventType.LLM_TEXT_DELTA == "llm_text_delta"
assert EventType.LLM_TEXT_DELTA.value == "llm_text_delta"
def test_event_type_accessible_by_name(self):
assert EventType["NODE_LOOP_STARTED"] is EventType.NODE_LOOP_STARTED
assert EventType["TOOL_CALL_COMPLETED"] is EventType.TOOL_CALL_COMPLETED
def test_event_type_accessible_by_value(self):
assert EventType("node_loop_started") is EventType.NODE_LOOP_STARTED
assert EventType("tool_call_completed") is EventType.TOOL_CALL_COMPLETED
# ---------------------------------------------------------------------------
# WP-2 Part B: AgentEvent.node_id and Subscription.filter_node
# ---------------------------------------------------------------------------
class TestWP2AgentEventNodeId:
"""AgentEvent supports node_id as a first-class field."""
def test_node_id_defaults_to_none(self):
event = AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="stream-1",
)
assert event.node_id is None
def test_node_id_can_be_set(self):
event = AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="email_composer",
)
assert event.node_id == "email_composer"
def test_node_id_in_to_dict(self):
event = AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id="stream-1",
node_id="search_node",
)
d = event.to_dict()
assert d["node_id"] == "search_node"
def test_node_id_none_in_to_dict(self):
event = AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="stream-1",
)
d = event.to_dict()
assert "node_id" in d
assert d["node_id"] is None
class TestWP2SubscriptionFilterNode:
"""Subscription supports filter_node for node-level routing."""
@staticmethod
async def _noop_handler(event: AgentEvent) -> None:
pass
def test_filter_node_defaults_to_none(self):
sub = Subscription(
id="sub_1",
event_types={EventType.LLM_TEXT_DELTA},
handler=self._noop_handler,
)
assert sub.filter_node is None
def test_filter_node_can_be_set(self):
sub = Subscription(
id="sub_1",
event_types={EventType.LLM_TEXT_DELTA},
handler=self._noop_handler,
filter_node="email_composer",
)
assert sub.filter_node == "email_composer"
# ---------------------------------------------------------------------------
# WP-2 Part B: Node-level event routing integration tests
# ---------------------------------------------------------------------------
class TestWP2NodeLevelRouting:
"""EventBus routes events by node_id using filter_node."""
@pytest.fixture
def bus(self):
return EventBus()
@pytest.mark.asyncio
async def test_filter_node_receives_matching_events(self, bus):
"""Subscriber with filter_node='node-A' receives events from node-A."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler,
filter_node="node-A",
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="node-A",
data={"content": "hello"},
)
)
assert len(received) == 1
assert received[0].node_id == "node-A"
assert received[0].data["content"] == "hello"
@pytest.mark.asyncio
async def test_filter_node_rejects_non_matching_events(self, bus):
"""Subscriber with filter_node='node-B' does NOT receive node-A events."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler,
filter_node="node-B",
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="node-A",
data={"content": "hello"},
)
)
assert len(received) == 0
@pytest.mark.asyncio
async def test_no_filter_node_receives_all_events(self, bus):
"""Subscriber with no filter_node receives events from all nodes."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler,
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="node-A",
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="node-B",
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id=None,
)
)
assert len(received) == 3
@pytest.mark.asyncio
async def test_interleaved_nodes_separated_by_filter(self, bus):
"""Two subscribers on different nodes get only their node's events."""
node_a_events = []
node_b_events = []
async def handler_a(event):
node_a_events.append(event)
async def handler_b(event):
node_b_events.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler_a,
filter_node="email_sender",
)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler_b,
filter_node="inbox_scanner",
)
# Interleaved events from both nodes
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="webhook",
node_id="email_sender",
data={"content": "Dear Jo"},
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="webhook",
node_id="inbox_scanner",
data={"content": "RE: Meeting conf"},
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="webhook",
node_id="email_sender",
data={"content": "hn, Thank you for"},
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="webhook",
node_id="inbox_scanner",
data={"content": "irmed for Thursday"},
)
)
assert len(node_a_events) == 2
assert len(node_b_events) == 2
assert node_a_events[0].data["content"] == "Dear Jo"
assert node_a_events[1].data["content"] == "hn, Thank you for"
assert node_b_events[0].data["content"] == "RE: Meeting conf"
assert node_b_events[1].data["content"] == "irmed for Thursday"
@pytest.mark.asyncio
async def test_filter_node_combined_with_filter_stream(self, bus):
"""filter_node and filter_stream work together."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.TOOL_CALL_STARTED],
handler=handler,
filter_stream="webhook",
filter_node="search_node",
)
# Matching both filters
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id="webhook",
node_id="search_node",
)
)
# Wrong stream
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id="api",
node_id="search_node",
)
)
# Wrong node
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id="webhook",
node_id="other_node",
)
)
assert len(received) == 1
assert received[0].stream_id == "webhook"
assert received[0].node_id == "search_node"
@pytest.mark.asyncio
async def test_wait_for_with_node_id(self, bus):
"""wait_for() accepts node_id parameter for filtering."""
async def publish_later():
await asyncio.sleep(0.01)
await bus.publish(
AgentEvent(
type=EventType.NODE_LOOP_COMPLETED,
stream_id="stream-1",
node_id="target_node",
data={"iterations": 3},
)
)
task = asyncio.create_task(publish_later())
event = await bus.wait_for(
event_type=EventType.NODE_LOOP_COMPLETED,
node_id="target_node",
timeout=2.0,
)
await task
assert event is not None
assert event.node_id == "target_node"
assert event.data["iterations"] == 3
@pytest.mark.asyncio
async def test_wait_for_ignores_wrong_node(self, bus):
"""wait_for() with node_id ignores events from other nodes."""
async def publish_wrong_then_right():
await asyncio.sleep(0.01)
# Wrong node — should be ignored
await bus.publish(
AgentEvent(
type=EventType.NODE_LOOP_COMPLETED,
stream_id="stream-1",
node_id="wrong_node",
)
)
await asyncio.sleep(0.01)
# Right node
await bus.publish(
AgentEvent(
type=EventType.NODE_LOOP_COMPLETED,
stream_id="stream-1",
node_id="target_node",
data={"iterations": 5},
)
)
task = asyncio.create_task(publish_wrong_then_right())
event = await bus.wait_for(
event_type=EventType.NODE_LOOP_COMPLETED,
node_id="target_node",
timeout=2.0,
)
await task
assert event is not None
assert event.node_id == "target_node"
assert event.data["iterations"] == 5
# ---------------------------------------------------------------------------
# WP-2: Convenience publisher methods
# ---------------------------------------------------------------------------
class TestWP2ConveniencePublishers:
"""EventBus convenience methods for new WP-2 event types."""
@pytest.fixture
def bus(self):
return EventBus()
@pytest.mark.asyncio
async def test_emit_node_loop_started(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.NODE_LOOP_STARTED], handler=handler)
await bus.emit_node_loop_started(
stream_id="s1",
node_id="n1",
max_iterations=10,
)
assert len(received) == 1
assert received[0].node_id == "n1"
assert received[0].data["max_iterations"] == 10
@pytest.mark.asyncio
async def test_emit_node_loop_iteration(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.NODE_LOOP_ITERATION], handler=handler)
await bus.emit_node_loop_iteration(
stream_id="s1",
node_id="n1",
iteration=3,
)
assert len(received) == 1
assert received[0].data["iteration"] == 3
@pytest.mark.asyncio
async def test_emit_node_loop_completed(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.NODE_LOOP_COMPLETED], handler=handler)
await bus.emit_node_loop_completed(
stream_id="s1",
node_id="n1",
iterations=5,
)
assert len(received) == 1
assert received[0].data["iterations"] == 5
@pytest.mark.asyncio
async def test_emit_llm_text_delta(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.LLM_TEXT_DELTA], handler=handler)
await bus.emit_llm_text_delta(
stream_id="s1",
node_id="n1",
content="hello",
snapshot="hello world",
)
assert len(received) == 1
assert received[0].data["content"] == "hello"
assert received[0].data["snapshot"] == "hello world"
@pytest.mark.asyncio
async def test_emit_tool_call_started(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler)
await bus.emit_tool_call_started(
stream_id="s1",
node_id="n1",
tool_use_id="call_1",
tool_name="web_search",
tool_input={"query": "test"},
)
assert len(received) == 1
assert received[0].data["tool_name"] == "web_search"
assert received[0].data["tool_input"] == {"query": "test"}
@pytest.mark.asyncio
async def test_emit_tool_call_completed(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler)
await bus.emit_tool_call_completed(
stream_id="s1",
node_id="n1",
tool_use_id="call_1",
tool_name="web_search",
result="3 results found",
)
assert len(received) == 1
assert received[0].data["result"] == "3 results found"
assert received[0].data["is_error"] is False
@pytest.mark.asyncio
async def test_emit_client_output_delta(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.CLIENT_OUTPUT_DELTA], handler=handler)
await bus.emit_client_output_delta(
stream_id="s1",
node_id="n1",
content="chunk",
snapshot="full chunk",
)
assert len(received) == 1
assert received[0].data["content"] == "chunk"
@pytest.mark.asyncio
async def test_emit_node_stalled(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.NODE_STALLED], handler=handler)
await bus.emit_node_stalled(
stream_id="s1",
node_id="n1",
reason="no progress after 10 iterations",
)
assert len(received) == 1
assert received[0].data["reason"] == "no progress after 10 iterations"
@pytest.mark.asyncio
async def test_convenience_publishers_set_node_id(self, bus):
"""All WP-2 convenience publishers set node_id on the emitted event."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA, EventType.TOOL_CALL_STARTED],
handler=handler,
filter_node="my_node",
)
await bus.emit_llm_text_delta(
stream_id="s1",
node_id="my_node",
content="hi",
snapshot="hi",
)
await bus.emit_tool_call_started(
stream_id="s1",
node_id="my_node",
tool_use_id="c1",
tool_name="calc",
)
# Wrong node — should not be received
await bus.emit_llm_text_delta(
stream_id="s1",
node_id="other_node",
content="bye",
snapshot="bye",
)
assert len(received) == 2
assert all(e.node_id == "my_node" for e in received)
+389
View File
@@ -0,0 +1,389 @@
"""Real-API streaming tests for LiteLLM provider.
Calls live LLM APIs and dumps stream events to JSON files for review.
Results are saved to core/tests/stream_event_dumps/{provider}_{model}_{scenario}.json
Run with:
cd core && python -m pytest tests/test_litellm_streaming.py -v -s -k "RealAPI"
Requires API keys set in environment:
ANTHROPIC_API_KEY, OPENAI_API_KEY, GEMINI_API_KEY (or via credential store)
"""
import asyncio
import json
import logging
import os
from dataclasses import asdict
from pathlib import Path
import pytest
from framework.llm.litellm import LiteLLMProvider
from framework.llm.provider import Tool
from framework.llm.stream_events import (
FinishEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
logger = logging.getLogger(__name__)
DUMP_DIR = Path(__file__).parent / "stream_event_dumps"
def _serialize_event(index: int, event: StreamEvent) -> dict:
"""Serialize a StreamEvent to a JSON-safe dict."""
d = asdict(event) # type: ignore[arg-type]
d["index"] = index
# Move index to front for readability
return {"index": index, **{k: v for k, v in d.items() if k != "index"}}
def _dump_events(events: list[StreamEvent], filename: str) -> Path:
"""Write stream events to a JSON file in the dump directory."""
DUMP_DIR.mkdir(parents=True, exist_ok=True)
filepath = DUMP_DIR / filename
serialized = [_serialize_event(i, e) for i, e in enumerate(events)]
filepath.write_text(json.dumps(serialized, indent=2) + "\n")
logger.info(f"Dumped {len(events)} events to {filepath}")
return filepath
async def _collect_stream(provider: LiteLLMProvider, **kwargs) -> list[StreamEvent]:
"""Collect all stream events from a provider.stream() call."""
events: list[StreamEvent] = []
async for event in provider.stream(**kwargs):
events.append(event)
# Log each event type as it arrives
logger.debug(f" [{len(events) - 1}] {event.type}: {event}")
return events
# ---------------------------------------------------------------------------
# Test matrix: (model_id, dump_prefix, env_var_for_skip)
# ---------------------------------------------------------------------------
MODELS = [
(
"anthropic/claude-haiku-4-5-20251001",
"anthropic_claude-haiku-4-5-20251001",
"ANTHROPIC_API_KEY",
),
("gpt-4.1-nano", "gpt-4.1-nano", "OPENAI_API_KEY"),
("gemini/gemini-2.0-flash", "gemini_gemini-2.0-flash", "GEMINI_API_KEY"),
]
WEATHER_TOOL = Tool(
name="get_weather",
description="Get the current weather for a city.",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name, e.g. 'Tokyo'",
}
},
"required": ["city"],
},
)
SEARCH_TOOL = Tool(
name="web_search",
description="Search the web for information.",
parameters={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"num_results": {
"type": "integer",
"description": "Number of results to return (1-10)",
},
},
"required": ["query"],
},
)
CALCULATOR_TOOL = Tool(
name="calculator",
description="Perform arithmetic calculations.",
parameters={
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "Math expression to evaluate, e.g. '2 + 2'",
}
},
"required": ["expression"],
},
)
def _has_api_key(env_var: str) -> bool:
"""Check if an API key is available (env var or credential store)."""
if os.environ.get(env_var):
return True
# Try credential store
try:
from aden_tools.credentials import CredentialStoreAdapter
creds = CredentialStoreAdapter.with_env_storage()
provider_name = env_var.replace("_API_KEY", "").lower()
return creds.is_available(provider_name)
except (ImportError, Exception):
return False
# ---------------------------------------------------------------------------
# Real API tests — text streaming
# ---------------------------------------------------------------------------
class TestRealAPITextStreaming:
"""Stream a simple text response from each provider and dump events."""
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
@pytest.mark.asyncio
async def test_text_stream(self, model: str, prefix: str, env_var: str):
"""Stream a multi-paragraph response to exercise chunked delivery."""
if not _has_api_key(env_var):
pytest.skip(f"{env_var} not set")
provider = LiteLLMProvider(model=model)
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": (
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
"Cover fetch, decode, and execute stages. Be concise but thorough."
),
}
],
system="You are a computer science teacher. Give clear, structured explanations.",
max_tokens=512,
)
# Dump to file
_dump_events(events, f"{prefix}_text.json")
# Basic structural assertions
assert len(events) >= 4, f"Expected at least 4 events, got {len(events)}"
# Must have multiple text deltas for a longer response
text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
assert len(text_deltas) >= 3, f"Expected 3+ TextDeltaEvents, got {len(text_deltas)}"
# Snapshot must accumulate monotonically
for i in range(1, len(text_deltas)):
assert len(text_deltas[i].snapshot) > len(text_deltas[i - 1].snapshot), (
f"Snapshot did not grow at index {i}"
)
# Must end with TextEndEvent then FinishEvent
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
assert len(text_ends) == 1, f"Expected 1 TextEndEvent, got {len(text_ends)}"
finish_events = [e for e in events if isinstance(e, FinishEvent)]
assert len(finish_events) == 1, f"Expected 1 FinishEvent, got {len(finish_events)}"
assert finish_events[0].stop_reason in ("stop", "end_turn")
# TextEndEvent.full_text should match last snapshot
assert text_ends[0].full_text == text_deltas[-1].snapshot
# Response should actually contain multi-paragraph content
full_text = text_ends[0].full_text
assert len(full_text) > 200, f"Response too short ({len(full_text)} chars)"
# ---------------------------------------------------------------------------
# Real API tests — tool call streaming
# ---------------------------------------------------------------------------
class TestRealAPIToolCallStreaming:
"""Stream a tool call response from each provider and dump events."""
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
@pytest.mark.asyncio
async def test_tool_call_stream(self, model: str, prefix: str, env_var: str):
"""Stream a single tool call with complex arguments."""
if not _has_api_key(env_var):
pytest.skip(f"{env_var} not set")
provider = LiteLLMProvider(model=model)
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": "Search the web for 'Python 3.13 release notes'.",
}
],
system="You have access to tools. Use the appropriate tool.",
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
max_tokens=512,
)
# Dump to file
_dump_events(events, f"{prefix}_tool_call.json")
# Basic structural assertions
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"
# Must have a tool call event
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
assert len(tool_calls) >= 1, "No ToolCallEvent received"
tc = tool_calls[0]
assert tc.tool_name == "web_search"
assert "query" in tc.tool_input
assert tc.tool_use_id != ""
# Must end with FinishEvent
finish_events = [e for e in events if isinstance(e, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason in ("tool_calls", "tool_use", "stop")
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
@pytest.mark.asyncio
async def test_multi_tool_call_stream(self, model: str, prefix: str, env_var: str):
"""Stream a response that should invoke multiple tool calls."""
if not _has_api_key(env_var):
pytest.skip(f"{env_var} not set")
provider = LiteLLMProvider(model=model)
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": (
"I need three things done in parallel: "
"1) Get the weather in London, "
"2) Get the weather in New York, "
"3) Calculate 1337 * 42. "
"Use the tools for all three."
),
}
],
system=(
"You have access to tools. When the user asks for multiple things, "
"call all the needed tools. Always use tools, never guess results."
),
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
max_tokens=512,
)
# Dump to file
_dump_events(events, f"{prefix}_multi_tool.json")
# Must have multiple tool call events
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
assert len(tool_calls) >= 2, (
f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
)
# Verify tool names used
tool_names = {tc.tool_name for tc in tool_calls}
assert "get_weather" in tool_names, "Expected get_weather tool call"
# All tool calls should have non-empty IDs
for tc in tool_calls:
assert tc.tool_use_id != "", f"Empty tool_use_id on {tc.tool_name}"
assert tc.tool_input, f"Empty tool_input on {tc.tool_name}"
# Must end with FinishEvent
finish_events = [e for e in events if isinstance(e, FinishEvent)]
assert len(finish_events) == 1
# ---------------------------------------------------------------------------
# Convenience runner for manual invocation
# ---------------------------------------------------------------------------
if __name__ == "__main__":
"""Run all streaming tests and dump results. Usage: python tests/test_litellm_streaming.py"""
ALL_TOOLS = [WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL]
async def _run_all():
for model, prefix, env_var in MODELS:
if not _has_api_key(env_var):
print(f"SKIP {prefix}: {env_var} not set")
continue
provider = LiteLLMProvider(model=model)
# Text streaming (multi-paragraph)
print(f"\n--- {prefix} text ---")
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": (
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
"Cover fetch, decode, and execute stages. Be concise but thorough."
),
}
],
system="You are a computer science teacher. Give clear, structured explanations.",
max_tokens=512,
)
path = _dump_events(events, f"{prefix}_text.json")
print(f" {len(events)} events -> {path}")
for i, e in enumerate(events):
print(f" [{i}] {e.type}: {e}")
# Tool call streaming
print(f"\n--- {prefix} tool_call ---")
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": "Search the web for 'Python 3.13 release notes'.",
}
],
system="You have access to tools. Use the appropriate tool.",
tools=ALL_TOOLS,
max_tokens=512,
)
path = _dump_events(events, f"{prefix}_tool_call.json")
print(f" {len(events)} events -> {path}")
for i, e in enumerate(events):
print(f" [{i}] {e.type}: {e}")
# Multi-tool call streaming
print(f"\n--- {prefix} multi_tool ---")
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": (
"I need three things done in parallel: "
"1) Get the weather in London, "
"2) Get the weather in New York, "
"3) Calculate 1337 * 42. "
"Use the tools for all three."
),
}
],
system=(
"You have access to tools. When the user asks for multiple things, "
"call all the needed tools. Always use tools, never guess results."
),
tools=ALL_TOOLS,
max_tokens=512,
)
path = _dump_events(events, f"{prefix}_multi_tool.json")
print(f" {len(events)} events -> {path}")
for i, e in enumerate(events):
print(f" [{i}] {e.type}: {e}")
logging.basicConfig(level=logging.DEBUG)
asyncio.run(_run_all())
+372
View File
@@ -168,6 +168,68 @@ class TestNodeConversation:
await conv.add_user_message("a" * 400)
assert conv.estimate_tokens() == 100
@pytest.mark.asyncio
async def test_update_token_count_overrides_estimate(self):
"""When actual API token count is provided, estimate_tokens uses it."""
conv = NodeConversation()
await conv.add_user_message("a" * 400)
assert conv.estimate_tokens() == 100 # chars/4 fallback
conv.update_token_count(500)
assert conv.estimate_tokens() == 500 # actual API value
@pytest.mark.asyncio
async def test_compact_resets_token_count(self):
"""After compaction, actual token count is cleared (recalibrates on next LLM call)."""
conv = NodeConversation()
await conv.add_user_message("a" * 400)
conv.update_token_count(500)
assert conv.estimate_tokens() == 500
await conv.compact("summary", keep_recent=0)
# Falls back to chars/4 for the summary message
assert conv.estimate_tokens() == len("summary") // 4
@pytest.mark.asyncio
async def test_clear_resets_token_count(self):
"""clear() also resets the actual token count."""
conv = NodeConversation()
await conv.add_user_message("hello")
conv.update_token_count(1000)
assert conv.estimate_tokens() == 1000
await conv.clear()
assert conv.estimate_tokens() == 0
@pytest.mark.asyncio
async def test_usage_ratio(self):
"""usage_ratio returns estimate / max_history_tokens."""
conv = NodeConversation(max_history_tokens=1000)
await conv.add_user_message("a" * 400)
assert conv.usage_ratio() == pytest.approx(0.1) # 100/1000
conv.update_token_count(800)
assert conv.usage_ratio() == pytest.approx(0.8) # 800/1000
@pytest.mark.asyncio
async def test_usage_ratio_zero_budget(self):
"""usage_ratio returns 0 when max_history_tokens is 0 (unlimited)."""
conv = NodeConversation(max_history_tokens=0)
await conv.add_user_message("a" * 400)
assert conv.usage_ratio() == 0.0
@pytest.mark.asyncio
async def test_needs_compaction_with_actual_tokens(self):
"""needs_compaction uses actual API token count when available."""
conv = NodeConversation(max_history_tokens=1000, compaction_threshold=0.8)
await conv.add_user_message("a" * 100) # chars/4 = 25, well under 800
assert conv.needs_compaction() is False
# Simulate API reporting much higher actual token usage
conv.update_token_count(850)
assert conv.needs_compaction() is True
@pytest.mark.asyncio
async def test_needs_compaction(self):
conv = NodeConversation(max_history_tokens=100, compaction_threshold=0.8)
@@ -558,3 +620,313 @@ class TestFileConversationStore:
assert (base / "cursor.json").exists()
assert (base / "parts" / "0000000000.json").exists()
assert (base / "parts" / "0000000001.json").exists()
# ===================================================================
# Integration tests — real FileConversationStore, no mocks
# ===================================================================
class TestConversationIntegration:
"""End-to-end tests using real FileConversationStore on disk.
Every test creates a fresh directory, writes real JSON files,
and restores from a *new* store instance (simulating process restart).
"""
@pytest.mark.asyncio
async def test_multi_turn_agent_conversation(self, tmp_path):
"""Simulate a realistic agent conversation with multiple turns,
tool calls, and tool results then restore from disk."""
base = tmp_path / "agent_conv"
store = FileConversationStore(base)
conv = NodeConversation(
system_prompt="You are a helpful travel agent.",
max_history_tokens=16000,
store=store,
)
# Turn 1: user asks, assistant responds with tool call
await conv.add_user_message("Find me flights from NYC to London next Friday.")
await conv.add_assistant_message(
"Let me search for flights.",
tool_calls=[
{
"id": "call_flight_1",
"type": "function",
"function": {
"name": "search_flights",
"arguments": '{"origin":"JFK","destination":"LHR","date":"2025-06-13"}',
},
}
],
)
await conv.add_tool_result(
"call_flight_1",
'{"flights":[{"airline":"BA","price":450,"departure":"08:00"},{"airline":"AA","price":520,"departure":"14:30"}]}',
)
# Turn 2: assistant presents results, user picks one
await conv.add_assistant_message(
"I found 2 flights:\n"
"1. British Airways at $450, departing 08:00\n"
"2. American Airlines at $520, departing 14:30\n"
"Which one would you like?"
)
await conv.add_user_message("Book the British Airways one.")
await conv.add_assistant_message(
"Booking the BA flight now.",
tool_calls=[
{
"id": "call_book_1",
"type": "function",
"function": {
"name": "book_flight",
"arguments": '{"flight_id":"BA-JFK-LHR-0800","passenger":"user"}',
},
}
],
)
await conv.add_tool_result(
"call_book_1",
'{"confirmation":"BA-12345","status":"confirmed"}',
)
await conv.add_assistant_message("Your flight is booked! Confirmation: BA-12345.")
# Verify in-memory state
assert conv.turn_count == 2
assert conv.message_count == 8
assert conv.next_seq == 8
# --- Simulate process restart: new store, same path ---
store2 = FileConversationStore(base)
restored = await NodeConversation.restore(store2)
assert restored is not None
assert restored.system_prompt == "You are a helpful travel agent."
assert restored.turn_count == 2
assert restored.message_count == 8
assert restored.next_seq == 8
# Verify message content integrity
msgs = restored.messages
assert msgs[0].role == "user"
assert "NYC to London" in msgs[0].content
assert msgs[1].role == "assistant"
assert msgs[1].tool_calls[0]["id"] == "call_flight_1"
assert msgs[2].role == "tool"
assert msgs[2].tool_use_id == "call_flight_1"
assert "BA" in msgs[2].content
assert msgs[7].content == "Your flight is booked! Confirmation: BA-12345."
# Verify LLM-format output
llm_msgs = restored.to_llm_messages()
assert llm_msgs[0] == {"role": "user", "content": msgs[0].content}
assert llm_msgs[2]["role"] == "tool"
assert llm_msgs[2]["tool_call_id"] == "call_flight_1"
@pytest.mark.asyncio
async def test_compaction_and_restore_preserves_continuity(self, tmp_path):
"""Build up a long conversation, compact it, continue adding
messages, then restore verifying seq continuity and content."""
base = tmp_path / "compact_conv"
store = FileConversationStore(base)
conv = NodeConversation(
system_prompt="research assistant",
store=store,
)
# Build 10 messages (5 turns)
for i in range(5):
await conv.add_user_message(f"question {i}")
await conv.add_assistant_message(f"answer {i}")
assert conv.message_count == 10
assert conv.next_seq == 10
# Compact: keep last 2 messages (question 4, answer 4)
await conv.compact("Summary of questions 0-3 and their answers.", keep_recent=2)
assert conv.message_count == 3 # summary + 2 recent
assert conv.messages[0].content == "Summary of questions 0-3 and their answers."
assert conv.messages[1].content == "question 4"
assert conv.messages[2].content == "answer 4"
# Continue the conversation post-compaction
await conv.add_user_message("question 5")
await conv.add_assistant_message("answer 5")
assert conv.next_seq == 12
# Verify disk: old part files (seq 0-7) should be deleted
parts_dir = base / "parts"
part_files = sorted(parts_dir.glob("*.json"))
part_seqs = [int(f.stem) for f in part_files]
# Should have: summary (seq 7), question 4 (seq 8), answer 4 (seq 9),
# question 5 (seq 10), answer 5 (seq 11)
assert all(s >= 7 for s in part_seqs), f"Stale parts found: {part_seqs}"
# Restore from fresh store
store2 = FileConversationStore(base)
restored = await NodeConversation.restore(store2)
assert restored is not None
assert restored.next_seq == 12
assert restored.message_count == 5
assert "Summary of questions 0-3" in restored.messages[0].content
assert restored.messages[-1].content == "answer 5"
# Verify seq monotonicity across all restored messages
seqs = [m.seq for m in restored.messages]
assert seqs == sorted(seqs), f"Seqs not monotonic: {seqs}"
@pytest.mark.asyncio
async def test_output_key_preservation_through_compact_and_restore(self, tmp_path):
"""Output keys in compacted messages survive disk persistence."""
base = tmp_path / "output_key_conv"
store = FileConversationStore(base)
conv = NodeConversation(
system_prompt="classifier",
output_keys=["classification", "confidence"],
store=store,
)
await conv.add_user_message("Classify this email: 'You won a prize!'")
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.97"}')
await conv.add_user_message("What about: 'Meeting at 3pm'")
await conv.add_assistant_message('{"classification": "ham", "confidence": "0.99"}')
await conv.add_user_message("And: 'Buy cheap meds now'")
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.95"}')
# Compact keeping only the last 2 messages
await conv.compact("Classified 3 emails.", keep_recent=2)
# The summary should contain preserved output keys from discarded messages
summary_content = conv.messages[0].content
assert "PRESERVED VALUES" in summary_content
# Most recent values from discarded messages (msgs 0-3) are "ham"/"0.99"
assert "ham" in summary_content or "spam" in summary_content
# Restore and verify the preserved values survived
store2 = FileConversationStore(base)
restored = await NodeConversation.restore(store2)
assert restored is not None
assert "PRESERVED VALUES" in restored.messages[0].content
@pytest.mark.asyncio
async def test_tool_error_roundtrip(self, tmp_path):
"""Tool errors persist and restore with ERROR: prefix in LLM output."""
base = tmp_path / "error_conv"
store = FileConversationStore(base)
conv = NodeConversation(store=store)
await conv.add_user_message("Calculate 1/0")
await conv.add_assistant_message(
"Let me calculate that.",
tool_calls=[
{
"id": "call_calc",
"type": "function",
"function": {"name": "calculator", "arguments": '{"expr":"1/0"}'},
}
],
)
await conv.add_tool_result(
"call_calc", "ZeroDivisionError: division by zero", is_error=True
)
await conv.add_assistant_message("The calculation failed: division by zero is undefined.")
# Restore
store2 = FileConversationStore(base)
restored = await NodeConversation.restore(store2)
assert restored is not None
tool_msg = restored.messages[2]
assert tool_msg.role == "tool"
assert tool_msg.is_error is True
assert tool_msg.tool_use_id == "call_calc"
llm_dict = tool_msg.to_llm_dict()
assert llm_dict["content"].startswith("ERROR: ")
assert "ZeroDivisionError" in llm_dict["content"]
assert llm_dict["tool_call_id"] == "call_calc"
@pytest.mark.asyncio
async def test_concurrent_conversations_isolated(self, tmp_path):
"""Two conversations in separate directories don't interfere."""
store_a = FileConversationStore(tmp_path / "conv_a")
store_b = FileConversationStore(tmp_path / "conv_b")
conv_a = NodeConversation(system_prompt="Agent A", store=store_a)
conv_b = NodeConversation(system_prompt="Agent B", store=store_b)
await conv_a.add_user_message("Hello from A")
await conv_b.add_user_message("Hello from B")
await conv_a.add_assistant_message("Response A")
await conv_b.add_assistant_message("Response B")
await conv_b.add_user_message("Follow-up B")
# Restore independently
restored_a = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_a"))
restored_b = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_b"))
assert restored_a.system_prompt == "Agent A"
assert restored_b.system_prompt == "Agent B"
assert restored_a.message_count == 2
assert restored_b.message_count == 3
assert restored_a.messages[0].content == "Hello from A"
assert restored_b.messages[2].content == "Follow-up B"
@pytest.mark.asyncio
async def test_destroy_removes_all_files(self, tmp_path):
"""destroy() wipes the entire conversation directory."""
base = tmp_path / "doomed_conv"
store = FileConversationStore(base)
conv = NodeConversation(system_prompt="temp", store=store)
await conv.add_user_message("ephemeral")
await conv.add_assistant_message("gone soon")
assert base.exists()
assert (base / "meta.json").exists()
assert (base / "parts").exists()
await store.destroy()
assert not base.exists()
@pytest.mark.asyncio
async def test_restore_empty_store_returns_none(self, tmp_path):
"""Restoring from a path that was never written to returns None."""
store = FileConversationStore(tmp_path / "empty")
result = await NodeConversation.restore(store)
assert result is None
@pytest.mark.asyncio
async def test_clear_then_continue_then_restore(self, tmp_path):
"""clear() removes messages but preserves seq counter for new messages."""
base = tmp_path / "clear_conv"
store = FileConversationStore(base)
conv = NodeConversation(system_prompt="s", store=store)
await conv.add_user_message("old msg 0")
await conv.add_assistant_message("old msg 1")
assert conv.next_seq == 2
await conv.clear()
assert conv.message_count == 0
assert conv.next_seq == 2 # seq counter preserved
# Continue with new messages — seqs should start at 2
await conv.add_user_message("new msg")
await conv.add_assistant_message("new response")
assert conv.next_seq == 4
assert conv.messages[0].seq == 2
assert conv.messages[1].seq == 3
# Restore
store2 = FileConversationStore(base)
restored = await NodeConversation.restore(store2)
assert restored is not None
assert restored.message_count == 2
assert restored.next_seq == 4
assert restored.messages[0].content == "new msg"
assert restored.messages[0].seq == 2
+318
View File
@@ -0,0 +1,318 @@
"""Tests for stream event dataclasses.
Validates construction, defaults, immutability, serialization, and the
StreamEvent discriminated union type.
"""
from dataclasses import FrozenInstanceError, asdict, fields
import pytest
from framework.llm.stream_events import (
FinishEvent,
ReasoningDeltaEvent,
ReasoningStartEvent,
StreamErrorEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
ToolResultEvent,
)
# All concrete event classes in the union
ALL_EVENT_CLASSES = [
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
ToolResultEvent,
ReasoningStartEvent,
ReasoningDeltaEvent,
FinishEvent,
StreamErrorEvent,
]
# ---------------------------------------------------------------------------
# Construction & defaults
# ---------------------------------------------------------------------------
class TestEventDefaults:
"""Each event class should be constructible with zero arguments."""
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_default_construction(self, cls):
event = cls()
assert event.type != ""
def test_text_delta_defaults(self):
e = TextDeltaEvent()
assert e.type == "text_delta"
assert e.content == ""
assert e.snapshot == ""
def test_text_end_defaults(self):
e = TextEndEvent()
assert e.type == "text_end"
assert e.full_text == ""
def test_tool_call_defaults(self):
e = ToolCallEvent()
assert e.type == "tool_call"
assert e.tool_use_id == ""
assert e.tool_name == ""
assert e.tool_input == {}
def test_tool_result_defaults(self):
e = ToolResultEvent()
assert e.type == "tool_result"
assert e.tool_use_id == ""
assert e.content == ""
assert e.is_error is False
def test_reasoning_start_defaults(self):
e = ReasoningStartEvent()
assert e.type == "reasoning_start"
def test_reasoning_delta_defaults(self):
e = ReasoningDeltaEvent()
assert e.type == "reasoning_delta"
assert e.content == ""
def test_finish_defaults(self):
e = FinishEvent()
assert e.type == "finish"
assert e.stop_reason == ""
assert e.input_tokens == 0
assert e.output_tokens == 0
assert e.model == ""
def test_stream_error_defaults(self):
e = StreamErrorEvent()
assert e.type == "error"
assert e.error == ""
assert e.recoverable is False
# ---------------------------------------------------------------------------
# Construction with values
# ---------------------------------------------------------------------------
class TestEventConstruction:
"""Events should store provided field values correctly."""
def test_text_delta_with_values(self):
e = TextDeltaEvent(content="hello", snapshot="hello world")
assert e.content == "hello"
assert e.snapshot == "hello world"
def test_text_end_with_values(self):
e = TextEndEvent(full_text="the complete response")
assert e.full_text == "the complete response"
def test_tool_call_with_values(self):
e = ToolCallEvent(
tool_use_id="call_abc123",
tool_name="web_search",
tool_input={"query": "python", "num_results": 5},
)
assert e.tool_use_id == "call_abc123"
assert e.tool_name == "web_search"
assert e.tool_input == {"query": "python", "num_results": 5}
def test_tool_result_with_values(self):
e = ToolResultEvent(
tool_use_id="call_abc123",
content="search results here",
is_error=False,
)
assert e.tool_use_id == "call_abc123"
assert e.content == "search results here"
assert e.is_error is False
def test_tool_result_error(self):
e = ToolResultEvent(
tool_use_id="call_fail",
content="timeout",
is_error=True,
)
assert e.is_error is True
def test_reasoning_delta_with_content(self):
e = ReasoningDeltaEvent(content="Let me think about this...")
assert e.content == "Let me think about this..."
def test_finish_with_values(self):
e = FinishEvent(
stop_reason="end_turn",
input_tokens=150,
output_tokens=300,
model="claude-haiku-4-5",
)
assert e.stop_reason == "end_turn"
assert e.input_tokens == 150
assert e.output_tokens == 300
assert e.model == "claude-haiku-4-5"
def test_stream_error_with_values(self):
e = StreamErrorEvent(error="rate limit exceeded", recoverable=True)
assert e.error == "rate limit exceeded"
assert e.recoverable is True
# ---------------------------------------------------------------------------
# Frozen immutability
# ---------------------------------------------------------------------------
class TestEventImmutability:
"""All events are frozen dataclasses — fields cannot be reassigned."""
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_frozen(self, cls):
event = cls()
with pytest.raises(FrozenInstanceError):
event.type = "modified"
def test_text_delta_frozen_content(self):
e = TextDeltaEvent(content="hello")
with pytest.raises(FrozenInstanceError):
e.content = "modified"
def test_tool_call_frozen_input(self):
e = ToolCallEvent(tool_input={"key": "value"})
with pytest.raises(FrozenInstanceError):
e.tool_input = {}
# ---------------------------------------------------------------------------
# Type literal values
# ---------------------------------------------------------------------------
class TestTypeLiterals:
"""Each event's `type` field should match its Literal annotation."""
EXPECTED_TYPES = {
TextDeltaEvent: "text_delta",
TextEndEvent: "text_end",
ToolCallEvent: "tool_call",
ToolResultEvent: "tool_result",
ReasoningStartEvent: "reasoning_start",
ReasoningDeltaEvent: "reasoning_delta",
FinishEvent: "finish",
StreamErrorEvent: "error",
}
@pytest.mark.parametrize(
"cls,expected_type",
EXPECTED_TYPES.items(),
ids=lambda x: x.__name__ if isinstance(x, type) else x,
)
def test_type_value(self, cls, expected_type):
assert cls().type == expected_type
def test_all_types_unique(self):
types = [cls().type for cls in ALL_EVENT_CLASSES]
assert len(types) == len(set(types)), f"Duplicate type values: {types}"
# ---------------------------------------------------------------------------
# Serialization via dataclasses.asdict
# ---------------------------------------------------------------------------
class TestEventSerialization:
"""Events should round-trip through asdict for JSON serialization."""
def test_text_delta_asdict(self):
e = TextDeltaEvent(content="chunk", snapshot="full chunk")
d = asdict(e)
assert d == {"type": "text_delta", "content": "chunk", "snapshot": "full chunk"}
def test_tool_call_asdict(self):
e = ToolCallEvent(
tool_use_id="id_1",
tool_name="calc",
tool_input={"expression": "2+2"},
)
d = asdict(e)
assert d["tool_name"] == "calc"
assert d["tool_input"] == {"expression": "2+2"}
def test_finish_asdict(self):
e = FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=20, model="gpt-4")
d = asdict(e)
assert d == {
"type": "finish",
"stop_reason": "stop",
"input_tokens": 10,
"output_tokens": 20,
"model": "gpt-4",
}
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_asdict_contains_type(self, cls):
d = asdict(cls())
assert "type" in d
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_asdict_keys_match_fields(self, cls):
event = cls()
d = asdict(event)
field_names = {f.name for f in fields(cls)}
assert set(d.keys()) == field_names
# ---------------------------------------------------------------------------
# StreamEvent union type
# ---------------------------------------------------------------------------
class TestStreamEventUnion:
"""The StreamEvent union should include all event classes."""
def test_union_contains_all_classes(self):
# StreamEvent is a UnionType (PEP 604 syntax: X | Y | Z)
union_args = StreamEvent.__args__ # type: ignore[attr-defined]
for cls in ALL_EVENT_CLASSES:
assert cls in union_args, f"{cls.__name__} not in StreamEvent union"
def test_union_has_exactly_expected_members(self):
union_args = set(StreamEvent.__args__) # type: ignore[attr-defined]
expected = set(ALL_EVENT_CLASSES)
assert union_args == expected
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_isinstance_check(self, cls):
"""Each event instance should be an instance of its class (basic sanity)."""
event = cls()
assert isinstance(event, cls)
# ---------------------------------------------------------------------------
# Equality & hashing (frozen dataclasses support both)
# ---------------------------------------------------------------------------
class TestEventEquality:
"""Frozen dataclasses support equality and hashing."""
def test_equal_events(self):
a = TextDeltaEvent(content="hi", snapshot="hi")
b = TextDeltaEvent(content="hi", snapshot="hi")
assert a == b
def test_unequal_events(self):
a = TextDeltaEvent(content="hi")
b = TextDeltaEvent(content="bye")
assert a != b
def test_different_types_not_equal(self):
a = TextDeltaEvent(content="hi")
b = ReasoningDeltaEvent(content="hi")
assert a != b
def test_hashable(self):
e = FinishEvent(stop_reason="stop", model="gpt-4")
s = {e} # should be hashable since frozen
assert e in s
def test_equal_events_same_hash(self):
a = FinishEvent(stop_reason="stop", model="gpt-4")
b = FinishEvent(stop_reason="stop", model="gpt-4")
assert hash(a) == hash(b)
def test_events_with_dict_not_hashable(self):
"""Events containing dict fields (e.g. tool_input) are not hashable."""
e = ToolCallEvent(tool_use_id="x", tool_name="y", tool_input={"key": "val"})
with pytest.raises(TypeError, match="unhashable type"):
hash(e)