Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d49e858d32 | |||
| d7afa5dcf2 | |||
| 22e816bf86 | |||
| 3240616808 | |||
| b9f83d4d61 | |||
| 9c16826ad3 | |||
| df4d0ad3fd | |||
| 9034d1dc71 | |||
| 537172d8ce | |||
| 20b2e4b3dd | |||
| a86043a2ec | |||
| 3947da2cf1 | |||
| 17caab6563 | |||
| a5ae071a03 | |||
| 9c33da7b8d | |||
| 94d31743b0 | |||
| 70db618c6e | |||
| 23146c8dae | |||
| c52ce6bb49 | |||
| bcddd4ce77 | |||
| 017872f71b | |||
| 7e670ce0a8 | |||
| 3ee6d98905 |
@@ -32,9 +32,13 @@
|
||||
"mcp__agent-builder__verify_credentials",
|
||||
"Bash(PYTHONPATH=/home/timothy/oss/hive/core:/home/timothy/oss/hive/exports python:*)",
|
||||
"Bash(PYTHONPATH=core:exports:tools/src python -m hubspot_input:*)",
|
||||
"mcp__agent-builder__export_graph"
|
||||
"mcp__agent-builder__export_graph",
|
||||
"Bash(python3:*)"
|
||||
]
|
||||
},
|
||||
"enabledMcpjsonServers": ["agent-builder", "tools"],
|
||||
"enableAllProjectMcpServers": true
|
||||
"enableAllProjectMcpServers": true,
|
||||
"enabledMcpjsonServers": [
|
||||
"agent-builder",
|
||||
"tools"
|
||||
]
|
||||
}
|
||||
|
||||
+4
-1
@@ -69,4 +69,7 @@ exports/*
|
||||
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.venv
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
|
||||
@@ -0,0 +1,779 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
EventLoopNode WebSocket Demo
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams EventLoopNode execution to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/event_loop_wss_demo.py
|
||||
|
||||
Then open http://localhost:8765 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
import os # noqa: E402
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.event_loop_node import EventLoopNode, JudgeVerdict, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state (shared across WebSocket connections)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_demo_"))
|
||||
STORE = FileConversationStore(STORE_DIR / "conversation")
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — real tools via ToolRegistry (same pattern as GraphExecutor)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
# Credential store: Aden sync (OAuth2 tokens) + encrypted files + env var fallback
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_local_storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
|
||||
if os.environ.get("ADEN_API_KEY"):
|
||||
try:
|
||||
from framework.credentials.aden import ( # noqa: E402
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
_client = AdenCredentialClient(AdenClientConfig(base_url="https://api.adenhq.com"))
|
||||
_provider = AdenSyncProvider(client=_client)
|
||||
_storage = AdenCachedStorage(
|
||||
local_storage=_local_storage,
|
||||
aden_provider=_provider,
|
||||
)
|
||||
_cred_store = CredentialStore(storage=_storage, providers=[_provider], auto_refresh=True)
|
||||
_synced = _provider.sync_all(_cred_store)
|
||||
logger.info("Synced %d credentials from Aden", _synced)
|
||||
except Exception as e:
|
||||
logger.warning("Aden sync unavailable: %s", e)
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
else:
|
||||
logger.info("ADEN_API_KEY not set, using local credential storage")
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
|
||||
CREDENTIALS = CredentialStoreAdapter(_cred_store)
|
||||
|
||||
# Debug: log which credentials resolved
|
||||
for _name in ["brave_search", "hubspot", "anthropic"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# --- web_search (Brave Search API) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={"X-Subscription-Token": api_key, "Accept": "application/json"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
# --- web_scrape (httpx + BeautifulSoup, no playwright for sync compat) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(url, timeout=30.0, follow_redirects=True, headers=_SCRAPE_HEADERS)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {"url": url, "title": title, "content": text, "length": len(text)}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
# --- HubSpot CRM tools (optional, requires HUBSPOT_ACCESS_TOKEN) ---
|
||||
|
||||
_HUBSPOT_API = "https://api.hubapi.com"
|
||||
|
||||
|
||||
def _hubspot_headers() -> dict | None:
|
||||
token = CREDENTIALS.get("hubspot")
|
||||
if token:
|
||||
logger.debug("HubSpot token: %s...%s (len=%d)", token[:8], token[-4:], len(token))
|
||||
else:
|
||||
logger.debug("HubSpot token: not found")
|
||||
if not token:
|
||||
return None
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _exec_hubspot_search(inputs: dict) -> dict:
|
||||
headers = _hubspot_headers()
|
||||
if not headers:
|
||||
return {"error": "HUBSPOT_ACCESS_TOKEN not set"}
|
||||
object_type = inputs.get("object_type", "contacts")
|
||||
query = inputs.get("query", "")
|
||||
limit = min(inputs.get("limit", 10), 100)
|
||||
body: dict = {"limit": limit}
|
||||
if query:
|
||||
body["query"] = query
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_HUBSPOT_API}/crm/v3/objects/{object_type}/search",
|
||||
headers=headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HubSpot API HTTP {resp.status_code}: {resp.text[:200]}"}
|
||||
return resp.json()
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"HubSpot error: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="hubspot_search",
|
||||
tool=Tool(
|
||||
name="hubspot_search",
|
||||
description=(
|
||||
"Search HubSpot CRM objects (contacts, companies, or deals). "
|
||||
"Returns matching records with their properties."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"object_type": {
|
||||
"type": "string",
|
||||
"description": "CRM object type: 'contacts', 'companies', or 'deals'",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (name, email, domain, etc.)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (1-100, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["object_type"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_hubspot_search(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# ChatJudge — keeps the event loop alive between user messages
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChatJudge:
|
||||
"""Judge that blocks between user messages, keeping the loop alive.
|
||||
|
||||
After the LLM finishes responding, the judge awaits a signal indicating
|
||||
a new user message has been injected, then returns RETRY to continue.
|
||||
"""
|
||||
|
||||
def __init__(self, on_ready=None):
|
||||
self._message_ready = asyncio.Event()
|
||||
self._shutdown = False
|
||||
self._on_ready = on_ready # async callback fired when waiting for input
|
||||
|
||||
async def evaluate(self, context: dict) -> JudgeVerdict:
|
||||
# Notify client that the LLM is done — ready for next input
|
||||
if self._on_ready:
|
||||
await self._on_ready()
|
||||
|
||||
# Block until next user message (or shutdown)
|
||||
self._message_ready.clear()
|
||||
await self._message_ready.wait()
|
||||
|
||||
if self._shutdown:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
def signal_message(self):
|
||||
"""Unblock the judge — a new user message has been injected."""
|
||||
self._message_ready.set()
|
||||
|
||||
def signal_shutdown(self):
|
||||
"""Unblock the judge and let the loop exit cleanly."""
|
||||
self._shutdown = True
|
||||
self._message_ready.set()
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page (embedded)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>EventLoopNode Live Demo</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117; color: #c9d1d9;
|
||||
height: 100vh; display: flex; flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22; padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex; align-items: center; gap: 16px;
|
||||
}
|
||||
header h1 { font-size: 16px; color: #58a6ff; font-weight: 600; }
|
||||
.status {
|
||||
font-size: 12px; padding: 3px 10px; border-radius: 12px;
|
||||
background: #21262d; color: #8b949e;
|
||||
}
|
||||
.status.running { background: #1a4b2e; color: #3fb950; }
|
||||
.status.done { background: #1a3a5c; color: #58a6ff; }
|
||||
.status.error { background: #4b1a1a; color: #f85149; }
|
||||
.chat { flex: 1; overflow-y: auto; padding: 16px; }
|
||||
.msg {
|
||||
margin: 8px 0; padding: 10px 14px; border-radius: 8px;
|
||||
line-height: 1.6; white-space: pre-wrap; word-wrap: break-word;
|
||||
}
|
||||
.msg.user { background: #1a3a5c; color: #58a6ff; }
|
||||
.msg.assistant { background: #161b22; color: #c9d1d9; }
|
||||
.msg.event {
|
||||
background: transparent; color: #8b949e; font-size: 11px;
|
||||
padding: 4px 14px; border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop { border-left-color: #58a6ff; }
|
||||
.msg.event.tool { border-left-color: #d29922; }
|
||||
.msg.event.stall { border-left-color: #f85149; }
|
||||
.input-bar {
|
||||
padding: 12px 16px; background: #161b22;
|
||||
border-top: 1px solid #30363d; display: flex; gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1; background: #0d1117; border: 1px solid #30363d;
|
||||
color: #c9d1d9; padding: 8px 12px; border-radius: 6px;
|
||||
font-family: inherit; font-size: 14px; outline: none;
|
||||
}
|
||||
.input-bar input:focus { border-color: #58a6ff; }
|
||||
.input-bar button {
|
||||
background: #238636; color: #fff; border: none;
|
||||
padding: 8px 20px; border-radius: 6px; cursor: pointer;
|
||||
font-family: inherit; font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover { background: #2ea043; }
|
||||
.input-bar button:disabled {
|
||||
background: #21262d; color: #484f58; cursor: not-allowed;
|
||||
}
|
||||
.input-bar button.clear { background: #da3633; }
|
||||
.input-bar button.clear:hover { background: #f85149; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>EventLoopNode Live</h1>
|
||||
<span id="status" class="status">Idle</span>
|
||||
<span id="iter" class="status" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Ask anything..." autofocus />
|
||||
<button id="go" onclick="run()">Send</button>
|
||||
<button class="clear"
|
||||
onclick="clearConversation()">Clear</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
const chat = document.getElementById('chat');
|
||||
const status = document.getElementById('status');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setStatus(text, cls) {
|
||||
status.textContent = text;
|
||||
status.className = 'status ' + cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setStatus('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setStatus('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'ready') {
|
||||
setStatus('Ready', 'done');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg('RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls);
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
}
|
||||
else if (evt.type === 'result') {
|
||||
setStatus('Session ended', evt.success ? 'done' : 'error');
|
||||
if (evt.error) addMsg('ERROR ' + evt.error, 'event stall');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
else if (evt.type === 'cleared') {
|
||||
chat.innerHTML = '';
|
||||
iterCount = 0;
|
||||
iterEl.textContent = 'Step 0';
|
||||
iterEl.style.display = 'none';
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
setStatus('Running', 'running');
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
function clearConversation() {
|
||||
if (ws && ws.readyState === 1) {
|
||||
ws.send(JSON.stringify({ command: 'clear' }));
|
||||
}
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Persistent WebSocket: long-lived EventLoopNode kept alive by ChatJudge."""
|
||||
global STORE
|
||||
|
||||
# -- Event forwarding (WebSocket ← EventBus) ----------------------------
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
# -- Ready callback (tells browser the LLM is done, waiting for input) --
|
||||
async def send_ready():
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "ready"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -- Per-connection state -----------------------------------------------
|
||||
judge = ChatJudge(on_ready=send_ready)
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
node_spec = NodeSpec(
|
||||
id="assistant",
|
||||
name="Chat Assistant",
|
||||
description="A conversational assistant that remembers context across messages",
|
||||
node_type="event_loop",
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to tools. "
|
||||
"You can search the web, scrape webpages, and query HubSpot CRM. "
|
||||
"Use tools when the user asks for current information or external data. "
|
||||
"You have full conversation history, so you can reference previous messages."
|
||||
),
|
||||
)
|
||||
|
||||
async def start_loop(first_message: str):
|
||||
"""Create an EventLoopNode and run it as a background task."""
|
||||
nonlocal node, loop_task
|
||||
|
||||
memory = SharedMemory()
|
||||
ctx = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="assistant",
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=10_000, max_history_tokens=32_000),
|
||||
conversation_store=STORE,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
await node.inject_event(first_message)
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"error": result.error,
|
||||
"tokens": result.tokens_used,
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Loop ended: success={result.success}, tokens={result.tokens_used}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("Loop stopped: WebSocket closed")
|
||||
except Exception as e:
|
||||
logger.exception("Loop error")
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"output": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
loop_task = asyncio.create_task(_run())
|
||||
|
||||
async def stop_loop():
|
||||
"""Signal the judge and wait for the loop task to finish."""
|
||||
nonlocal node, loop_task
|
||||
if loop_task and not loop_task.done():
|
||||
judge.signal_shutdown()
|
||||
try:
|
||||
await asyncio.wait_for(loop_task, timeout=5.0)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
loop_task.cancel()
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
# -- Message loop (runs for the lifetime of this WebSocket) -------------
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Clear command
|
||||
if msg.get("command") == "clear":
|
||||
import shutil
|
||||
|
||||
await stop_loop()
|
||||
await STORE.close()
|
||||
conv_dir = STORE_DIR / "conversation"
|
||||
if conv_dir.exists():
|
||||
shutil.rmtree(conv_dir)
|
||||
STORE = FileConversationStore(conv_dir)
|
||||
# Reset judge for next session
|
||||
judge = ChatJudge(on_ready=send_ready)
|
||||
await websocket.send(json.dumps({"type": "cleared"}))
|
||||
logger.info("Conversation cleared")
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
if node is None:
|
||||
# First message — spin up the loop
|
||||
logger.info(f"Starting persistent loop: {topic}")
|
||||
await start_loop(topic)
|
||||
else:
|
||||
# Subsequent message — inject and unblock the judge
|
||||
logger.info(f"Injecting message: {topic}")
|
||||
await node.inject_event(topic)
|
||||
judge.signal_message()
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
await stop_loop()
|
||||
logger.info("WebSocket closed, loop stopped")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler for serving the HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None # let websockets handle the upgrade
|
||||
# Serve the HTML page for any other path
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8765
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Demo running at http://localhost:{port}")
|
||||
logger.info("Open in your browser and enter a topic to research.")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,930 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Two-Node ContextHandoff Demo
|
||||
|
||||
Demonstrates ContextHandoff between two EventLoopNode instances:
|
||||
Node A (Researcher) → ContextHandoff → Node B (Analyst)
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams both nodes to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/handoff_demo.py
|
||||
|
||||
Then open http://localhost:8766 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.context_handoff import ContextHandoff # noqa: E402
|
||||
from framework.graph.conversation import NodeConversation # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("handoff_demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_handoff_"))
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Credentials
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Composite credential store: encrypted files (primary) + env vars (fallback)
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_composite = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
CREDENTIALS = CredentialStoreAdapter(CredentialStore(storage=_composite))
|
||||
|
||||
for _name in ["brave_search", "hubspot"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — web_search + web_scrape for Node A (Researcher)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(
|
||||
url,
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
headers=_SCRAPE_HEADERS,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": text,
|
||||
"length": len(text),
|
||||
}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Node Specs
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
RESEARCHER_SPEC = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches a topic using web search and scraping tools",
|
||||
node_type="event_loop",
|
||||
input_keys=["topic"],
|
||||
output_keys=["research_summary"],
|
||||
system_prompt=(
|
||||
"You are a thorough research assistant. Your job is to research "
|
||||
"the given topic using the web_search and web_scrape tools.\n\n"
|
||||
"1. Search for relevant information on the topic\n"
|
||||
"2. Scrape 1-2 of the most promising URLs for details\n"
|
||||
"3. Synthesize your findings into a comprehensive summary\n"
|
||||
"4. Use set_output with key='research_summary' to save your "
|
||||
"findings\n\n"
|
||||
"Be thorough but efficient. Aim for 2-4 search/scrape calls, "
|
||||
"then summarize and set_output."
|
||||
),
|
||||
)
|
||||
|
||||
ANALYST_SPEC = NodeSpec(
|
||||
id="analyst",
|
||||
name="Analyst",
|
||||
description="Analyzes research findings and provides insights",
|
||||
node_type="event_loop",
|
||||
input_keys=["context"],
|
||||
output_keys=["analysis"],
|
||||
system_prompt=(
|
||||
"You are a strategic analyst. You receive research findings from "
|
||||
"a previous researcher and must:\n\n"
|
||||
"1. Identify key themes and patterns\n"
|
||||
"2. Assess the reliability and significance of the findings\n"
|
||||
"3. Provide actionable insights and recommendations\n"
|
||||
"4. Use set_output with key='analysis' to save your analysis\n\n"
|
||||
"Be concise but insightful. Focus on what matters most."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>ContextHandoff Demo</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117;
|
||||
color: #c9d1d9;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22;
|
||||
padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
header h1 {
|
||||
font-size: 16px;
|
||||
color: #58a6ff;
|
||||
font-weight: 600;
|
||||
}
|
||||
.badge {
|
||||
font-size: 12px;
|
||||
padding: 3px 10px;
|
||||
border-radius: 12px;
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.researcher {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.badge.analyst {
|
||||
background: #1a4b2e;
|
||||
color: #3fb950;
|
||||
}
|
||||
.badge.handoff {
|
||||
background: #3d1f00;
|
||||
color: #d29922;
|
||||
}
|
||||
.badge.done {
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.error {
|
||||
background: #4b1a1a;
|
||||
color: #f85149;
|
||||
}
|
||||
.chat {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 16px;
|
||||
}
|
||||
.msg {
|
||||
margin: 8px 0;
|
||||
padding: 10px 14px;
|
||||
border-radius: 8px;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
.msg.user {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.msg.assistant {
|
||||
background: #161b22;
|
||||
color: #c9d1d9;
|
||||
}
|
||||
.msg.assistant.analyst-msg {
|
||||
border-left: 3px solid #3fb950;
|
||||
}
|
||||
.msg.event {
|
||||
background: transparent;
|
||||
color: #8b949e;
|
||||
font-size: 11px;
|
||||
padding: 4px 14px;
|
||||
border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop {
|
||||
border-left-color: #58a6ff;
|
||||
}
|
||||
.msg.event.tool {
|
||||
border-left-color: #d29922;
|
||||
}
|
||||
.msg.event.stall {
|
||||
border-left-color: #f85149;
|
||||
}
|
||||
.handoff-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #1c1200;
|
||||
border: 1px solid #d29922;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.handoff-banner h3 {
|
||||
color: #d29922;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.handoff-banner p, .result-banner p {
|
||||
color: #8b949e;
|
||||
font-size: 12px;
|
||||
line-height: 1.5;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap;
|
||||
text-align: left;
|
||||
}
|
||||
.result-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #0a2614;
|
||||
border: 1px solid #3fb950;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.result-banner h3 {
|
||||
color: #3fb950;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.result-banner .label {
|
||||
color: #58a6ff;
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
.result-banner .tokens {
|
||||
color: #484f58;
|
||||
font-size: 11px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.input-bar {
|
||||
padding: 12px 16px;
|
||||
background: #161b22;
|
||||
border-top: 1px solid #30363d;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1;
|
||||
background: #0d1117;
|
||||
border: 1px solid #30363d;
|
||||
color: #c9d1d9;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
font-family: inherit;
|
||||
font-size: 14px;
|
||||
outline: none;
|
||||
}
|
||||
.input-bar input:focus {
|
||||
border-color: #58a6ff;
|
||||
}
|
||||
.input-bar button {
|
||||
background: #238636;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 20px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover {
|
||||
background: #2ea043;
|
||||
}
|
||||
.input-bar button:disabled {
|
||||
background: #21262d;
|
||||
color: #484f58;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>ContextHandoff Demo</h1>
|
||||
<span id="phase" class="badge">Idle</span>
|
||||
<span id="iter" class="badge" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Enter a research topic..." autofocus />
|
||||
<button id="go" onclick="run()">Research</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
let currentPhase = 'idle';
|
||||
const chat = document.getElementById('chat');
|
||||
const phase = document.getElementById('phase');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setPhase(text, cls) {
|
||||
phase.textContent = text;
|
||||
phase.className = 'badge ' + cls;
|
||||
currentPhase = cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function addHandoffBanner(summary) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'handoff-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Context Handoff: Researcher -> Analyst';
|
||||
const p = document.createElement('p');
|
||||
p.textContent = summary || 'Passing research context...';
|
||||
banner.appendChild(h3);
|
||||
banner.appendChild(p);
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function addResultBanner(researcher, analyst, tokens) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'result-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Pipeline Complete';
|
||||
banner.appendChild(h3);
|
||||
|
||||
if (researcher && researcher.research_summary) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'RESEARCH SUMMARY';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = researcher.research_summary;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (analyst && analyst.analysis) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'ANALYSIS';
|
||||
lbl.style.color = '#3fb950';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = analyst.analysis;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (tokens) {
|
||||
const t = document.createElement('div');
|
||||
t.className = 'tokens';
|
||||
t.textContent = 'Total tokens: ' + tokens.toLocaleString();
|
||||
banner.appendChild(t);
|
||||
}
|
||||
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setPhase('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setPhase('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setPhase('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'phase') {
|
||||
if (evt.phase === 'researcher') {
|
||||
setPhase('Researcher', 'researcher');
|
||||
} else if (evt.phase === 'handoff') {
|
||||
setPhase('Handoff', 'handoff');
|
||||
} else if (evt.phase === 'analyst') {
|
||||
setPhase('Analyst', 'analyst');
|
||||
}
|
||||
iterCount = 0;
|
||||
iterEl.style.display = 'none';
|
||||
}
|
||||
else if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg(
|
||||
'RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls
|
||||
);
|
||||
var assistCls = currentPhase === 'analyst'
|
||||
? 'assistant analyst-msg' : 'assistant';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'handoff_context') {
|
||||
addHandoffBanner(evt.summary);
|
||||
var assistCls = 'assistant analyst-msg';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'node_result') {
|
||||
if (evt.node_id === 'researcher') {
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'done') {
|
||||
setPhase('Done', 'done');
|
||||
iterEl.style.display = 'none';
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
currentAssistantEl = null;
|
||||
addResultBanner(
|
||||
evt.researcher, evt.analyst, evt.total_tokens
|
||||
);
|
||||
goBtn.disabled = false;
|
||||
inputEl.placeholder = 'Enter another topic...';
|
||||
}
|
||||
else if (evt.type === 'error') {
|
||||
setPhase('Error', 'error');
|
||||
addMsg('ERROR ' + evt.message, 'event stall');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
chat.innerHTML = '';
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler — sequential Node A → Handoff → Node B
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Run the two-node handoff pipeline per user message."""
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
logger.info(f"Starting handoff pipeline for: {topic}")
|
||||
|
||||
try:
|
||||
await _run_pipeline(websocket, topic)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("WebSocket closed during pipeline")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Pipeline error")
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_pipeline(websocket, topic: str):
|
||||
"""Execute: Node A (research) → ContextHandoff → Node B (analysis)."""
|
||||
import shutil
|
||||
|
||||
# Fresh stores for each run
|
||||
run_dir = Path(tempfile.mkdtemp(prefix="hive_run_", dir=STORE_DIR))
|
||||
store_a = FileConversationStore(run_dir / "node_a")
|
||||
store_b = FileConversationStore(run_dir / "node_b")
|
||||
|
||||
# Shared event bus
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
# ---- Phase 1: Researcher ------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "researcher"}))
|
||||
|
||||
node_a = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge: accept when output_keys filled
|
||||
config=LoopConfig(
|
||||
max_iterations=20,
|
||||
max_tool_calls_per_turn=10,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_a,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
ctx_a = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="researcher",
|
||||
node_spec=RESEARCHER_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"topic": topic},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
result_a = await node_a.execute(ctx_a)
|
||||
logger.info(
|
||||
"Researcher done: success=%s, tokens=%s",
|
||||
result_a.success,
|
||||
result_a.tokens_used,
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_result",
|
||||
"node_id": "researcher",
|
||||
"success": result_a.success,
|
||||
"output": result_a.output,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if not result_a.success:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": f"Researcher failed: {result_a.error}",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# ---- Phase 2: Context Handoff -------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "handoff"}))
|
||||
|
||||
# Restore the researcher's conversation from store
|
||||
conversation_a = await NodeConversation.restore(store_a)
|
||||
if conversation_a is None:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Failed to restore researcher conversation",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
handoff_engine = ContextHandoff(llm=LLM)
|
||||
handoff_context = handoff_engine.summarize_conversation(
|
||||
conversation=conversation_a,
|
||||
node_id="researcher",
|
||||
output_keys=["research_summary"],
|
||||
)
|
||||
|
||||
formatted_handoff = ContextHandoff.format_as_input(handoff_context)
|
||||
logger.info(
|
||||
"Handoff: %d turns, ~%d tokens, keys=%s",
|
||||
handoff_context.turn_count,
|
||||
handoff_context.total_tokens_used,
|
||||
list(handoff_context.key_outputs.keys()),
|
||||
)
|
||||
|
||||
# Send handoff context to browser
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "handoff_context",
|
||||
"summary": handoff_context.summary[:500],
|
||||
"turn_count": handoff_context.turn_count,
|
||||
"tokens": handoff_context.total_tokens_used,
|
||||
"key_outputs": handoff_context.key_outputs,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# ---- Phase 3: Analyst ---------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "analyst"}))
|
||||
|
||||
node_b = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
max_tool_calls_per_turn=5,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_b,
|
||||
)
|
||||
|
||||
ctx_b = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="analyst",
|
||||
node_spec=ANALYST_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"context": formatted_handoff},
|
||||
llm=LLM,
|
||||
available_tools=[],
|
||||
)
|
||||
|
||||
result_b = await node_b.execute(ctx_b)
|
||||
logger.info(
|
||||
"Analyst done: success=%s, tokens=%s",
|
||||
result_b.success,
|
||||
result_b.tokens_used,
|
||||
)
|
||||
|
||||
# ---- Done ---------------------------------------------------------------
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "done",
|
||||
"researcher": result_a.output,
|
||||
"analyst": result_b.output,
|
||||
"total_tokens": ((result_a.tokens_used or 0) + (result_b.tokens_used or 0)),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Clean up temp stores
|
||||
try:
|
||||
shutil.rmtree(run_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8766
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Handoff demo at http://localhost:{port}")
|
||||
logger.info("Enter a research topic to start the pipeline.")
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -64,6 +64,8 @@ class AdenCachedStorage(CredentialStorage):
|
||||
- **Reads**: Try local cache first, fallback to Aden if stale/missing
|
||||
- **Writes**: Always write to local cache
|
||||
- **Offline resilience**: Uses cached credentials when Aden is unreachable
|
||||
- **Provider-based lookup**: Match credentials by provider name (e.g., "hubspot")
|
||||
when direct ID lookup fails, since Aden uses hash-based IDs internally.
|
||||
|
||||
The cache TTL determines how long to trust local credentials before
|
||||
checking with the Aden server for updates. This balances:
|
||||
@@ -85,6 +87,7 @@ class AdenCachedStorage(CredentialStorage):
|
||||
|
||||
# First access fetches from Aden
|
||||
# Subsequent accesses use cache until TTL expires
|
||||
# Can look up by provider name OR credential ID
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
|
||||
@@ -111,21 +114,24 @@ class AdenCachedStorage(CredentialStorage):
|
||||
self._cache_ttl = timedelta(seconds=cache_ttl_seconds)
|
||||
self._prefer_local = prefer_local
|
||||
self._cache_timestamps: dict[str, datetime] = {}
|
||||
# Index: provider name (e.g., "hubspot") -> credential hash ID
|
||||
self._provider_index: dict[str, str] = {}
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save credential to local cache.
|
||||
Save credential to local cache and update provider index.
|
||||
|
||||
Args:
|
||||
credential: The credential to save.
|
||||
"""
|
||||
self._local.save(credential)
|
||||
self._cache_timestamps[credential.id] = datetime.now(UTC)
|
||||
self._index_provider(credential)
|
||||
logger.debug(f"Cached credential '{credential.id}'")
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load credential from cache, with Aden fallback.
|
||||
Load credential from cache, with Aden fallback and provider-based lookup.
|
||||
|
||||
The loading strategy depends on the `prefer_local` setting:
|
||||
|
||||
@@ -141,8 +147,37 @@ class AdenCachedStorage(CredentialStorage):
|
||||
2. Update local cache with response
|
||||
3. Fall back to local cache only if Aden fails
|
||||
|
||||
Provider-based lookup:
|
||||
When a provider index mapping exists for the credential_id (e.g.,
|
||||
"hubspot" → hash ID), the Aden-synced credential is loaded first.
|
||||
This ensures fresh OAuth tokens from Aden take priority over stale
|
||||
local credentials (env vars, old encrypted files).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
credential_id: The credential identifier or provider name.
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
"""
|
||||
# Check provider index first — Aden-synced credentials take priority
|
||||
resolved_id = self._provider_index.get(credential_id)
|
||||
if resolved_id and resolved_id != credential_id:
|
||||
result = self._load_by_id(resolved_id)
|
||||
if result is not None:
|
||||
logger.info(
|
||||
f"Loaded credential '{credential_id}' via provider index (id='{resolved_id}')"
|
||||
)
|
||||
return result
|
||||
|
||||
# Direct lookup (exact credential_id match)
|
||||
return self._load_by_id(credential_id)
|
||||
|
||||
def _load_by_id(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load credential by exact ID from cache, with Aden fallback.
|
||||
|
||||
Args:
|
||||
credential_id: The exact credential identifier.
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
@@ -200,15 +235,21 @@ class AdenCachedStorage(CredentialStorage):
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if credential exists in local cache.
|
||||
Check if credential exists in local cache (by ID or provider name).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
credential_id: The credential identifier or provider name.
|
||||
|
||||
Returns:
|
||||
True if credential exists locally.
|
||||
"""
|
||||
return self._local.exists(credential_id)
|
||||
if self._local.exists(credential_id):
|
||||
return True
|
||||
# Check provider index
|
||||
resolved_id = self._provider_index.get(credential_id)
|
||||
if resolved_id and resolved_id != credential_id:
|
||||
return self._local.exists(resolved_id)
|
||||
return False
|
||||
|
||||
def _is_cache_fresh(self, credential_id: str) -> bool:
|
||||
"""
|
||||
@@ -242,6 +283,47 @@ class AdenCachedStorage(CredentialStorage):
|
||||
self._cache_timestamps.clear()
|
||||
logger.debug("Invalidated all cache entries")
|
||||
|
||||
def _index_provider(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Index a credential by its provider/integration type.
|
||||
|
||||
Aden credentials carry an ``_integration_type`` key whose value is
|
||||
the provider name (e.g., ``hubspot``). This method maps that
|
||||
provider name to the credential's hash ID so that subsequent
|
||||
``load("hubspot")`` calls resolve to the correct credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to index.
|
||||
"""
|
||||
integration_type_key = credential.keys.get("_integration_type")
|
||||
if integration_type_key is None:
|
||||
return
|
||||
provider_name = integration_type_key.value.get_secret_value()
|
||||
if provider_name:
|
||||
self._provider_index[provider_name] = credential.id
|
||||
logger.debug(f"Indexed provider '{provider_name}' -> '{credential.id}'")
|
||||
|
||||
def rebuild_provider_index(self) -> int:
|
||||
"""
|
||||
Rebuild the provider index from all locally cached credentials.
|
||||
|
||||
Useful after loading from disk when the in-memory index is empty.
|
||||
|
||||
Returns:
|
||||
Number of provider mappings indexed.
|
||||
"""
|
||||
self._provider_index.clear()
|
||||
indexed = 0
|
||||
for cred_id in self._local.list_all():
|
||||
cred = self._local.load(cred_id)
|
||||
if cred:
|
||||
before = len(self._provider_index)
|
||||
self._index_provider(cred)
|
||||
if len(self._provider_index) > before:
|
||||
indexed += 1
|
||||
logger.debug(f"Rebuilt provider index with {indexed} mappings")
|
||||
return indexed
|
||||
|
||||
def sync_all_from_aden(self) -> int:
|
||||
"""
|
||||
Sync all credentials from Aden server to local cache.
|
||||
|
||||
@@ -589,6 +589,149 @@ class TestAdenCachedStorage:
|
||||
assert info["stale"]["is_fresh"] is False
|
||||
assert info["stale"]["ttl_remaining_seconds"] == 0
|
||||
|
||||
def test_save_indexes_provider(self, cached_storage):
|
||||
"""Test save builds the provider index from _integration_type key."""
|
||||
cred = CredentialObject(
|
||||
id="aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token-value"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert cached_storage._provider_index["hubspot"] == "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
|
||||
def test_load_by_provider_name(self, cached_storage):
|
||||
"""Test load resolves provider name to hash-based credential ID."""
|
||||
hash_id = "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("hubspot-token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Save builds the index
|
||||
cached_storage.save(cred)
|
||||
|
||||
# Load by provider name should resolve to the hash ID
|
||||
loaded = cached_storage.load("hubspot")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == hash_id
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "hubspot-token"
|
||||
|
||||
def test_load_by_direct_id_still_works(self, cached_storage):
|
||||
"""Test load by direct hash ID still works as before."""
|
||||
hash_id = "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
# Direct ID lookup should still work
|
||||
loaded = cached_storage.load(hash_id)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == hash_id
|
||||
|
||||
def test_exists_by_provider_name(self, cached_storage):
|
||||
"""Test exists resolves provider name to hash-based credential ID."""
|
||||
hash_id = "c2xhY2s6dGVzdDo5OTk="
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("slack-token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("slack"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert cached_storage.exists("slack") is True
|
||||
assert cached_storage.exists(hash_id) is True
|
||||
assert cached_storage.exists("nonexistent") is False
|
||||
|
||||
def test_rebuild_provider_index(self, cached_storage, local_storage):
|
||||
"""Test rebuild_provider_index reconstructs from local storage."""
|
||||
# Manually save credentials to local storage (bypassing cached_storage.save)
|
||||
for provider_name, hash_id in [("hubspot", "hash_hub"), ("slack", "hash_slack")]:
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr(provider_name),
|
||||
),
|
||||
},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
|
||||
# Index should be empty (we bypassed save)
|
||||
assert len(cached_storage._provider_index) == 0
|
||||
|
||||
# Rebuild
|
||||
indexed = cached_storage.rebuild_provider_index()
|
||||
|
||||
assert indexed == 2
|
||||
assert cached_storage._provider_index["hubspot"] == "hash_hub"
|
||||
assert cached_storage._provider_index["slack"] == "hash_slack"
|
||||
|
||||
def test_save_without_integration_type_no_index(self, cached_storage):
|
||||
"""Test save does not index credentials without _integration_type key."""
|
||||
cred = CredentialObject(
|
||||
id="plain-cred",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={
|
||||
"api_key": CredentialKey(
|
||||
name="api_key",
|
||||
value=SecretStr("key-value"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert "plain-cred" not in cached_storage._provider_index
|
||||
assert len(cached_storage._provider_index) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
|
||||
|
||||
from framework.graph.client_io import (
|
||||
ActiveNodeClientIO,
|
||||
ClientIOGateway,
|
||||
InertNodeClientIO,
|
||||
NodeClientIO,
|
||||
)
|
||||
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import ConversationStore, Message, NodeConversation
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
JudgeProtocol,
|
||||
JudgeVerdict,
|
||||
LoopConfig,
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
|
||||
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
|
||||
@@ -77,4 +91,18 @@ __all__ = [
|
||||
"NodeConversation",
|
||||
"ConversationStore",
|
||||
"Message",
|
||||
# Event Loop
|
||||
"EventLoopNode",
|
||||
"LoopConfig",
|
||||
"OutputAccumulator",
|
||||
"JudgeProtocol",
|
||||
"JudgeVerdict",
|
||||
# Context Handoff
|
||||
"ContextHandoff",
|
||||
"HandoffContext",
|
||||
# Client I/O
|
||||
"NodeClientIO",
|
||||
"ActiveNodeClientIO",
|
||||
"InertNodeClientIO",
|
||||
"ClientIOGateway",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Client I/O gateway for graph nodes.
|
||||
|
||||
Provides the bridge between node code and external clients:
|
||||
- ActiveNodeClientIO: for client_facing=True nodes (streams output, accepts input)
|
||||
- InertNodeClientIO: for client_facing=False nodes (logs internally, redirects input)
|
||||
- ClientIOGateway: factory that creates the right variant per node
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeClientIO(ABC):
|
||||
"""Abstract base for node client I/O."""
|
||||
|
||||
@abstractmethod
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
"""Emit output content. If is_final=True, signal end of stream."""
|
||||
|
||||
@abstractmethod
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
"""Request input. Behavior depends on whether the node is client-facing."""
|
||||
|
||||
|
||||
class ActiveNodeClientIO(NodeClientIO):
|
||||
"""
|
||||
Client I/O for client_facing=True nodes.
|
||||
|
||||
- emit_output() queues content and publishes CLIENT_OUTPUT_DELTA.
|
||||
- request_input() publishes CLIENT_INPUT_REQUESTED, then awaits provide_input().
|
||||
- output_stream() yields queued content until the final sentinel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
|
||||
self._output_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._output_snapshot = ""
|
||||
|
||||
self._input_event: asyncio.Event | None = None
|
||||
self._input_result: str | None = None
|
||||
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
self._output_snapshot += content
|
||||
await self._output_queue.put(content)
|
||||
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_client_output_delta(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
snapshot=self._output_snapshot,
|
||||
)
|
||||
|
||||
if is_final:
|
||||
await self._output_queue.put(None)
|
||||
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
if self._input_event is not None:
|
||||
raise RuntimeError("request_input already pending for this node")
|
||||
|
||||
self._input_event = asyncio.Event()
|
||||
self._input_result = None
|
||||
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_client_input_requested(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
try:
|
||||
if timeout is not None:
|
||||
await asyncio.wait_for(self._input_event.wait(), timeout=timeout)
|
||||
else:
|
||||
await self._input_event.wait()
|
||||
finally:
|
||||
self._input_event = None
|
||||
|
||||
if self._input_result is None:
|
||||
raise RuntimeError("input event was set but no input was provided")
|
||||
result = self._input_result
|
||||
self._input_result = None
|
||||
return result
|
||||
|
||||
async def provide_input(self, content: str) -> None:
|
||||
"""Called externally to fulfill a pending request_input()."""
|
||||
if self._input_event is None:
|
||||
raise RuntimeError("no pending request_input to fulfill")
|
||||
self._input_result = content
|
||||
self._input_event.set()
|
||||
|
||||
async def output_stream(self) -> AsyncIterator[str]:
|
||||
"""Async iterator that yields output chunks until the final sentinel."""
|
||||
while True:
|
||||
chunk = await self._output_queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
|
||||
class InertNodeClientIO(NodeClientIO):
|
||||
"""
|
||||
Client I/O for client_facing=False nodes.
|
||||
|
||||
- emit_output() publishes NODE_INTERNAL_OUTPUT (content is not discarded).
|
||||
- request_input() publishes NODE_INPUT_BLOCKED and returns a redirect string.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_node_internal_output(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_node_input_blocked(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
return (
|
||||
"You are an internal processing node. There is no user to interact with."
|
||||
" Work with the data provided in your inputs to complete your task."
|
||||
)
|
||||
|
||||
|
||||
class ClientIOGateway:
|
||||
"""Factory that creates the appropriate NodeClientIO for a node."""
|
||||
|
||||
def __init__(self, event_bus: EventBus | None = None) -> None:
|
||||
self._event_bus = event_bus
|
||||
|
||||
def create_io(self, node_id: str, client_facing: bool) -> NodeClientIO:
|
||||
if client_facing:
|
||||
return ActiveNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
)
|
||||
return InertNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
)
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Context handoff: summarize a completed NodeConversation for the next graph node."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.conversation import _try_extract_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.provider import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRUNCATE_CHARS = 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandoffContext:
|
||||
"""Structured summary of a completed node conversation."""
|
||||
|
||||
source_node_id: str
|
||||
summary: str
|
||||
key_outputs: dict[str, Any]
|
||||
turn_count: int
|
||||
total_tokens_used: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContextHandoff
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ContextHandoff:
|
||||
"""Summarize a completed NodeConversation into a HandoffContext.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
llm : LLMProvider | None
|
||||
Optional LLM provider for abstractive summarization.
|
||||
When *None*, all summarization uses the extractive fallback.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLMProvider | None = None) -> None:
|
||||
self.llm = llm
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def summarize_conversation(
|
||||
self,
|
||||
conversation: NodeConversation,
|
||||
node_id: str,
|
||||
output_keys: list[str] | None = None,
|
||||
) -> HandoffContext:
|
||||
"""Produce a HandoffContext from *conversation*.
|
||||
|
||||
1. Extracts turn_count & total_tokens_used (sync properties).
|
||||
2. Extracts key_outputs by scanning assistant messages most-recent-first.
|
||||
3. Builds a summary via the LLM (if available) or extractive fallback.
|
||||
"""
|
||||
turn_count = conversation.turn_count
|
||||
total_tokens_used = conversation.estimate_tokens()
|
||||
messages = conversation.messages # defensive copy
|
||||
|
||||
# --- key outputs ---------------------------------------------------
|
||||
key_outputs: dict[str, Any] = {}
|
||||
if output_keys:
|
||||
remaining = set(output_keys)
|
||||
for msg in reversed(messages):
|
||||
if msg.role != "assistant" or not remaining:
|
||||
continue
|
||||
for key in list(remaining):
|
||||
value = _try_extract_key(msg.content, key)
|
||||
if value is not None:
|
||||
key_outputs[key] = value
|
||||
remaining.discard(key)
|
||||
|
||||
# --- summary -------------------------------------------------------
|
||||
if self.llm is not None:
|
||||
try:
|
||||
summary = self._llm_summary(messages, output_keys or [])
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"LLM summarization failed; falling back to extractive.",
|
||||
exc_info=True,
|
||||
)
|
||||
summary = self._extractive_summary(messages)
|
||||
else:
|
||||
summary = self._extractive_summary(messages)
|
||||
|
||||
return HandoffContext(
|
||||
source_node_id=node_id,
|
||||
summary=summary,
|
||||
key_outputs=key_outputs,
|
||||
turn_count=turn_count,
|
||||
total_tokens_used=total_tokens_used,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_as_input(handoff: HandoffContext) -> str:
|
||||
"""Render *handoff* as structured plain text for the next node's input."""
|
||||
header = (
|
||||
f"--- CONTEXT FROM: {handoff.source_node_id} "
|
||||
f"({handoff.turn_count} turns, ~{handoff.total_tokens_used} tokens) ---"
|
||||
)
|
||||
|
||||
sections: list[str] = [header, ""]
|
||||
|
||||
if handoff.key_outputs:
|
||||
sections.append("KEY OUTPUTS:")
|
||||
for k, v in handoff.key_outputs.items():
|
||||
sections.append(f"- {k}: {v}")
|
||||
sections.append("")
|
||||
|
||||
summary_text = handoff.summary or "No summary available."
|
||||
sections.append("SUMMARY:")
|
||||
sections.append(summary_text)
|
||||
sections.append("")
|
||||
sections.append("--- END CONTEXT ---")
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extractive_summary(messages: list) -> str:
|
||||
"""Build a summary from key assistant messages without an LLM.
|
||||
|
||||
Strategy:
|
||||
- Include the first assistant message (initial assessment).
|
||||
- Include the last assistant message (final conclusion).
|
||||
- Truncate each to ~500 chars.
|
||||
"""
|
||||
if not messages:
|
||||
return "Empty conversation."
|
||||
|
||||
assistant_msgs = [m for m in messages if m.role == "assistant"]
|
||||
if not assistant_msgs:
|
||||
return "No assistant responses."
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
first = assistant_msgs[0].content
|
||||
parts.append(first[:_TRUNCATE_CHARS])
|
||||
|
||||
if len(assistant_msgs) > 1:
|
||||
last = assistant_msgs[-1].content
|
||||
parts.append(last[:_TRUNCATE_CHARS])
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _llm_summary(self, messages: list, output_keys: list[str]) -> str:
|
||||
"""Produce a summary by calling the LLM provider."""
|
||||
if self.llm is None:
|
||||
raise ValueError("_llm_summary called without an LLM provider")
|
||||
|
||||
conversation_text = "\n".join(f"[{m.role}]: {m.content}" for m in messages)
|
||||
|
||||
key_hint = ""
|
||||
if output_keys:
|
||||
key_hint = (
|
||||
"\nThe following output keys are especially important: "
|
||||
+ ", ".join(output_keys)
|
||||
+ ".\n"
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a concise summarizer. Given the conversation below, "
|
||||
"produce a brief summary (at most ~500 tokens) that captures the "
|
||||
"key decisions, findings, and outcomes. Focus on what was concluded "
|
||||
"rather than the back-and-forth process." + key_hint
|
||||
)
|
||||
|
||||
response = self.llm.complete(
|
||||
messages=[{"role": "user", "content": conversation_text}],
|
||||
system=system_prompt,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
return response.content.strip()
|
||||
@@ -108,6 +108,50 @@ class ConversationStore(Protocol):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_extract_key(content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a *key*'s value from message content.
|
||||
|
||||
Strategies (in order):
|
||||
1. Whole message is JSON — ``json.loads``, check for key.
|
||||
2. Embedded JSON via ``find_json_object`` helper.
|
||||
3. Colon format: ``key: value``.
|
||||
4. Equals format: ``key = value``.
|
||||
"""
|
||||
from framework.graph.node import find_json_object
|
||||
|
||||
# 1. Whole message is JSON
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 2. Embedded JSON via find_json_object
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Colon format: key: value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 4. Equals format: key = value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NodeConversation:
|
||||
"""Message history for a graph node with optional write-through persistence.
|
||||
|
||||
@@ -133,6 +177,7 @@ class NodeConversation:
|
||||
self._messages: list[Message] = []
|
||||
self._next_seq: int = 0
|
||||
self._meta_persisted: bool = False
|
||||
self._last_api_input_tokens: int | None = None
|
||||
|
||||
# --- Properties --------------------------------------------------------
|
||||
|
||||
@@ -205,14 +250,78 @@ class NodeConversation:
|
||||
# --- Query -------------------------------------------------------------
|
||||
|
||||
def to_llm_messages(self) -> list[dict[str, Any]]:
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded)."""
|
||||
return [m.to_llm_dict() for m in self._messages]
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded).
|
||||
|
||||
Automatically repairs orphaned tool_use blocks (assistant messages
|
||||
with tool_calls that lack corresponding tool-result messages). This
|
||||
can happen when a loop is cancelled mid-tool-execution.
|
||||
"""
|
||||
msgs = [m.to_llm_dict() for m in self._messages]
|
||||
return self._repair_orphaned_tool_calls(msgs)
|
||||
|
||||
@staticmethod
|
||||
def _repair_orphaned_tool_calls(
|
||||
msgs: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Ensure every tool_call has a matching tool-result message."""
|
||||
repaired: list[dict[str, Any]] = []
|
||||
for i, m in enumerate(msgs):
|
||||
repaired.append(m)
|
||||
tool_calls = m.get("tool_calls")
|
||||
if m.get("role") != "assistant" or not tool_calls:
|
||||
continue
|
||||
# Collect IDs of tool results that follow this assistant message
|
||||
answered: set[str] = set()
|
||||
for j in range(i + 1, len(msgs)):
|
||||
if msgs[j].get("role") == "tool":
|
||||
tid = msgs[j].get("tool_call_id")
|
||||
if tid:
|
||||
answered.add(tid)
|
||||
else:
|
||||
break # stop at first non-tool message
|
||||
# Patch any missing results
|
||||
for tc in tool_calls:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in answered:
|
||||
repaired.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": "ERROR: Tool execution was interrupted.",
|
||||
}
|
||||
)
|
||||
return repaired
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Rough token estimate: total characters / 4."""
|
||||
"""Best available token estimate.
|
||||
|
||||
Uses actual API input token count when available (set via
|
||||
:meth:`update_token_count`), otherwise falls back to the rough
|
||||
``total_chars / 4`` heuristic.
|
||||
"""
|
||||
if self._last_api_input_tokens is not None:
|
||||
return self._last_api_input_tokens
|
||||
total_chars = sum(len(m.content) for m in self._messages)
|
||||
return total_chars // 4
|
||||
|
||||
def update_token_count(self, actual_input_tokens: int) -> None:
|
||||
"""Store actual API input token count for more accurate compaction.
|
||||
|
||||
Called by EventLoopNode after each LLM call with the ``input_tokens``
|
||||
value from the API response. This value includes system prompt and
|
||||
tool definitions, so it may be higher than a message-only estimate.
|
||||
"""
|
||||
self._last_api_input_tokens = actual_input_tokens
|
||||
|
||||
def usage_ratio(self) -> float:
|
||||
"""Current token usage as a fraction of *max_history_tokens*.
|
||||
|
||||
Returns 0.0 when ``max_history_tokens`` is zero (unlimited).
|
||||
"""
|
||||
if self._max_history_tokens <= 0:
|
||||
return 0.0
|
||||
return self.estimate_tokens() / self._max_history_tokens
|
||||
|
||||
def needs_compaction(self) -> bool:
|
||||
return self.estimate_tokens() >= self._max_history_tokens * self._compaction_threshold
|
||||
|
||||
@@ -244,39 +353,7 @@ class NodeConversation:
|
||||
|
||||
def _try_extract_key(self, content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a key's value from message content."""
|
||||
from framework.graph.node import find_json_object
|
||||
|
||||
# 1. Whole message is JSON
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 2. Embedded JSON via find_json_object
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Colon format: key: value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 4. Equals format: key = value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
return _try_extract_key(content, key)
|
||||
|
||||
# --- Lifecycle ---------------------------------------------------------
|
||||
|
||||
@@ -330,6 +407,7 @@ class NodeConversation:
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
self._messages = [summary_msg] + recent_messages
|
||||
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
|
||||
@@ -337,6 +415,7 @@ class NodeConversation:
|
||||
await self._store.delete_parts_before(self._next_seq)
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
self._messages.clear()
|
||||
self._last_api_input_tokens = None
|
||||
|
||||
def export_summary(self) -> str:
|
||||
"""Structured summary with [STATS], [CONFIG], [RECENT_MESSAGES] sections."""
|
||||
|
||||
@@ -11,7 +11,6 @@ our edges can be created dynamically by a Builder agent based on the goal.
|
||||
|
||||
Edge Types:
|
||||
- always: Always traverse after source completes
|
||||
- always: Always traverse after source completes
|
||||
- on_success: Traverse only if source succeeds
|
||||
- on_failure: Traverse only if source fails
|
||||
- conditional: Traverse based on expression evaluation (SAFE SUBSET ONLY)
|
||||
@@ -609,4 +608,40 @@ class GraphSpec(BaseModel):
|
||||
continue
|
||||
errors.append(f"Node '{node.id}' is unreachable from entry")
|
||||
|
||||
# Client-facing fan-out validation
|
||||
fan_outs = self.detect_fan_out_nodes()
|
||||
for source_id, targets in fan_outs.items():
|
||||
client_facing_targets = [
|
||||
t
|
||||
for t in targets
|
||||
if self.get_node(t) and getattr(self.get_node(t), "client_facing", False)
|
||||
]
|
||||
if len(client_facing_targets) > 1:
|
||||
errors.append(
|
||||
f"Fan-out from '{source_id}' has multiple client-facing nodes: "
|
||||
f"{client_facing_targets}. Only one branch may be client-facing."
|
||||
)
|
||||
|
||||
# Output key overlap on parallel event_loop nodes
|
||||
for source_id, targets in fan_outs.items():
|
||||
event_loop_targets = [
|
||||
t
|
||||
for t in targets
|
||||
if self.get_node(t) and getattr(self.get_node(t), "node_type", "") == "event_loop"
|
||||
]
|
||||
if len(event_loop_targets) > 1:
|
||||
seen_keys: dict[str, str] = {}
|
||||
for node_id in event_loop_targets:
|
||||
node = self.get_node(node_id)
|
||||
for key in getattr(node, "output_keys", []):
|
||||
if key in seen_keys:
|
||||
errors.append(
|
||||
f"Fan-out from '{source_id}': event_loop nodes "
|
||||
f"'{seen_keys[key]}' and '{node_id}' both write to "
|
||||
f"output_key '{key}'. Parallel event_loop nodes must "
|
||||
f"have disjoint output_keys to prevent last-wins data loss."
|
||||
)
|
||||
else:
|
||||
seen_keys[key] = node_id
|
||||
|
||||
return errors
|
||||
|
||||
@@ -0,0 +1,879 @@
|
||||
"""EventLoopNode: Multi-turn LLM streaming loop with tool execution and judge evaluation.
|
||||
|
||||
Implements NodeProtocol and runs a streaming event loop:
|
||||
1. Calls LLMProvider.stream() to get streaming events
|
||||
2. Processes text deltas, tool calls, and finish events
|
||||
3. Executes tools and feeds results back to the conversation
|
||||
4. Uses judge evaluation (or implicit stop-reason) to decide loop termination
|
||||
5. Publishes lifecycle events to EventBus
|
||||
6. Persists conversation and outputs via write-through to ConversationStore
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from framework.graph.conversation import ConversationStore, NodeConversation
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult
|
||||
from framework.llm.provider import Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Judge protocol (simple 3-action interface for event loop evaluation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeVerdict:
|
||||
"""Result of judge evaluation for the event loop."""
|
||||
|
||||
action: Literal["ACCEPT", "RETRY", "ESCALATE"]
|
||||
feedback: str = ""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class JudgeProtocol(Protocol):
|
||||
"""Protocol for event-loop judges.
|
||||
|
||||
Implementations evaluate the current state of the event loop and
|
||||
decide whether to accept the output, retry with feedback, or escalate.
|
||||
"""
|
||||
|
||||
async def evaluate(self, context: dict[str, Any]) -> JudgeVerdict: ...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopConfig:
|
||||
"""Configuration for the event loop."""
|
||||
|
||||
max_iterations: int = 50
|
||||
max_tool_calls_per_turn: int = 10
|
||||
judge_every_n_turns: int = 1
|
||||
stall_detection_threshold: int = 3
|
||||
max_history_tokens: int = 32_000
|
||||
store_prefix: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output accumulator with write-through persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputAccumulator:
|
||||
"""Accumulates output key-value pairs with optional write-through persistence.
|
||||
|
||||
Values are stored in memory and optionally written through to a
|
||||
ConversationStore's cursor data for crash recovery.
|
||||
"""
|
||||
|
||||
values: dict[str, Any] = field(default_factory=dict)
|
||||
store: ConversationStore | None = None
|
||||
|
||||
async def set(self, key: str, value: Any) -> None:
|
||||
"""Set a key-value pair, persisting immediately if store is available."""
|
||||
self.values[key] = value
|
||||
if self.store:
|
||||
cursor = await self.store.read_cursor() or {}
|
||||
outputs = cursor.get("outputs", {})
|
||||
outputs[key] = value
|
||||
cursor["outputs"] = outputs
|
||||
await self.store.write_cursor(cursor)
|
||||
|
||||
def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key, or None if not present."""
|
||||
return self.values.get(key)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a copy of all accumulated values."""
|
||||
return dict(self.values)
|
||||
|
||||
def has_all_keys(self, required: list[str]) -> bool:
|
||||
"""Check if all required keys have been set (non-None)."""
|
||||
return all(key in self.values and self.values[key] is not None for key in required)
|
||||
|
||||
@classmethod
|
||||
async def restore(cls, store: ConversationStore) -> OutputAccumulator:
|
||||
"""Restore an OutputAccumulator from a store's cursor data."""
|
||||
cursor = await store.read_cursor()
|
||||
values = {}
|
||||
if cursor and "outputs" in cursor:
|
||||
values = cursor["outputs"]
|
||||
return cls(values=values, store=store)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EventLoopNode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EventLoopNode(NodeProtocol):
|
||||
"""Multi-turn LLM streaming loop with tool execution and judge evaluation.
|
||||
|
||||
Lifecycle:
|
||||
1. Try to restore from durable state (crash recovery)
|
||||
2. If no prior state, init from NodeSpec.system_prompt + input_keys
|
||||
3. Loop: drain injection queue -> stream LLM -> execute tools -> judge
|
||||
(each add_* and set_output writes through to store immediately)
|
||||
4. Publish events to EventBus at each stage
|
||||
5. Write cursor after each iteration
|
||||
6. Terminate when judge returns ACCEPT (or max iterations)
|
||||
7. Build output dict from OutputAccumulator
|
||||
|
||||
Always returns NodeResult with retryable=False semantics. The executor
|
||||
must NOT retry event loop nodes -- retry is handled internally by the
|
||||
judge (RETRY action continues the loop). See WP-7 enforcement.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_bus: EventBus | None = None,
|
||||
judge: JudgeProtocol | None = None,
|
||||
config: LoopConfig | None = None,
|
||||
tool_executor: Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None = None,
|
||||
conversation_store: ConversationStore | None = None,
|
||||
) -> None:
|
||||
self._event_bus = event_bus
|
||||
self._judge = judge
|
||||
self._config = config or LoopConfig()
|
||||
self._tool_executor = tool_executor
|
||||
self._conversation_store = conversation_store
|
||||
self._injection_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
def validate_input(self, ctx: NodeContext) -> list[str]:
|
||||
"""Validate hard requirements only.
|
||||
|
||||
Event loop nodes are LLM-powered and can reason about flexible input,
|
||||
so input_keys are treated as hints — not strict requirements.
|
||||
Only the LLM provider is a hard dependency.
|
||||
"""
|
||||
errors = []
|
||||
if ctx.llm is None:
|
||||
errors.append("LLM provider is required for EventLoopNode")
|
||||
return errors
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Public API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
"""Run the event loop."""
|
||||
start_time = time.time()
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
stream_id = ctx.node_id
|
||||
node_id = ctx.node_id
|
||||
|
||||
# 1. Guard: LLM required
|
||||
if ctx.llm is None:
|
||||
return NodeResult(success=False, error="LLM provider not available")
|
||||
|
||||
# 2. Restore or create new conversation + accumulator
|
||||
conversation, accumulator, start_iteration = await self._restore(ctx)
|
||||
if conversation is None:
|
||||
conversation = NodeConversation(
|
||||
system_prompt=ctx.node_spec.system_prompt or "",
|
||||
max_history_tokens=self._config.max_history_tokens,
|
||||
output_keys=ctx.node_spec.output_keys or None,
|
||||
store=self._conversation_store,
|
||||
)
|
||||
accumulator = OutputAccumulator(store=self._conversation_store)
|
||||
start_iteration = 0
|
||||
|
||||
# Add initial user message from input data
|
||||
initial_message = self._build_initial_message(ctx)
|
||||
if initial_message:
|
||||
await conversation.add_user_message(initial_message)
|
||||
|
||||
# 3. Build tool list: node tools + synthetic set_output tool
|
||||
tools = list(ctx.available_tools)
|
||||
set_output_tool = self._build_set_output_tool(ctx.node_spec.output_keys)
|
||||
if set_output_tool:
|
||||
tools.append(set_output_tool)
|
||||
|
||||
# 4. Publish loop started
|
||||
await self._publish_loop_started(stream_id, node_id)
|
||||
|
||||
# 5. Stall detection state
|
||||
recent_responses: list[str] = []
|
||||
|
||||
# 6. Main loop
|
||||
for iteration in range(start_iteration, self._config.max_iterations):
|
||||
# 6a. Check pause
|
||||
if await self._check_pause(ctx, conversation, iteration):
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
# 6b. Drain injection queue
|
||||
await self._drain_injection_queue(conversation)
|
||||
|
||||
# 6c. Publish iteration event
|
||||
await self._publish_iteration(stream_id, node_id, iteration)
|
||||
|
||||
# 6d. Pre-turn compaction check (tiered)
|
||||
if conversation.needs_compaction():
|
||||
await self._compact_tiered(ctx, conversation)
|
||||
|
||||
# 6e. Run single LLM turn
|
||||
assistant_text, tool_results_list, turn_tokens = await self._run_single_turn(
|
||||
ctx, conversation, tools, iteration, accumulator
|
||||
)
|
||||
total_input_tokens += turn_tokens.get("input", 0)
|
||||
total_output_tokens += turn_tokens.get("output", 0)
|
||||
|
||||
# 6e'. Feed actual API token count back for accurate estimation
|
||||
turn_input = turn_tokens.get("input", 0)
|
||||
if turn_input > 0:
|
||||
conversation.update_token_count(turn_input)
|
||||
|
||||
# 6e''. Post-turn compaction check (catches tool-result bloat)
|
||||
if conversation.needs_compaction():
|
||||
await self._compact_tiered(ctx, conversation)
|
||||
|
||||
# 6f. Stall detection
|
||||
recent_responses.append(assistant_text)
|
||||
if len(recent_responses) > self._config.stall_detection_threshold:
|
||||
recent_responses.pop(0)
|
||||
if self._is_stalled(recent_responses):
|
||||
await self._publish_stalled(stream_id, node_id)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=(
|
||||
f"Node stalled: {self._config.stall_detection_threshold} "
|
||||
"consecutive identical responses"
|
||||
),
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
# 6g. Write cursor checkpoint
|
||||
await self._write_cursor(ctx, conversation, accumulator, iteration)
|
||||
|
||||
# 6h. Judge evaluation
|
||||
should_judge = (
|
||||
(iteration + 1) % self._config.judge_every_n_turns == 0
|
||||
or not tool_results_list # no tool calls = natural stop
|
||||
)
|
||||
|
||||
if should_judge:
|
||||
verdict = await self._evaluate(
|
||||
ctx,
|
||||
conversation,
|
||||
accumulator,
|
||||
assistant_text,
|
||||
tool_results_list,
|
||||
iteration,
|
||||
)
|
||||
|
||||
if verdict.action == "ACCEPT":
|
||||
# Check for missing output keys
|
||||
missing = self._get_missing_output_keys(accumulator, ctx.node_spec.output_keys)
|
||||
if missing and self._judge is not None:
|
||||
hint = (
|
||||
f"Missing required output keys: {missing}. "
|
||||
"Use set_output to provide them."
|
||||
)
|
||||
await conversation.add_user_message(hint)
|
||||
continue
|
||||
|
||||
# Write outputs to shared memory
|
||||
for key, value in accumulator.to_dict().items():
|
||||
ctx.memory.write(key, value, validate=False)
|
||||
|
||||
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
elif verdict.action == "ESCALATE":
|
||||
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=f"Judge escalated: {verdict.feedback}",
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
elif verdict.action == "RETRY":
|
||||
if verdict.feedback:
|
||||
await conversation.add_user_message(f"[Judge feedback]: {verdict.feedback}")
|
||||
continue
|
||||
|
||||
# 7. Max iterations exhausted
|
||||
await self._publish_loop_completed(stream_id, node_id, self._config.max_iterations)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=(f"Max iterations ({self._config.max_iterations}) reached without acceptance"),
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
async def inject_event(self, content: str) -> None:
|
||||
"""Inject an external event into the running loop.
|
||||
|
||||
The content becomes a user message prepended to the next iteration.
|
||||
Thread-safe via asyncio.Queue.
|
||||
"""
|
||||
await self._injection_queue.put(content)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Single LLM turn with caller-managed tool orchestration
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def _run_single_turn(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
tools: list[Tool],
|
||||
iteration: int,
|
||||
accumulator: OutputAccumulator,
|
||||
) -> tuple[str, list[dict], dict[str, int]]:
|
||||
"""Run a single LLM turn with streaming and tool execution.
|
||||
|
||||
Returns (assistant_text, tool_results, token_counts).
|
||||
"""
|
||||
stream_id = ctx.node_id
|
||||
node_id = ctx.node_id
|
||||
token_counts: dict[str, int] = {"input": 0, "output": 0}
|
||||
tool_call_count = 0
|
||||
final_text = ""
|
||||
|
||||
# Inner tool loop: stream may produce tool calls requiring re-invocation
|
||||
while True:
|
||||
# Pre-send guard: if context is at or over budget, compact before
|
||||
# calling the LLM — prevents API context-length errors.
|
||||
if conversation.usage_ratio() >= 1.0:
|
||||
logger.warning(
|
||||
"Pre-send guard: context at %.0f%% of budget, compacting",
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
await self._compact_tiered(ctx, conversation)
|
||||
|
||||
messages = conversation.to_llm_messages()
|
||||
accumulated_text = ""
|
||||
tool_calls: list[ToolCallEvent] = []
|
||||
|
||||
# Stream LLM response
|
||||
async for event in ctx.llm.stream(
|
||||
messages=messages,
|
||||
system=conversation.system_prompt,
|
||||
tools=tools if tools else None,
|
||||
max_tokens=ctx.max_tokens,
|
||||
):
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
accumulated_text = event.snapshot
|
||||
await self._publish_text_delta(
|
||||
stream_id, node_id, event.content, event.snapshot, ctx
|
||||
)
|
||||
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
tool_calls.append(event)
|
||||
|
||||
elif isinstance(event, FinishEvent):
|
||||
token_counts["input"] += event.input_tokens
|
||||
token_counts["output"] += event.output_tokens
|
||||
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
if not event.recoverable:
|
||||
raise RuntimeError(f"Stream error: {event.error}")
|
||||
logger.warning(f"Recoverable stream error: {event.error}")
|
||||
|
||||
final_text = accumulated_text
|
||||
|
||||
# Record assistant message (write-through via conversation store)
|
||||
tc_dicts = None
|
||||
if tool_calls:
|
||||
tc_dicts = [
|
||||
{
|
||||
"id": tc.tool_use_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.tool_name,
|
||||
"arguments": json.dumps(tc.tool_input),
|
||||
},
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
await conversation.add_assistant_message(
|
||||
content=accumulated_text,
|
||||
tool_calls=tc_dicts,
|
||||
)
|
||||
|
||||
# If no tool calls, turn is complete
|
||||
if not tool_calls:
|
||||
return final_text, [], token_counts
|
||||
|
||||
# Execute tool calls
|
||||
tool_results: list[dict] = []
|
||||
for tc in tool_calls:
|
||||
tool_call_count += 1
|
||||
if tool_call_count > self._config.max_tool_calls_per_turn:
|
||||
logger.warning(
|
||||
f"Max tool calls per turn ({self._config.max_tool_calls_per_turn}) exceeded"
|
||||
)
|
||||
break
|
||||
|
||||
# Publish tool call started
|
||||
await self._publish_tool_started(
|
||||
stream_id, node_id, tc.tool_use_id, tc.tool_name, tc.tool_input
|
||||
)
|
||||
|
||||
# Handle set_output synthetic tool
|
||||
if tc.tool_name == "set_output":
|
||||
result = self._handle_set_output(tc.tool_input, ctx.node_spec.output_keys)
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
)
|
||||
# Async write-through for set_output
|
||||
if not result.is_error:
|
||||
await accumulator.set(tc.tool_input["key"], tc.tool_input["value"])
|
||||
else:
|
||||
# Execute real tool
|
||||
result = await self._execute_tool(tc)
|
||||
|
||||
# Record tool result in conversation (write-through)
|
||||
await conversation.add_tool_result(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
)
|
||||
tool_results.append(
|
||||
{
|
||||
"tool_use_id": tc.tool_use_id,
|
||||
"tool_name": tc.tool_name,
|
||||
"content": result.content,
|
||||
"is_error": result.is_error,
|
||||
}
|
||||
)
|
||||
|
||||
# Publish tool call completed
|
||||
await self._publish_tool_completed(
|
||||
stream_id,
|
||||
node_id,
|
||||
tc.tool_use_id,
|
||||
tc.tool_name,
|
||||
result.content,
|
||||
result.is_error,
|
||||
)
|
||||
|
||||
# Tool calls processed -- loop back to stream with updated conversation
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# set_output synthetic tool
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def _build_set_output_tool(self, output_keys: list[str] | None) -> Tool | None:
|
||||
"""Build the synthetic set_output tool for explicit output declaration."""
|
||||
if not output_keys:
|
||||
return None
|
||||
return Tool(
|
||||
name="set_output",
|
||||
description=(
|
||||
"Set an output value for this node. Call once per output key. "
|
||||
f"Valid keys: {output_keys}"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": f"Output key. Must be one of: {output_keys}",
|
||||
"enum": output_keys,
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "The output value to store.",
|
||||
},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_set_output(
|
||||
self,
|
||||
tool_input: dict[str, Any],
|
||||
output_keys: list[str] | None,
|
||||
) -> ToolResult:
|
||||
"""Handle set_output tool call. Returns ToolResult (sync)."""
|
||||
key = tool_input.get("key", "")
|
||||
valid_keys = output_keys or []
|
||||
|
||||
if key not in valid_keys:
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Invalid output key '{key}'. Valid keys: {valid_keys}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Output '{key}' set successfully.",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Judge evaluation
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def _evaluate(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator,
|
||||
assistant_text: str,
|
||||
tool_results: list[dict],
|
||||
iteration: int,
|
||||
) -> JudgeVerdict:
|
||||
"""Evaluate the current state using judge or implicit logic."""
|
||||
if self._judge is not None:
|
||||
context = {
|
||||
"assistant_text": assistant_text,
|
||||
"tool_calls": tool_results,
|
||||
"output_accumulator": accumulator.to_dict(),
|
||||
"iteration": iteration,
|
||||
"conversation_summary": conversation.export_summary(),
|
||||
"output_keys": ctx.node_spec.output_keys,
|
||||
"missing_keys": self._get_missing_output_keys(
|
||||
accumulator, ctx.node_spec.output_keys
|
||||
),
|
||||
}
|
||||
return await self._judge.evaluate(context)
|
||||
|
||||
# Implicit judge: accept when no tool calls and all output keys present
|
||||
if not tool_results:
|
||||
missing = self._get_missing_output_keys(accumulator, ctx.node_spec.output_keys)
|
||||
if not missing:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
else:
|
||||
return JudgeVerdict(
|
||||
action="RETRY",
|
||||
feedback=(
|
||||
f"Missing output keys: {missing}. Use set_output tool to provide them."
|
||||
),
|
||||
)
|
||||
|
||||
# Tool calls were made -- continue loop
|
||||
return JudgeVerdict(action="RETRY", feedback="")
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def _build_initial_message(self, ctx: NodeContext) -> str:
|
||||
"""Build the initial user message from input data and memory.
|
||||
|
||||
Includes ALL input_data (not just declared input_keys) so that
|
||||
upstream handoff data flows through regardless of key naming.
|
||||
Declared input_keys are also checked in shared memory as fallback.
|
||||
"""
|
||||
parts = []
|
||||
seen: set[str] = set()
|
||||
# Include everything from input_data (flexible handoff)
|
||||
for key, value in ctx.input_data.items():
|
||||
if value is not None:
|
||||
parts.append(f"{key}: {value}")
|
||||
seen.add(key)
|
||||
# Fallback: check memory for declared input_keys not already covered
|
||||
for key in ctx.node_spec.input_keys:
|
||||
if key not in seen:
|
||||
value = ctx.memory.read(key)
|
||||
if value is not None:
|
||||
parts.append(f"{key}: {value}")
|
||||
if ctx.goal_context:
|
||||
parts.append(f"\nGoal: {ctx.goal_context}")
|
||||
return "\n".join(parts) if parts else "Begin."
|
||||
|
||||
def _get_missing_output_keys(
|
||||
self,
|
||||
accumulator: OutputAccumulator,
|
||||
output_keys: list[str] | None,
|
||||
) -> list[str]:
|
||||
"""Return output keys that have not been set yet."""
|
||||
if not output_keys:
|
||||
return []
|
||||
return [k for k in output_keys if accumulator.get(k) is None]
|
||||
|
||||
def _is_stalled(self, recent_responses: list[str]) -> bool:
|
||||
"""Detect stall: N consecutive identical non-empty responses."""
|
||||
if len(recent_responses) < self._config.stall_detection_threshold:
|
||||
return False
|
||||
if not recent_responses[0]:
|
||||
return False
|
||||
return all(r == recent_responses[0] for r in recent_responses)
|
||||
|
||||
async def _execute_tool(self, tc: ToolCallEvent) -> ToolResult:
|
||||
"""Execute a tool call, handling both sync and async executors."""
|
||||
if self._tool_executor is None:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"No tool executor configured for '{tc.tool_name}'",
|
||||
is_error=True,
|
||||
)
|
||||
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
|
||||
result = self._tool_executor(tool_use)
|
||||
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
async def _compact_tiered(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
) -> None:
|
||||
"""Run compaction with aggressiveness scaled to usage level.
|
||||
|
||||
| Usage | Strategy |
|
||||
|----------------|---------------------------------------------|
|
||||
| 80-100% | Normal: LLM summary, keep 4 recent messages |
|
||||
| 100-120% | Aggressive: LLM summary, keep 2 recent |
|
||||
| >= 120% | Emergency: static summary, keep 1 recent |
|
||||
"""
|
||||
ratio = conversation.usage_ratio()
|
||||
|
||||
if ratio >= 1.2:
|
||||
# Emergency -- don't risk another LLM call on a bloated context
|
||||
logger.warning("Emergency compaction triggered (usage %.0f%%)", ratio * 100)
|
||||
await conversation.compact(
|
||||
"Previous conversation context (emergency compaction).",
|
||||
keep_recent=1,
|
||||
)
|
||||
elif ratio >= 1.0:
|
||||
logger.info("Aggressive compaction triggered (usage %.0f%%)", ratio * 100)
|
||||
summary = await self._generate_compaction_summary(ctx, conversation)
|
||||
await conversation.compact(summary, keep_recent=2)
|
||||
else:
|
||||
summary = await self._generate_compaction_summary(ctx, conversation)
|
||||
await conversation.compact(summary, keep_recent=4)
|
||||
|
||||
async def _generate_compaction_summary(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
) -> str:
|
||||
"""Use LLM to generate a conversation summary for compaction."""
|
||||
messages_text = "\n".join(
|
||||
f"[{m.role}]: {m.content[:200]}" for m in conversation.messages[-10:]
|
||||
)
|
||||
prompt = (
|
||||
"Summarize this conversation so far in 2-3 sentences, "
|
||||
"preserving key decisions and results:\n\n"
|
||||
f"{messages_text}"
|
||||
)
|
||||
try:
|
||||
response = ctx.llm.complete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="Summarize conversations concisely.",
|
||||
max_tokens=300,
|
||||
)
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.warning(f"Compaction summary generation failed: {e}")
|
||||
return "Previous conversation context (summary unavailable)."
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Persistence: restore, cursor, injection, pause
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def _restore(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
) -> tuple[NodeConversation | None, OutputAccumulator | None, int]:
|
||||
"""Attempt to restore from a previous checkpoint."""
|
||||
if self._conversation_store is None:
|
||||
return None, None, 0
|
||||
|
||||
conversation = await NodeConversation.restore(self._conversation_store)
|
||||
if conversation is None:
|
||||
return None, None, 0
|
||||
|
||||
accumulator = await OutputAccumulator.restore(self._conversation_store)
|
||||
|
||||
cursor = await self._conversation_store.read_cursor()
|
||||
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
|
||||
|
||||
logger.info(
|
||||
f"Restored event loop: iteration={start_iteration}, "
|
||||
f"messages={conversation.message_count}, "
|
||||
f"outputs={list(accumulator.values.keys())}"
|
||||
)
|
||||
return conversation, accumulator, start_iteration
|
||||
|
||||
async def _write_cursor(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator,
|
||||
iteration: int,
|
||||
) -> None:
|
||||
"""Write checkpoint cursor for crash recovery."""
|
||||
if self._conversation_store:
|
||||
cursor = await self._conversation_store.read_cursor() or {}
|
||||
cursor.update(
|
||||
{
|
||||
"iteration": iteration,
|
||||
"node_id": ctx.node_id,
|
||||
"next_seq": conversation.next_seq,
|
||||
"outputs": accumulator.to_dict(),
|
||||
}
|
||||
)
|
||||
await self._conversation_store.write_cursor(cursor)
|
||||
|
||||
async def _drain_injection_queue(self, conversation: NodeConversation) -> int:
|
||||
"""Drain all pending injected events as user messages. Returns count."""
|
||||
count = 0
|
||||
while not self._injection_queue.empty():
|
||||
try:
|
||||
content = self._injection_queue.get_nowait()
|
||||
await conversation.add_user_message(f"[External event]: {content}")
|
||||
count += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
return count
|
||||
|
||||
async def _check_pause(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
iteration: int,
|
||||
) -> bool:
|
||||
"""Check if pause has been requested. Returns True if paused."""
|
||||
pause_requested = ctx.input_data.get("pause_requested", False)
|
||||
if not pause_requested:
|
||||
pause_requested = ctx.memory.read("pause_requested") or False
|
||||
if pause_requested:
|
||||
logger.info(f"Pause requested at iteration {iteration}")
|
||||
return True
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# EventBus publishing helpers
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def _publish_loop_started(self, stream_id: str, node_id: str) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_loop_started(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
max_iterations=self._config.max_iterations,
|
||||
)
|
||||
|
||||
async def _publish_iteration(self, stream_id: str, node_id: str, iteration: int) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_loop_iteration(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
async def _publish_loop_completed(self, stream_id: str, node_id: str, iterations: int) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_loop_completed(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
async def _publish_stalled(self, stream_id: str, node_id: str) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_stalled(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
reason="Consecutive identical responses detected",
|
||||
)
|
||||
|
||||
async def _publish_text_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
ctx: NodeContext,
|
||||
) -> None:
|
||||
if self._event_bus:
|
||||
if ctx.node_spec.client_facing:
|
||||
await self._event_bus.emit_client_output_delta(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
else:
|
||||
await self._event_bus.emit_llm_text_delta(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
async def _publish_tool_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_tool_call_started(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
)
|
||||
|
||||
async def _publish_tool_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
is_error: bool,
|
||||
) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_tool_call_completed(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
result=result,
|
||||
is_error=is_error,
|
||||
)
|
||||
@@ -11,6 +11,7 @@ The executor:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
@@ -380,6 +381,15 @@ class GraphExecutor:
|
||||
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
|
||||
max_retries = getattr(node_spec, "max_retries", 3)
|
||||
|
||||
# Event loop nodes handle retry internally via judge —
|
||||
# executor retry is catastrophic (retry multiplication)
|
||||
if node_spec.node_type == "event_loop" and max_retries > 0:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has max_retries={max_retries}. "
|
||||
"Overriding to 0 — event loop nodes handle retry internally via judge."
|
||||
)
|
||||
max_retries = 0
|
||||
|
||||
if node_retry_counts[current_node_id] < max_retries:
|
||||
# Retry - don't increment steps for retries
|
||||
steps -= 1
|
||||
@@ -658,7 +668,15 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
VALID_NODE_TYPES = {
|
||||
"llm_tool_use",
|
||||
"llm_generate",
|
||||
"router",
|
||||
"function",
|
||||
"human_input",
|
||||
"event_loop",
|
||||
}
|
||||
DEPRECATED_NODE_TYPES = {"llm_tool_use": "event_loop", "llm_generate": "event_loop"}
|
||||
|
||||
def _get_node_implementation(
|
||||
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
|
||||
@@ -676,6 +694,17 @@ class GraphExecutor:
|
||||
f"Use 'llm_tool_use' for nodes that call tools, 'llm_generate' for text generation."
|
||||
)
|
||||
|
||||
# Warn on deprecated node types
|
||||
if node_spec.node_type in self.DEPRECATED_NODE_TYPES:
|
||||
replacement = self.DEPRECATED_NODE_TYPES[node_spec.node_type]
|
||||
warnings.warn(
|
||||
f"Node type '{node_spec.node_type}' is deprecated. "
|
||||
f"Use '{replacement}' instead. "
|
||||
f"Node: '{node_spec.id}'",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Create based on type
|
||||
if node_spec.node_type == "llm_tool_use":
|
||||
if not node_spec.tools:
|
||||
@@ -713,6 +742,13 @@ class GraphExecutor:
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "event_loop":
|
||||
# Event loop nodes must be pre-registered (like function nodes)
|
||||
raise RuntimeError(
|
||||
f"EventLoopNode '{node_spec.id}' not found in registry. "
|
||||
"Register it with executor.register_node() before execution."
|
||||
)
|
||||
|
||||
# Should never reach here due to validation above
|
||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||
|
||||
@@ -909,6 +945,17 @@ class GraphExecutor:
|
||||
branch.status = "failed"
|
||||
branch.error = f"Node {branch.node_id} not found in graph"
|
||||
return branch, RuntimeError(branch.error)
|
||||
|
||||
effective_max_retries = node_spec.max_retries
|
||||
if node_spec.node_type == "event_loop":
|
||||
if effective_max_retries > 1:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has "
|
||||
f"max_retries={effective_max_retries}. Overriding "
|
||||
"to 1 — event loop nodes handle retry internally."
|
||||
)
|
||||
effective_max_retries = 1
|
||||
|
||||
branch.status = "running"
|
||||
|
||||
try:
|
||||
@@ -942,7 +989,7 @@ class GraphExecutor:
|
||||
|
||||
# Execute with retries
|
||||
last_result = None
|
||||
for attempt in range(node_spec.max_retries):
|
||||
for attempt in range(effective_max_retries):
|
||||
branch.retry_count = attempt
|
||||
|
||||
# Build context for this branch
|
||||
@@ -970,7 +1017,7 @@ class GraphExecutor:
|
||||
|
||||
self.logger.warning(
|
||||
f" ↻ Branch {node_spec.name}: "
|
||||
f"retry {attempt + 1}/{node_spec.max_retries}"
|
||||
f"retry {attempt + 1}/{effective_max_retries}"
|
||||
)
|
||||
|
||||
# All retries exhausted
|
||||
@@ -979,7 +1026,7 @@ class GraphExecutor:
|
||||
branch.result = last_result
|
||||
self.logger.error(
|
||||
f" ✗ Branch {node_spec.name}: "
|
||||
f"failed after {node_spec.max_retries} attempts"
|
||||
f"failed after {effective_max_retries} attempts"
|
||||
)
|
||||
return branch, last_result
|
||||
|
||||
|
||||
@@ -153,7 +153,10 @@ class NodeSpec(BaseModel):
|
||||
# Node behavior type
|
||||
node_type: str = Field(
|
||||
default="llm_tool_use",
|
||||
description="Type: 'llm_tool_use', 'llm_generate', 'function', 'router', 'human_input'",
|
||||
description=(
|
||||
"Type: 'event_loop', 'function', 'router', 'human_input'. "
|
||||
"Deprecated: 'llm_tool_use', 'llm_generate' (use 'event_loop' instead)."
|
||||
),
|
||||
)
|
||||
|
||||
# Data flow
|
||||
@@ -218,6 +221,12 @@ class NodeSpec(BaseModel):
|
||||
description="Maximum retries when Pydantic validation fails (with feedback to LLM)",
|
||||
)
|
||||
|
||||
# Client-facing behavior
|
||||
client_facing: bool = Field(
|
||||
default=False,
|
||||
description="If True, this node streams output to the end user and can request input.",
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
@@ -1348,7 +1357,9 @@ Expected output keys: {output_keys}
|
||||
LLM Response:
|
||||
{raw_response}
|
||||
|
||||
Output ONLY the JSON object, nothing else."""
|
||||
Output ONLY the JSON object, nothing else.
|
||||
If no valid JSON object exists in the response, output exactly: {{"error": "NO_JSON_FOUND"}}
|
||||
Do NOT fabricate data or return empty objects."""
|
||||
|
||||
try:
|
||||
result = cleaner_llm.complete(
|
||||
@@ -1395,6 +1406,14 @@ Output ONLY the JSON object, nothing else."""
|
||||
parsed = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
parsed = json.loads(_fix_unescaped_newlines_in_json(cleaned))
|
||||
|
||||
# Validate LLM didn't return empty or fabricated data
|
||||
if parsed.get("error") == "NO_JSON_FOUND":
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
if not parsed or parsed == {}:
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
if all(v is None for v in parsed.values()):
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
logger.info(" ✓ LLM cleaned JSON output")
|
||||
return parsed
|
||||
|
||||
|
||||
@@ -1,8 +1,31 @@
|
||||
"""LLM provider abstraction."""
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse"]
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"StreamEvent",
|
||||
"TextDeltaEvent",
|
||||
"TextEndEvent",
|
||||
"ToolCallEvent",
|
||||
"ToolResultEvent",
|
||||
"ReasoningStartEvent",
|
||||
"ReasoningDeltaEvent",
|
||||
"FinishEvent",
|
||||
"StreamErrorEvent",
|
||||
]
|
||||
|
||||
try:
|
||||
from framework.llm.anthropic import AnthropicProvider # noqa: F401
|
||||
|
||||
@@ -7,10 +7,11 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,6 +24,7 @@ except ImportError:
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -425,3 +427,167 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a completion via litellm.acompletion(stream=True).
|
||||
|
||||
Yields StreamEvent objects as chunks arrive from the provider.
|
||||
Tool call arguments are accumulated across chunks and yielded as
|
||||
a single ToolCallEvent with fully parsed JSON when complete.
|
||||
|
||||
Empty responses (e.g. Gemini stealth rate-limits that return 200
|
||||
with no content) are retried with exponential backoff, mirroring
|
||||
the retry behaviour of ``_completion_with_rate_limit_retry``.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
buffered_events: list[StreamEvent] = []
|
||||
accumulated_text = ""
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
async for chunk in response:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# --- Text content ---
|
||||
if delta and delta.content:
|
||||
accumulated_text += delta.content
|
||||
buffered_events.append(
|
||||
TextDeltaEvent(
|
||||
content=delta.content,
|
||||
snapshot=accumulated_text,
|
||||
)
|
||||
)
|
||||
|
||||
# --- Tool calls (accumulate across chunks) ---
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function:
|
||||
if tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
|
||||
# --- Finish ---
|
||||
if choice.finish_reason:
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
buffered_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
tool_name=tc_data["name"],
|
||||
tool_input=parsed_args,
|
||||
)
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
buffered_events.append(TextEndEvent(full_text=accumulated_text))
|
||||
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
buffered_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether the stream produced any real content.
|
||||
has_content = accumulated_text or tool_calls_acc
|
||||
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
)
|
||||
dump_path = _dump_failed_request(
|
||||
model=self.model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_stream",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} returned empty stream — "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
# Success (or final attempt) — flush buffered events.
|
||||
for event in buffered_events:
|
||||
yield event
|
||||
return
|
||||
|
||||
except RateLimitError as e:
|
||||
if attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
@@ -2,10 +2,16 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
|
||||
class MockLLMProvider(LLMProvider):
|
||||
@@ -175,3 +181,28 @@ class MockLLMProvider(LLMProvider):
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a mock completion as word-level TextDeltaEvents.
|
||||
|
||||
Splits the mock response into words and yields each as a separate
|
||||
TextDeltaEvent with an accumulating snapshot, exercising the full
|
||||
streaming pipeline without any API calls.
|
||||
"""
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
|
||||
for i, word in enumerate(words):
|
||||
chunk = word if i == 0 else " " + word
|
||||
accumulated += chunk
|
||||
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
|
||||
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(stop_reason="mock_complete", model=self.model)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -108,3 +108,45 @@ class LLMProvider(ABC):
|
||||
Final LLMResponse after tool use completes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator["StreamEvent"]:
|
||||
"""
|
||||
Stream a completion as an async iterator of StreamEvents.
|
||||
|
||||
Default implementation wraps complete() with synthetic events.
|
||||
Subclasses SHOULD override for true streaming.
|
||||
|
||||
Tool orchestration is the CALLER's responsibility:
|
||||
- Caller detects ToolCallEvent, executes tool, adds result
|
||||
to messages, calls stream() again.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
response = self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
|
||||
# Deferred import target for type annotation
|
||||
from framework.llm.stream_events import StreamEvent as StreamEvent # noqa: E402, F401
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Stream event types for LLM streaming responses.
|
||||
|
||||
Defines a discriminated union of frozen dataclasses representing every event
|
||||
a streaming LLM call can produce. These types form the contract between the
|
||||
LLM provider layer, EventLoopNode, event bus, persistence, and monitoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextDeltaEvent:
|
||||
"""A chunk of text produced by the LLM."""
|
||||
|
||||
type: Literal["text_delta"] = "text_delta"
|
||||
content: str = "" # this chunk's text
|
||||
snapshot: str = "" # accumulated text so far
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextEndEvent:
|
||||
"""Signals that text generation is complete."""
|
||||
|
||||
type: Literal["text_end"] = "text_end"
|
||||
full_text: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolCallEvent:
|
||||
"""The LLM has requested a tool call."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
tool_use_id: str = ""
|
||||
tool_name: str = ""
|
||||
tool_input: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolResultEvent:
|
||||
"""Result of executing a tool call."""
|
||||
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str = ""
|
||||
content: str = ""
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningStartEvent:
|
||||
"""The LLM has started a reasoning/thinking block."""
|
||||
|
||||
type: Literal["reasoning_start"] = "reasoning_start"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningDeltaEvent:
|
||||
"""A chunk of reasoning/thinking content."""
|
||||
|
||||
type: Literal["reasoning_delta"] = "reasoning_delta"
|
||||
content: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinishEvent:
|
||||
"""The LLM has finished generating."""
|
||||
|
||||
type: Literal["finish"] = "finish"
|
||||
stop_reason: str = ""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamErrorEvent:
|
||||
"""An error occurred during streaming."""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: str = ""
|
||||
recoverable: bool = False
|
||||
|
||||
|
||||
# Discriminated union of all stream event types
|
||||
StreamEvent = (
|
||||
TextDeltaEvent
|
||||
| TextEndEvent
|
||||
| ToolCallEvent
|
||||
| ToolResultEvent
|
||||
| ReasoningStartEvent
|
||||
| ReasoningDeltaEvent
|
||||
| FinishEvent
|
||||
| StreamErrorEvent
|
||||
)
|
||||
@@ -41,6 +41,28 @@ class EventType(str, Enum):
|
||||
STREAM_STARTED = "stream_started"
|
||||
STREAM_STOPPED = "stream_stopped"
|
||||
|
||||
# Node event-loop lifecycle
|
||||
NODE_LOOP_STARTED = "node_loop_started"
|
||||
NODE_LOOP_ITERATION = "node_loop_iteration"
|
||||
NODE_LOOP_COMPLETED = "node_loop_completed"
|
||||
|
||||
# LLM streaming observability
|
||||
LLM_TEXT_DELTA = "llm_text_delta"
|
||||
LLM_REASONING_DELTA = "llm_reasoning_delta"
|
||||
|
||||
# Tool lifecycle
|
||||
TOOL_CALL_STARTED = "tool_call_started"
|
||||
TOOL_CALL_COMPLETED = "tool_call_completed"
|
||||
|
||||
# Client I/O (client_facing=True nodes only)
|
||||
CLIENT_OUTPUT_DELTA = "client_output_delta"
|
||||
CLIENT_INPUT_REQUESTED = "client_input_requested"
|
||||
|
||||
# Internal node observability (client_facing=False nodes)
|
||||
NODE_INTERNAL_OUTPUT = "node_internal_output"
|
||||
NODE_INPUT_BLOCKED = "node_input_blocked"
|
||||
NODE_STALLED = "node_stalled"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
|
||||
@@ -51,6 +73,7 @@ class AgentEvent:
|
||||
|
||||
type: EventType
|
||||
stream_id: str
|
||||
node_id: str | None = None # Which node emitted this event
|
||||
execution_id: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
@@ -61,6 +84,7 @@ class AgentEvent:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"stream_id": self.stream_id,
|
||||
"node_id": self.node_id,
|
||||
"execution_id": self.execution_id,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
@@ -80,6 +104,7 @@ class Subscription:
|
||||
event_types: set[EventType]
|
||||
handler: EventHandler
|
||||
filter_stream: str | None = None # Only receive events from this stream
|
||||
filter_node: str | None = None # Only receive events from this node
|
||||
filter_execution: str | None = None # Only receive events from this execution
|
||||
|
||||
|
||||
@@ -138,6 +163,7 @@ class EventBus:
|
||||
event_types: list[EventType],
|
||||
handler: EventHandler,
|
||||
filter_stream: str | None = None,
|
||||
filter_node: str | None = None,
|
||||
filter_execution: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -147,6 +173,7 @@ class EventBus:
|
||||
event_types: Types of events to receive
|
||||
handler: Async function to call when event occurs
|
||||
filter_stream: Only receive events from this stream
|
||||
filter_node: Only receive events from this node
|
||||
filter_execution: Only receive events from this execution
|
||||
|
||||
Returns:
|
||||
@@ -160,6 +187,7 @@ class EventBus:
|
||||
event_types=set(event_types),
|
||||
handler=handler,
|
||||
filter_stream=filter_stream,
|
||||
filter_node=filter_node,
|
||||
filter_execution=filter_execution,
|
||||
)
|
||||
|
||||
@@ -218,6 +246,10 @@ class EventBus:
|
||||
if subscription.filter_stream and subscription.filter_stream != event.stream_id:
|
||||
return False
|
||||
|
||||
# Check node filter
|
||||
if subscription.filter_node and subscription.filter_node != event.node_id:
|
||||
return False
|
||||
|
||||
# Check execution filter
|
||||
if subscription.filter_execution and subscription.filter_execution != event.execution_id:
|
||||
return False
|
||||
@@ -359,6 +391,248 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
# === NODE EVENT-LOOP PUBLISHERS ===
|
||||
|
||||
async def emit_node_loop_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str | None = None,
|
||||
max_iterations: int | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop started event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"max_iterations": max_iterations},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_loop_iteration(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iteration: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop iteration event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_ITERATION,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"iteration": iteration},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_loop_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iterations: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop completed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"iterations": iterations},
|
||||
)
|
||||
)
|
||||
|
||||
# === LLM STREAMING PUBLISHERS ===
|
||||
|
||||
async def emit_llm_text_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM text delta event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_llm_reasoning_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM reasoning delta event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_REASONING_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content},
|
||||
)
|
||||
)
|
||||
|
||||
# === TOOL LIFECYCLE PUBLISHERS ===
|
||||
|
||||
async def emit_tool_call_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any] | None = None,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool call started event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_use_id": tool_use_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_input": tool_input or {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_tool_call_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
result: str = "",
|
||||
is_error: bool = False,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool call completed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_use_id": tool_use_id,
|
||||
"tool_name": tool_name,
|
||||
"result": result,
|
||||
"is_error": is_error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === CLIENT I/O PUBLISHERS ===
|
||||
|
||||
async def emit_client_output_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit client output delta event (client_facing=True nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_OUTPUT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_client_input_requested(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
prompt: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit client input requested event (client_facing=True nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"prompt": prompt},
|
||||
)
|
||||
)
|
||||
|
||||
# === INTERNAL NODE PUBLISHERS ===
|
||||
|
||||
async def emit_node_internal_output(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node internal output event (client_facing=False nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_INTERNAL_OUTPUT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_stalled(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
reason: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node stalled event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_STALLED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"reason": reason},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_input_blocked(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
prompt: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node input blocked event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_INPUT_BLOCKED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"prompt": prompt},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
@@ -410,6 +684,7 @@ class EventBus:
|
||||
self,
|
||||
event_type: EventType,
|
||||
stream_id: str | None = None,
|
||||
node_id: str | None = None,
|
||||
execution_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AgentEvent | None:
|
||||
@@ -419,6 +694,7 @@ class EventBus:
|
||||
Args:
|
||||
event_type: Type of event to wait for
|
||||
stream_id: Filter by stream
|
||||
node_id: Filter by node
|
||||
execution_id: Filter by execution
|
||||
timeout: Maximum time to wait (seconds)
|
||||
|
||||
@@ -438,6 +714,7 @@ class EventBus:
|
||||
event_types=[event_type],
|
||||
handler=handler,
|
||||
filter_stream=stream_id,
|
||||
filter_node=node_id,
|
||||
filter_execution=execution_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Tests for client-facing fan-out and event_loop output_key overlap validation.
|
||||
|
||||
Validates two rules added to GraphSpec.validate():
|
||||
1. Fan-out must not have multiple client_facing=True targets.
|
||||
2. Parallel event_loop nodes must have disjoint output_keys.
|
||||
"""
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rule 1: client_facing fan-out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClientFacingFanOut:
|
||||
"""Fan-out to multiple client_facing=True targets must be rejected."""
|
||||
|
||||
def test_fan_out_two_client_facing_fails(self):
|
||||
"""Two client-facing targets on the same fan-out -> error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
|
||||
NodeSpec(id="b", name="b", description="Node b", client_facing=True),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
cf_errors = [e for e in errors if "multiple client-facing" in e]
|
||||
assert len(cf_errors) == 1
|
||||
assert "'src'" in cf_errors[0]
|
||||
|
||||
def test_fan_out_one_client_facing_passes(self):
|
||||
"""Only one client-facing target -> no error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
|
||||
NodeSpec(id="b", name="b", description="Node b", client_facing=False),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
cf_errors = [e for e in errors if "multiple client-facing" in e]
|
||||
assert len(cf_errors) == 0
|
||||
|
||||
def test_fan_out_zero_client_facing_passes(self):
|
||||
"""No client-facing targets at all -> no error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(id="a", name="a", description="Node a"),
|
||||
NodeSpec(id="b", name="b", description="Node b"),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
cf_errors = [e for e in errors if "multiple client-facing" in e]
|
||||
assert len(cf_errors) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rule 2: event_loop output_key overlap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEventLoopOutputKeyOverlap:
|
||||
"""Parallel event_loop nodes with overlapping output_keys must be rejected."""
|
||||
|
||||
def test_overlapping_output_keys_event_loop_fails(self):
|
||||
"""Two event_loop nodes sharing an output_key -> error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="a",
|
||||
description="Node a",
|
||||
node_type="event_loop",
|
||||
output_keys=["status", "shared"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="event_loop",
|
||||
output_keys=["result", "shared"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(key_errors) == 1
|
||||
assert "'shared'" in key_errors[0]
|
||||
|
||||
def test_disjoint_output_keys_event_loop_passes(self):
|
||||
"""Two event_loop nodes with disjoint output_keys -> no error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="a",
|
||||
description="Node a",
|
||||
node_type="event_loop",
|
||||
output_keys=["status"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(key_errors) == 0
|
||||
|
||||
def test_overlapping_keys_non_event_loop_no_error(self):
|
||||
"""Non-event_loop nodes with overlapping keys -> no error (last-wins OK)."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="a",
|
||||
description="Node a",
|
||||
node_type="llm_generate",
|
||||
output_keys=["shared"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="llm_generate",
|
||||
output_keys=["shared"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(key_errors) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Baseline: no fan-out -> no errors from these rules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoFanOutUnaffected:
|
||||
"""Linear graphs should not trigger either validation rule."""
|
||||
|
||||
def test_no_fan_out_unaffected(self):
|
||||
"""Linear chain with client_facing and event_loop nodes -> no errors."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="a",
|
||||
terminal_nodes=["c"],
|
||||
nodes=[
|
||||
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="event_loop",
|
||||
output_keys=["x"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="c",
|
||||
name="c",
|
||||
description="Node c",
|
||||
client_facing=True,
|
||||
node_type="event_loop",
|
||||
output_keys=["x"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="a->b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b->c", source="b", target="c", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
cf_errors = [e for e in errors if "multiple client-facing" in e]
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(cf_errors) == 0
|
||||
assert len(key_errors) == 0
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Tests for ClientIO gateway (WP-9).
|
||||
|
||||
Covers:
|
||||
- ActiveNodeClientIO: emit_output → output_stream round-trip, request_input, timeout
|
||||
- InertNodeClientIO: emit_output publishes NODE_INTERNAL_OUTPUT, request_input returns redirect
|
||||
- ClientIOGateway: factory creates correct variant
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.client_io import (
|
||||
ActiveNodeClientIO,
|
||||
ClientIOGateway,
|
||||
InertNodeClientIO,
|
||||
NodeClientIO,
|
||||
)
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
_AGENT_EVENT_FIELDS = {"stream_id", "node_id", "execution_id", "correlation_id"}
|
||||
|
||||
|
||||
class MockEventBus:
|
||||
"""Lightweight stand-in for EventBus that records published events."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: list[AgentEvent] = []
|
||||
|
||||
async def _record(self, event_type: EventType, **kwargs) -> None:
|
||||
agent_kwargs = {k: v for k, v in kwargs.items() if k in _AGENT_EVENT_FIELDS}
|
||||
data = {k: v for k, v in kwargs.items() if k not in _AGENT_EVENT_FIELDS}
|
||||
self.events.append(AgentEvent(type=event_type, **agent_kwargs, data=data))
|
||||
|
||||
async def emit_client_output_delta(self, **kwargs) -> None:
|
||||
await self._record(EventType.CLIENT_OUTPUT_DELTA, **kwargs)
|
||||
|
||||
async def emit_client_input_requested(self, **kwargs) -> None:
|
||||
await self._record(EventType.CLIENT_INPUT_REQUESTED, **kwargs)
|
||||
|
||||
async def emit_node_internal_output(self, **kwargs) -> None:
|
||||
await self._record(EventType.NODE_INTERNAL_OUTPUT, **kwargs)
|
||||
|
||||
async def emit_node_input_blocked(self, **kwargs) -> None:
|
||||
await self._record(EventType.NODE_INPUT_BLOCKED, **kwargs)
|
||||
|
||||
|
||||
# --- ActiveNodeClientIO tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_emit_and_consume():
|
||||
"""emit_output → output_stream round-trip works correctly."""
|
||||
bus = MockEventBus()
|
||||
io = ActiveNodeClientIO(node_id="n1", event_bus=bus)
|
||||
|
||||
await io.emit_output("Hello ")
|
||||
await io.emit_output("World", is_final=True)
|
||||
|
||||
chunks = []
|
||||
async for chunk in io.output_stream():
|
||||
chunks.append(chunk)
|
||||
|
||||
assert chunks == ["Hello ", "World"]
|
||||
assert len(bus.events) == 2
|
||||
assert all(e.type == EventType.CLIENT_OUTPUT_DELTA for e in bus.events)
|
||||
# Verify snapshot accumulates
|
||||
assert bus.events[0].data["snapshot"] == "Hello "
|
||||
assert bus.events[1].data["snapshot"] == "Hello World"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_request_input():
|
||||
"""request_input blocks until provide_input is called."""
|
||||
bus = MockEventBus()
|
||||
io = ActiveNodeClientIO(node_id="n1", event_bus=bus)
|
||||
|
||||
async def fulfill_later():
|
||||
await asyncio.sleep(0.01)
|
||||
await io.provide_input("user says hi")
|
||||
|
||||
task = asyncio.create_task(fulfill_later())
|
||||
result = await io.request_input(prompt="What?")
|
||||
await task
|
||||
|
||||
assert result == "user says hi"
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
assert bus.events[0].data["prompt"] == "What?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_request_input_timeout():
|
||||
"""request_input raises TimeoutError when timeout expires."""
|
||||
io = ActiveNodeClientIO(node_id="n1")
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await io.request_input(prompt="waiting", timeout=0.01)
|
||||
|
||||
|
||||
# --- InertNodeClientIO tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inert_emit_publishes_internal():
|
||||
"""InertNodeClientIO.emit_output publishes NODE_INTERNAL_OUTPUT."""
|
||||
bus = MockEventBus()
|
||||
io = InertNodeClientIO(node_id="n2", event_bus=bus)
|
||||
|
||||
await io.emit_output("internal log")
|
||||
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.NODE_INTERNAL_OUTPUT
|
||||
assert bus.events[0].data["content"] == "internal log"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inert_request_input_returns_redirect():
|
||||
"""request_input returns a redirect string and publishes NODE_INPUT_BLOCKED."""
|
||||
bus = MockEventBus()
|
||||
io = InertNodeClientIO(node_id="n2", event_bus=bus)
|
||||
|
||||
result = await io.request_input(prompt="need data")
|
||||
|
||||
assert "internal processing node" in result
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.NODE_INPUT_BLOCKED
|
||||
assert bus.events[0].data["prompt"] == "need data"
|
||||
|
||||
|
||||
# --- ClientIOGateway tests ---
|
||||
|
||||
|
||||
def test_gateway_creates_active_for_client_facing():
|
||||
"""ClientIOGateway.create_io returns ActiveNodeClientIO when client_facing=True."""
|
||||
gateway = ClientIOGateway()
|
||||
io = gateway.create_io(node_id="n1", client_facing=True)
|
||||
|
||||
assert isinstance(io, ActiveNodeClientIO)
|
||||
assert isinstance(io, NodeClientIO)
|
||||
|
||||
|
||||
def test_gateway_creates_inert_for_internal():
|
||||
"""ClientIOGateway.create_io returns InertNodeClientIO when client_facing=False."""
|
||||
gateway = ClientIOGateway()
|
||||
io = gateway.create_io(node_id="n2", client_facing=False)
|
||||
|
||||
assert isinstance(io, InertNodeClientIO)
|
||||
assert isinstance(io, NodeClientIO)
|
||||
@@ -0,0 +1,326 @@
|
||||
"""Tests for ContextHandoff and HandoffContext."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SpyLLMProvider(MockLLMProvider):
|
||||
"""MockLLMProvider that records whether complete() was called."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.complete_called = False
|
||||
self.complete_call_args: dict[str, Any] | None = None
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
self.complete_called = True
|
||||
self.complete_call_args = {"messages": messages, **kwargs}
|
||||
return super().complete(messages, **kwargs)
|
||||
|
||||
|
||||
class FailingLLMProvider(LLMProvider):
|
||||
"""LLM provider that always raises."""
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list,
|
||||
tool_executor: Any,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
|
||||
async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation:
|
||||
"""Build a NodeConversation from (user, assistant) message pairs."""
|
||||
conv = NodeConversation()
|
||||
for user_msg, assistant_msg in pairs:
|
||||
await conv.add_user_message(user_msg)
|
||||
await conv.add_assistant_message(assistant_msg)
|
||||
return conv
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestHandoffContext
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandoffContext:
|
||||
def test_instantiation(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="node_A",
|
||||
summary="Summary text",
|
||||
key_outputs={"result": "42"},
|
||||
turn_count=3,
|
||||
total_tokens_used=1200,
|
||||
)
|
||||
assert hc.source_node_id == "node_A"
|
||||
assert hc.summary == "Summary text"
|
||||
assert hc.key_outputs == {"result": "42"}
|
||||
assert hc.turn_count == 3
|
||||
assert hc.total_tokens_used == 1200
|
||||
|
||||
def test_field_access(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="n1",
|
||||
summary="s",
|
||||
key_outputs={},
|
||||
turn_count=0,
|
||||
total_tokens_used=0,
|
||||
)
|
||||
assert hc.key_outputs == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestExtractiveSummary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractiveSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_summary_includes_first_last(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("hello", "First response here."),
|
||||
("continue", "Middle response."),
|
||||
("finish", "Final conclusion."),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="test_node")
|
||||
|
||||
assert "First response here." in hc.summary
|
||||
assert "Final conclusion." in hc.summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_summary_metadata(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("hi", "hello"),
|
||||
("bye", "goodbye"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="node_42")
|
||||
|
||||
assert hc.source_node_id == "node_42"
|
||||
assert hc.turn_count == 2
|
||||
assert hc.total_tokens_used > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_with_output_keys_colon(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("what is the answer?", "answer: 42"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["answer"])
|
||||
|
||||
assert hc.key_outputs["answer"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_with_output_keys_equals(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("compute", "result = success"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["result"])
|
||||
|
||||
assert hc.key_outputs["result"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_json_output_keys(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("give me json", '{"score": 95, "grade": "A"}'),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
|
||||
|
||||
assert hc.key_outputs["score"] == "95"
|
||||
assert hc.key_outputs["grade"] == "A"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_empty_conversation(self) -> None:
|
||||
conv = NodeConversation()
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="empty")
|
||||
|
||||
assert hc.summary == "Empty conversation."
|
||||
assert hc.turn_count == 0
|
||||
assert hc.key_outputs == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_no_assistant_messages(self) -> None:
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("hello?")
|
||||
await conv.add_user_message("anyone there?")
|
||||
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="silent")
|
||||
|
||||
assert hc.summary == "No assistant responses."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_most_recent_wins(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("first", "status: old_value"),
|
||||
("second", "status: new_value"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["status"])
|
||||
|
||||
assert hc.key_outputs["status"] == "new_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_truncation(self) -> None:
|
||||
long_text = "x" * 1000
|
||||
conv = await _build_conversation(
|
||||
("go", long_text),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n")
|
||||
|
||||
# Summary should be truncated to ~500 chars
|
||||
assert len(hc.summary) <= 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestLLMSummary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_summary_calls_provider(self) -> None:
|
||||
llm = SpyLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("hi", "hello back"),
|
||||
("what now?", "we are done"),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
hc = ch.summarize_conversation(conv, node_id="llm_node")
|
||||
|
||||
assert llm.complete_called, "LLM complete() was never invoked"
|
||||
assert hc.summary == "This is a mock response for testing purposes."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_summary_includes_output_key_hint(self) -> None:
|
||||
llm = SpyLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("compute", '{"score": 95}'),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
|
||||
|
||||
assert llm.complete_call_args is not None
|
||||
system = llm.complete_call_args.get("system", "")
|
||||
assert "score" in system
|
||||
assert "grade" in system
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_fallback_on_error(self) -> None:
|
||||
llm = FailingLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("start", "First assistant message."),
|
||||
("end", "Last assistant message."),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
hc = ch.summarize_conversation(conv, node_id="fallback_node")
|
||||
|
||||
# Should fall back to extractive (first + last assistant messages)
|
||||
assert "First assistant message." in hc.summary
|
||||
assert "Last assistant message." in hc.summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestFormatAsInput
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatAsInput:
|
||||
def test_format_structure(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="analyzer",
|
||||
summary="Analysis complete.",
|
||||
key_outputs={"score": "95"},
|
||||
turn_count=5,
|
||||
total_tokens_used=2000,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "--- CONTEXT FROM: analyzer" in output
|
||||
assert "KEY OUTPUTS:" in output
|
||||
assert "SUMMARY:" in output
|
||||
assert "--- END CONTEXT ---" in output
|
||||
|
||||
def test_format_no_key_outputs(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="simple",
|
||||
summary="Done.",
|
||||
key_outputs={},
|
||||
turn_count=1,
|
||||
total_tokens_used=100,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "KEY OUTPUTS:" not in output
|
||||
assert "SUMMARY:" in output
|
||||
|
||||
def test_format_content_values(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="node_X",
|
||||
summary="Found 3 bugs.",
|
||||
key_outputs={"bugs": "3", "severity": "high"},
|
||||
turn_count=7,
|
||||
total_tokens_used=5000,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "node_X" in output
|
||||
assert "7 turns" in output
|
||||
assert "~5000 tokens" in output
|
||||
assert "- bugs: 3" in output
|
||||
assert "- severity: high" in output
|
||||
assert "Found 3 bugs." in output
|
||||
|
||||
def test_format_empty_summary(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="n",
|
||||
summary="",
|
||||
key_outputs={},
|
||||
turn_count=0,
|
||||
total_tokens_used=0,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "No summary available." in output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_as_input_usable_as_message(self) -> None:
|
||||
"""Formatted output can be fed into a NodeConversation as a user message."""
|
||||
hc = HandoffContext(
|
||||
source_node_id="prev_node",
|
||||
summary="Completed analysis.",
|
||||
key_outputs={"result": "42"},
|
||||
turn_count=3,
|
||||
total_tokens_used=900,
|
||||
)
|
||||
text = ContextHandoff.format_as_input(hc)
|
||||
|
||||
conv = NodeConversation()
|
||||
msg = await conv.add_user_message(text)
|
||||
|
||||
assert msg.role == "user"
|
||||
assert "CONTEXT FROM: prev_node" in msg.content
|
||||
assert conv.turn_count == 1
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,746 @@
|
||||
"""WP-8: Tests for EventLoopNode, OutputAccumulator, LoopConfig, JudgeProtocol.
|
||||
|
||||
Uses real FileConversationStore (no mocks for storage) and a MockStreamingLLM
|
||||
that yields pre-programmed StreamEvents to control the loop deterministically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
JudgeProtocol,
|
||||
JudgeVerdict,
|
||||
LoopConfig,
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeSpec, SharedMemory
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.event_bus import EventBus, EventType
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM that yields pre-programmed stream events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences.
|
||||
|
||||
Each call to stream() consumes the next scenario from the list.
|
||||
Cycles back to the beginning if more calls are made than scenarios.
|
||||
"""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build a simple text-only scenario
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list:
|
||||
"""Build a stream scenario that produces text and finishes."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(
|
||||
stop_reason="stop", input_tokens=input_tokens, output_tokens=output_tokens, model="mock"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def tool_call_scenario(
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
tool_use_id: str = "call_1",
|
||||
text: str = "",
|
||||
) -> list:
|
||||
"""Build a stream scenario that produces a tool call."""
|
||||
events = []
|
||||
if text:
|
||||
events.append(TextDeltaEvent(content=text, snapshot=text))
|
||||
events.append(
|
||||
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
|
||||
)
|
||||
events.append(
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_spec():
|
||||
return NodeSpec(
|
||||
id="test_loop",
|
||||
name="Test Loop",
|
||||
description="A test event loop node",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
system_prompt="You are a test assistant.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory():
|
||||
return SharedMemory()
|
||||
|
||||
|
||||
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
|
||||
"""Build a NodeContext for testing."""
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_spec.id,
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data=input_data or {},
|
||||
llm=llm,
|
||||
available_tools=tools or [],
|
||||
goal_context=goal_context,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# NodeProtocol conformance
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNodeProtocolConformance:
|
||||
def test_subclasses_node_protocol(self):
|
||||
"""EventLoopNode must be a subclass of NodeProtocol."""
|
||||
assert issubclass(EventLoopNode, NodeProtocol)
|
||||
|
||||
def test_has_execute_method(self):
|
||||
node = EventLoopNode()
|
||||
assert hasattr(node, "execute")
|
||||
assert asyncio.iscoroutinefunction(node.execute)
|
||||
|
||||
def test_has_validate_input(self):
|
||||
node = EventLoopNode()
|
||||
assert hasattr(node, "validate_input")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Basic loop execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestBasicLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_text_only_implicit_accept(self, runtime, node_spec, memory):
|
||||
"""No tools, no judge. LLM produces text, implicit accept on stop."""
|
||||
# Override to no output_keys so implicit judge accepts immediately
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello world")])
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.tokens_used > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_llm_returns_failure(self, runtime, node_spec, memory):
|
||||
"""ctx.llm=None should return failure immediately."""
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm=None)
|
||||
|
||||
node = EventLoopNode()
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "LLM" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_failure(self, runtime, node_spec, memory):
|
||||
"""When max_iterations is reached without acceptance, should fail."""
|
||||
# LLM always produces text but never calls set_output, so implicit
|
||||
# judge retries asking for missing keys
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "Max iterations" in result.error
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Judge integration
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestJudgeIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_accept(self, runtime, node_spec, memory):
|
||||
"""Mock judge ACCEPT -> success."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Done!")])
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
judge.evaluate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_escalate(self, runtime, node_spec, memory):
|
||||
"""Mock judge ESCALATE -> failure."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Attempt")])
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(
|
||||
return_value=JudgeVerdict(action="ESCALATE", feedback="Tone violation")
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "escalated" in result.error.lower()
|
||||
assert "Tone violation" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_retry_then_accept(self, runtime, node_spec, memory):
|
||||
"""RETRY twice, then ACCEPT. Should run 3 iterations."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("attempt 1"),
|
||||
text_scenario("attempt 2"),
|
||||
text_scenario("attempt 3"),
|
||||
]
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def evaluate_fn(context):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
return JudgeVerdict(action="RETRY", feedback="Try harder")
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=10))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert call_count == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# set_output tool
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSetOutput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_accumulates(self, runtime, node_spec, memory):
|
||||
"""LLM calls set_output -> values appear in NodeResult.output."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call set_output
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "42"}),
|
||||
# Turn 2: text response (triggers implicit judge)
|
||||
text_scenario("Done, result is 42"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory):
|
||||
"""set_output with key not in output_keys -> is_error=True."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call set_output with bad key
|
||||
tool_call_scenario("set_output", {"key": "bad_key", "value": "x"}),
|
||||
# Turn 2: call set_output with good key
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
|
||||
# Turn 3: text done
|
||||
text_scenario("Done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "ok"
|
||||
assert "bad_key" not in result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
|
||||
"""Judge accepts but output keys are missing -> retry with hint."""
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: text without set_output -> judge accepts but keys missing -> retry
|
||||
text_scenario("I'll get to it"),
|
||||
# Turn 2: set_output
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
# Turn 3: text -> judge accepts, keys present -> success
|
||||
text_scenario("All done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "done"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stall detection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStallDetection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stall_detection(self, runtime, node_spec, memory):
|
||||
"""3 identical responses should trigger stall detection."""
|
||||
node_spec.output_keys = [] # so implicit judge would accept
|
||||
# But we need the judge to RETRY so we actually get 3 identical responses
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
|
||||
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("same answer")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=10, stall_detection_threshold=3),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "stalled" in result.error.lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# EventBus lifecycle events
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEventBusLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_events_published(self, runtime, node_spec, memory):
|
||||
"""NODE_LOOP_STARTED, NODE_LOOP_ITERATION, NODE_LOOP_COMPLETED should be published."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("ok")])
|
||||
bus = EventBus()
|
||||
|
||||
received_events = []
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
],
|
||||
handler=lambda e: received_events.append(e.type),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert EventType.NODE_LOOP_STARTED in received_events
|
||||
assert EventType.NODE_LOOP_ITERATION in received_events
|
||||
assert EventType.NODE_LOOP_COMPLETED in received_events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
|
||||
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
|
||||
spec = NodeSpec(
|
||||
id="ui_node",
|
||||
name="UI Node",
|
||||
description="Streams to user",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("visible to user")])
|
||||
bus = EventBus()
|
||||
|
||||
received_types = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.LLM_TEXT_DELTA],
|
||||
handler=lambda e: received_types.append(e.type),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
await node.execute(ctx)
|
||||
|
||||
assert EventType.CLIENT_OUTPUT_DELTA in received_types
|
||||
assert EventType.LLM_TEXT_DELTA not in received_types
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tool execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestToolExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_feedback(self, runtime, node_spec, memory):
|
||||
"""Tool call -> result fed back to conversation via stream loop."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
def my_tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=f"Result for {tool_use.name}",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call a tool
|
||||
tool_call_scenario("search", {"query": "test"}, tool_use_id="call_search"),
|
||||
# Turn 2: text response after seeing tool result
|
||||
text_scenario("Found the answer"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="Search", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
tool_executor=my_tool_executor,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
# stream() should have been called twice (tool call turn + final text turn)
|
||||
assert llm._call_index >= 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Write-through persistence with real FileConversationStore
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWriteThroughPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_written_to_store(self, tmp_path, runtime, node_spec, memory):
|
||||
"""Messages should be persisted immediately via write-through."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
# Verify parts were written to disk
|
||||
parts = await store.read_parts()
|
||||
assert len(parts) >= 2 # at least initial user msg + assistant msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_accumulator_write_through(self, tmp_path, runtime, node_spec, memory):
|
||||
"""set_output values should be persisted in cursor immediately."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "persisted_value"}),
|
||||
text_scenario("Done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "persisted_value"
|
||||
|
||||
# Verify output was written to cursor on disk
|
||||
cursor = await store.read_cursor()
|
||||
assert cursor is not None
|
||||
assert cursor["outputs"]["result"] == "persisted_value"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Crash recovery (restore from real FileConversationStore)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCrashRecovery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_checkpoint(self, tmp_path, runtime, node_spec, memory):
|
||||
"""Populate a store with state, then verify EventLoopNode restores from it."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
|
||||
# Simulate a previous run that wrote conversation + cursor
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a test assistant.",
|
||||
output_keys=["result"],
|
||||
store=store,
|
||||
)
|
||||
await conv.add_user_message("Initial input")
|
||||
await conv.add_assistant_message("Working on it...")
|
||||
|
||||
# Write cursor with iteration and outputs
|
||||
await store.write_cursor(
|
||||
{
|
||||
"iteration": 1,
|
||||
"next_seq": conv.next_seq,
|
||||
"outputs": {"result": "partial_value"},
|
||||
}
|
||||
)
|
||||
|
||||
# Now create a new EventLoopNode and execute -- it should restore
|
||||
node_spec.output_keys = [] # no required keys so implicit accept works
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Continuing...")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
# Should have the restored output
|
||||
assert result.output.get("result") == "partial_value"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# External event injection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEventInjection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_event(self, runtime, node_spec, memory):
|
||||
"""inject_event() content should appear as user message in next iteration."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
judge_calls = []
|
||||
|
||||
async def evaluate_fn(context):
|
||||
judge_calls.append(context)
|
||||
if len(judge_calls) >= 2:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("iteration 1"),
|
||||
text_scenario("iteration 2"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
|
||||
# Pre-inject an event before execute runs
|
||||
await node.inject_event("Priority: CEO wants meeting rescheduled")
|
||||
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
# Verify the injected content made it into the LLM messages
|
||||
all_messages = []
|
||||
for call in llm.stream_calls:
|
||||
all_messages.extend(call["messages"])
|
||||
injected_found = any("[External event]" in str(m.get("content", "")) for m in all_messages)
|
||||
assert injected_found
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Pause/resume
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestPauseResume:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_returns_early(self, runtime, node_spec, memory):
|
||||
"""pause_requested in input_data should trigger early return."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("should not run")])
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
input_data={"pause_requested": True},
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
# Should return success (paused, not failed)
|
||||
assert result.success is True
|
||||
# LLM should not have been called (paused before first turn)
|
||||
assert llm._call_index == 0
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stream errors
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamErrors:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_recoverable_stream_error_raises(self, runtime, node_spec, memory):
|
||||
"""Non-recoverable StreamErrorEvent should raise RuntimeError."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
[StreamErrorEvent(error="Connection lost", recoverable=False)],
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
|
||||
with pytest.raises(RuntimeError, match="Stream error"):
|
||||
await node.execute(ctx)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# OutputAccumulator unit tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOutputAccumulator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get(self):
|
||||
acc = OutputAccumulator()
|
||||
await acc.set("key1", "value1")
|
||||
assert acc.get("key1") == "value1"
|
||||
assert acc.get("nonexistent") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_dict(self):
|
||||
acc = OutputAccumulator()
|
||||
await acc.set("a", 1)
|
||||
await acc.set("b", 2)
|
||||
assert acc.to_dict() == {"a": 1, "b": 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_all_keys(self):
|
||||
acc = OutputAccumulator()
|
||||
assert acc.has_all_keys([]) is True
|
||||
assert acc.has_all_keys(["x"]) is False
|
||||
await acc.set("x", "val")
|
||||
assert acc.has_all_keys(["x"]) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_through_to_real_store(self, tmp_path):
|
||||
"""OutputAccumulator should write through to FileConversationStore cursor."""
|
||||
store = FileConversationStore(tmp_path / "acc_test")
|
||||
acc = OutputAccumulator(store=store)
|
||||
|
||||
await acc.set("result", "hello")
|
||||
|
||||
cursor = await store.read_cursor()
|
||||
assert cursor["outputs"]["result"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_real_store(self, tmp_path):
|
||||
"""OutputAccumulator.restore() should rebuild from FileConversationStore."""
|
||||
store = FileConversationStore(tmp_path / "acc_restore")
|
||||
await store.write_cursor({"outputs": {"key1": "val1", "key2": "val2"}})
|
||||
|
||||
acc = await OutputAccumulator.restore(store)
|
||||
assert acc.get("key1") == "val1"
|
||||
assert acc.get("key2") == "val2"
|
||||
assert acc.has_all_keys(["key1", "key2"]) is True
|
||||
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Tests for event_loop node type wiring (Issue #2513).
|
||||
|
||||
Covers:
|
||||
- NodeSpec.client_facing field
|
||||
- event_loop in VALID_NODE_TYPES
|
||||
- _get_node_implementation() event_loop branch
|
||||
- no-retry enforcement in serial execution path
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
class AlwaysFailsNode(NodeProtocol):
|
||||
"""A test node that always fails."""
|
||||
|
||||
def __init__(self):
|
||||
self.attempt_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
self.attempt_count += 1
|
||||
return NodeResult(success=False, error=f"Permanent error (attempt {self.attempt_count})")
|
||||
|
||||
|
||||
class SucceedsOnceNode(NodeProtocol):
|
||||
"""A test node that always succeeds."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"result": "ok"})
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fast_sleep(monkeypatch):
|
||||
"""Mock asyncio.sleep to avoid real delays from exponential backoff."""
|
||||
monkeypatch.setattr("asyncio.sleep", AsyncMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
"""Create a mock Runtime for testing."""
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.start_run = MagicMock(return_value="test_run_id")
|
||||
runtime.decide = MagicMock(return_value="test_decision_id")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
runtime.report_problem = MagicMock()
|
||||
runtime.set_node = MagicMock()
|
||||
return runtime
|
||||
|
||||
|
||||
# --- NodeSpec.client_facing tests ---
|
||||
|
||||
|
||||
def test_client_facing_defaults_false():
|
||||
"""NodeSpec without client_facing should default to False."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="llm_generate",
|
||||
)
|
||||
assert spec.client_facing is False
|
||||
|
||||
|
||||
def test_client_facing_explicit_true():
|
||||
"""NodeSpec with client_facing=True should retain the value."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
)
|
||||
assert spec.client_facing is True
|
||||
|
||||
|
||||
# --- VALID_NODE_TYPES tests ---
|
||||
|
||||
|
||||
def test_event_loop_in_valid_node_types():
|
||||
"""'event_loop' must be in GraphExecutor.VALID_NODE_TYPES."""
|
||||
assert "event_loop" in GraphExecutor.VALID_NODE_TYPES
|
||||
|
||||
|
||||
def test_event_loop_node_spec_accepted():
|
||||
"""Creating a NodeSpec with node_type='event_loop' should not raise."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
assert spec.node_type == "event_loop"
|
||||
|
||||
|
||||
# --- _get_node_implementation() tests ---
|
||||
|
||||
|
||||
def test_unregistered_event_loop_raises(runtime):
|
||||
"""An event_loop node not in the registry should raise RuntimeError."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
|
||||
with pytest.raises(RuntimeError, match="not found in registry"):
|
||||
executor._get_node_implementation(spec)
|
||||
|
||||
|
||||
def test_registered_event_loop_returns_impl(runtime):
|
||||
"""A registered event_loop node should be returned from the registry."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
impl = SucceedsOnceNode()
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("el1", impl)
|
||||
|
||||
result = executor._get_node_implementation(spec)
|
||||
assert result is impl
|
||||
|
||||
|
||||
# --- No-retry enforcement (serial path) ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_forced_zero(runtime):
|
||||
"""An event_loop node with max_retries=3 should only execute once (no retry)."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_fail",
|
||||
name="Failing Event Loop",
|
||||
description="event loop that fails",
|
||||
node_type="event_loop",
|
||||
max_retries=3,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_fail",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_fail"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_fail", failing_node)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Event loop nodes get max_retries overridden to 0, meaning execute once then fail
|
||||
assert not result.success
|
||||
assert failing_node.attempt_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_zero_no_warning(runtime, caplog):
|
||||
"""An event_loop node with max_retries=0 should not log a warning."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_zero",
|
||||
name="Zero Retry Event Loop",
|
||||
description="event loop with 0 retries",
|
||||
node_type="event_loop",
|
||||
max_retries=0,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_zero",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_zero"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_zero", failing_node)
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
# max_retries=0 should not trigger the override warning
|
||||
assert "Overriding to 0" not in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_positive_logs_warning(runtime, caplog):
|
||||
"""An event_loop node with max_retries=3 should log a warning about override."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_warn",
|
||||
name="Warning Event Loop",
|
||||
description="event loop with retries",
|
||||
node_type="event_loop",
|
||||
max_retries=3,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_warn",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_warn"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_warn", failing_node)
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
assert "Overriding to 0" in caplog.text
|
||||
assert "el_warn" in caplog.text
|
||||
|
||||
|
||||
# --- Existing node types unaffected ---
|
||||
|
||||
|
||||
def test_existing_node_types_unchanged():
|
||||
"""All pre-existing node types must still be in VALID_NODE_TYPES with defaults preserved."""
|
||||
expected = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
assert expected.issubset(GraphExecutor.VALID_NODE_TYPES)
|
||||
|
||||
# Default node_type is still llm_tool_use
|
||||
spec = NodeSpec(id="x", name="X", description="x")
|
||||
assert spec.node_type == "llm_tool_use"
|
||||
|
||||
# Default max_retries is still 3
|
||||
assert spec.max_retries == 3
|
||||
|
||||
# Default client_facing is False
|
||||
assert spec.client_facing is False
|
||||
@@ -0,0 +1,978 @@
|
||||
"""Tests for extending the stream event type system.
|
||||
|
||||
Validates that the StreamEvent discriminated union pattern supports:
|
||||
- Type-based dispatch (matching on event.type)
|
||||
- Pattern matching / isinstance branching
|
||||
- Custom event subclasses following the same frozen-dataclass convention
|
||||
- Serialization of mixed event sequences
|
||||
|
||||
WP-2 tests validate EventType enum extension and node-level event routing:
|
||||
- All 12 new EventType enum members with correct string values
|
||||
- node_id routing on AgentEvent
|
||||
- filter_node on Subscription
|
||||
- Backward compatibility with existing enum members
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import FrozenInstanceError, asdict, dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType, Subscription
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers: type-based dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
def dispatch_event(event) -> str:
|
||||
"""Dispatch an event by its type field, returning a label."""
|
||||
handlers = {
|
||||
"text_delta": lambda e: f"text:{e.content}",
|
||||
"text_end": lambda e: f"end:{len(e.full_text)}chars",
|
||||
"tool_call": lambda e: f"call:{e.tool_name}",
|
||||
"tool_result": lambda e: f"result:{e.tool_use_id}",
|
||||
"reasoning_start": lambda _: "reasoning:start",
|
||||
"reasoning_delta": lambda e: f"reasoning:{e.content[:20]}",
|
||||
"finish": lambda e: f"finish:{e.stop_reason}",
|
||||
"error": lambda e: f"error:{e.error}",
|
||||
}
|
||||
handler = handlers.get(event.type)
|
||||
if handler is None:
|
||||
return f"unknown:{event.type}"
|
||||
return handler(event)
|
||||
|
||||
|
||||
def collect_text(events: list) -> str:
|
||||
"""Accumulate full text from a stream of events."""
|
||||
for event in reversed(events):
|
||||
if isinstance(event, TextEndEvent):
|
||||
return event.full_text
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
return event.snapshot
|
||||
return ""
|
||||
|
||||
|
||||
def extract_tool_calls(events: list) -> list[dict[str, Any]]:
|
||||
"""Extract tool call info from a stream of events."""
|
||||
return [
|
||||
{"id": e.tool_use_id, "name": e.tool_name, "input": e.tool_input}
|
||||
for e in events
|
||||
if isinstance(e, ToolCallEvent)
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type-based dispatch tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTypeDispatch:
|
||||
"""Dispatch on event.type string for handler routing."""
|
||||
|
||||
def test_dispatch_text_delta(self):
|
||||
e = TextDeltaEvent(content="hello")
|
||||
assert dispatch_event(e) == "text:hello"
|
||||
|
||||
def test_dispatch_text_end(self):
|
||||
e = TextEndEvent(full_text="hello world")
|
||||
assert dispatch_event(e) == "end:11chars"
|
||||
|
||||
def test_dispatch_tool_call(self):
|
||||
e = ToolCallEvent(tool_name="web_search")
|
||||
assert dispatch_event(e) == "call:web_search"
|
||||
|
||||
def test_dispatch_tool_result(self):
|
||||
e = ToolResultEvent(tool_use_id="abc")
|
||||
assert dispatch_event(e) == "result:abc"
|
||||
|
||||
def test_dispatch_reasoning_start(self):
|
||||
e = ReasoningStartEvent()
|
||||
assert dispatch_event(e) == "reasoning:start"
|
||||
|
||||
def test_dispatch_reasoning_delta(self):
|
||||
e = ReasoningDeltaEvent(content="Let me think step by step")
|
||||
assert dispatch_event(e) == "reasoning:Let me think step by"
|
||||
|
||||
def test_dispatch_finish(self):
|
||||
e = FinishEvent(stop_reason="end_turn")
|
||||
assert dispatch_event(e) == "finish:end_turn"
|
||||
|
||||
def test_dispatch_error(self):
|
||||
e = StreamErrorEvent(error="timeout")
|
||||
assert dispatch_event(e) == "error:timeout"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# isinstance-based filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestInstanceFiltering:
|
||||
"""Filter event streams using isinstance for each event type."""
|
||||
|
||||
@pytest.fixture
|
||||
def text_stream(self) -> list:
|
||||
"""Simulate a text-only stream."""
|
||||
return [
|
||||
TextDeltaEvent(content="Hello", snapshot="Hello"),
|
||||
TextDeltaEvent(content=" world", snapshot="Hello world"),
|
||||
TextDeltaEvent(content="!", snapshot="Hello world!"),
|
||||
TextEndEvent(full_text="Hello world!"),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=3, model="test"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def tool_stream(self) -> list:
|
||||
"""Simulate a tool call stream."""
|
||||
return [
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_1",
|
||||
tool_name="get_weather",
|
||||
tool_input={"city": "London"},
|
||||
),
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_2",
|
||||
tool_name="calculator",
|
||||
tool_input={"expression": "2+2"},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def reasoning_stream(self) -> list:
|
||||
"""Simulate a stream with reasoning blocks."""
|
||||
return [
|
||||
ReasoningStartEvent(),
|
||||
ReasoningDeltaEvent(content="Let me analyze this..."),
|
||||
ReasoningDeltaEvent(content="The answer is 42."),
|
||||
TextDeltaEvent(content="The answer is 42.", snapshot="The answer is 42."),
|
||||
TextEndEvent(full_text="The answer is 42."),
|
||||
FinishEvent(stop_reason="end_turn"),
|
||||
]
|
||||
|
||||
def test_collect_text(self, text_stream):
|
||||
assert collect_text(text_stream) == "Hello world!"
|
||||
|
||||
def test_collect_text_from_tool_stream(self, tool_stream):
|
||||
assert collect_text(tool_stream) == ""
|
||||
|
||||
def test_extract_tool_calls(self, tool_stream):
|
||||
calls = extract_tool_calls(tool_stream)
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["name"] == "get_weather"
|
||||
assert calls[1]["name"] == "calculator"
|
||||
|
||||
def test_extract_tool_calls_from_text_stream(self, text_stream):
|
||||
assert extract_tool_calls(text_stream) == []
|
||||
|
||||
def test_filter_text_deltas(self, text_stream):
|
||||
deltas = [e for e in text_stream if isinstance(e, TextDeltaEvent)]
|
||||
assert len(deltas) == 3
|
||||
|
||||
def test_filter_finish(self, text_stream):
|
||||
finishes = [e for e in text_stream if isinstance(e, FinishEvent)]
|
||||
assert len(finishes) == 1
|
||||
assert finishes[0].stop_reason == "stop"
|
||||
|
||||
def test_reasoning_then_text(self, reasoning_stream):
|
||||
reasoning = [e for e in reasoning_stream if isinstance(e, ReasoningDeltaEvent)]
|
||||
text = collect_text(reasoning_stream)
|
||||
assert len(reasoning) == 2
|
||||
assert text == "The answer is 42."
|
||||
|
||||
def test_mixed_stream_type_counts(self, reasoning_stream):
|
||||
type_counts = {}
|
||||
for e in reasoning_stream:
|
||||
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
||||
assert type_counts == {
|
||||
"reasoning_start": 1,
|
||||
"reasoning_delta": 2,
|
||||
"text_delta": 1,
|
||||
"text_end": 1,
|
||||
"finish": 1,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom event extension pattern
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass(frozen=True)
|
||||
class CustomMetricsEvent:
|
||||
"""Example custom event following the same pattern."""
|
||||
|
||||
type: Literal["custom_metrics"] = "custom_metrics"
|
||||
latency_ms: float = 0.0
|
||||
tokens_per_second: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CustomCitationEvent:
|
||||
"""Example citation event extending the pattern."""
|
||||
|
||||
type: Literal["citation"] = "citation"
|
||||
source_url: str = ""
|
||||
quote: str = ""
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class TestCustomEventExtension:
|
||||
"""Custom events should follow the same frozen-dataclass convention."""
|
||||
|
||||
def test_custom_event_construction(self):
|
||||
e = CustomMetricsEvent(latency_ms=150.5, tokens_per_second=42.3)
|
||||
assert e.type == "custom_metrics"
|
||||
assert e.latency_ms == 150.5
|
||||
|
||||
def test_custom_event_frozen(self):
|
||||
e = CustomMetricsEvent()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.type = "modified"
|
||||
|
||||
def test_custom_event_serialization(self):
|
||||
e = CustomMetricsEvent(
|
||||
latency_ms=100.0,
|
||||
tokens_per_second=50.0,
|
||||
metadata={"provider": "anthropic"},
|
||||
)
|
||||
d = asdict(e)
|
||||
assert d["type"] == "custom_metrics"
|
||||
assert d["metadata"] == {"provider": "anthropic"}
|
||||
|
||||
def test_custom_event_dispatch(self):
|
||||
"""Custom events can extend the dispatch map."""
|
||||
e = CustomMetricsEvent(latency_ms=200.0)
|
||||
# Falls through to "unknown" in our dispatch_event
|
||||
assert dispatch_event(e) == "unknown:custom_metrics"
|
||||
|
||||
def test_custom_event_in_mixed_stream(self):
|
||||
"""Custom events can coexist with standard events in a list."""
|
||||
stream = [
|
||||
TextDeltaEvent(content="hi", snapshot="hi"),
|
||||
CustomMetricsEvent(latency_ms=50.0),
|
||||
TextEndEvent(full_text="hi"),
|
||||
CustomCitationEvent(source_url="https://example.com", quote="hi"),
|
||||
FinishEvent(stop_reason="stop"),
|
||||
]
|
||||
standard = [
|
||||
e
|
||||
for e in stream
|
||||
if hasattr(e, "type")
|
||||
and e.type
|
||||
in {
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"tool_call",
|
||||
"tool_result",
|
||||
"reasoning_start",
|
||||
"reasoning_delta",
|
||||
"finish",
|
||||
"error",
|
||||
}
|
||||
]
|
||||
custom = [
|
||||
e
|
||||
for e in stream
|
||||
if e.type
|
||||
not in {
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"tool_call",
|
||||
"tool_result",
|
||||
"reasoning_start",
|
||||
"reasoning_delta",
|
||||
"finish",
|
||||
"error",
|
||||
}
|
||||
]
|
||||
assert len(standard) == 3
|
||||
assert len(custom) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization of full event sequences
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSequenceSerialization:
|
||||
"""Serialize entire event sequences, as done by the dump tests."""
|
||||
|
||||
def test_serialize_text_sequence(self):
|
||||
events = [
|
||||
TextDeltaEvent(content="Hello", snapshot="Hello"),
|
||||
TextDeltaEvent(content=" world", snapshot="Hello world"),
|
||||
TextEndEvent(full_text="Hello world"),
|
||||
FinishEvent(stop_reason="stop", model="test-model"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert len(serialized) == 4
|
||||
assert serialized[0]["index"] == 0
|
||||
assert serialized[0]["type"] == "text_delta"
|
||||
assert serialized[-1]["type"] == "finish"
|
||||
assert serialized[-1]["model"] == "test-model"
|
||||
|
||||
def test_serialize_tool_sequence(self):
|
||||
events = [
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_1",
|
||||
tool_name="search",
|
||||
tool_input={"query": "test"},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert serialized[0]["tool_input"] == {"query": "test"}
|
||||
assert serialized[1]["stop_reason"] == "tool_calls"
|
||||
|
||||
def test_serialize_error_sequence(self):
|
||||
events = [
|
||||
TextDeltaEvent(content="partial"),
|
||||
StreamErrorEvent(error="connection reset", recoverable=True),
|
||||
FinishEvent(stop_reason="error"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert serialized[1]["type"] == "error"
|
||||
assert serialized[1]["recoverable"] is True
|
||||
|
||||
def test_roundtrip_snapshot_accumulation(self):
|
||||
"""Verify snapshot grows monotonically through serialization."""
|
||||
chunks = ["Hello", " beautiful", " world", "!"]
|
||||
events = []
|
||||
snapshot = ""
|
||||
for chunk in chunks:
|
||||
snapshot += chunk
|
||||
events.append(TextDeltaEvent(content=chunk, snapshot=snapshot))
|
||||
|
||||
serialized = [asdict(e) for e in events]
|
||||
for i in range(1, len(serialized)):
|
||||
assert len(serialized[i]["snapshot"]) > len(serialized[i - 1]["snapshot"])
|
||||
assert serialized[-1]["snapshot"] == "Hello beautiful world!"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# WP-2: EventType Enum Extension + Node-Level Event Routing
|
||||
# ===========================================================================
|
||||
|
||||
# The 12 new EventType members added by WP-2
|
||||
WP2_EVENT_TYPES = {
|
||||
# Node event-loop lifecycle
|
||||
EventType.NODE_LOOP_STARTED: "node_loop_started",
|
||||
EventType.NODE_LOOP_ITERATION: "node_loop_iteration",
|
||||
EventType.NODE_LOOP_COMPLETED: "node_loop_completed",
|
||||
# LLM streaming observability
|
||||
EventType.LLM_TEXT_DELTA: "llm_text_delta",
|
||||
EventType.LLM_REASONING_DELTA: "llm_reasoning_delta",
|
||||
# Tool lifecycle
|
||||
EventType.TOOL_CALL_STARTED: "tool_call_started",
|
||||
EventType.TOOL_CALL_COMPLETED: "tool_call_completed",
|
||||
# Client I/O
|
||||
EventType.CLIENT_OUTPUT_DELTA: "client_output_delta",
|
||||
EventType.CLIENT_INPUT_REQUESTED: "client_input_requested",
|
||||
# Internal node observability
|
||||
EventType.NODE_INTERNAL_OUTPUT: "node_internal_output",
|
||||
EventType.NODE_INPUT_BLOCKED: "node_input_blocked",
|
||||
EventType.NODE_STALLED: "node_stalled",
|
||||
}
|
||||
|
||||
# Pre-existing enum members that must remain unchanged
|
||||
ORIGINAL_EVENT_TYPES = {
|
||||
EventType.EXECUTION_STARTED: "execution_started",
|
||||
EventType.EXECUTION_COMPLETED: "execution_completed",
|
||||
EventType.EXECUTION_FAILED: "execution_failed",
|
||||
EventType.EXECUTION_PAUSED: "execution_paused",
|
||||
EventType.EXECUTION_RESUMED: "execution_resumed",
|
||||
EventType.STATE_CHANGED: "state_changed",
|
||||
EventType.STATE_CONFLICT: "state_conflict",
|
||||
EventType.GOAL_PROGRESS: "goal_progress",
|
||||
EventType.GOAL_ACHIEVED: "goal_achieved",
|
||||
EventType.CONSTRAINT_VIOLATION: "constraint_violation",
|
||||
EventType.STREAM_STARTED: "stream_started",
|
||||
EventType.STREAM_STOPPED: "stream_stopped",
|
||||
EventType.CUSTOM: "custom",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part A: EventType enum members
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2EventTypeEnumMembers:
|
||||
"""All 12 new EventType members exist with correct string values."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,expected_value",
|
||||
WP2_EVENT_TYPES.items(),
|
||||
ids=lambda x: x.name if isinstance(x, EventType) else x,
|
||||
)
|
||||
def test_new_member_value(self, member, expected_value):
|
||||
assert member.value == expected_value
|
||||
|
||||
def test_all_12_new_members_exist(self):
|
||||
assert len(WP2_EVENT_TYPES) == 12
|
||||
|
||||
def test_new_member_string_values_are_unique(self):
|
||||
values = list(WP2_EVENT_TYPES.values())
|
||||
assert len(values) == len(set(values))
|
||||
|
||||
def test_no_collision_with_original_members(self):
|
||||
new_values = set(WP2_EVENT_TYPES.values())
|
||||
old_values = set(ORIGINAL_EVENT_TYPES.values())
|
||||
overlap = new_values & old_values
|
||||
assert overlap == set(), f"Colliding values: {overlap}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,expected_value",
|
||||
ORIGINAL_EVENT_TYPES.items(),
|
||||
ids=lambda x: x.name if isinstance(x, EventType) else x,
|
||||
)
|
||||
def test_original_members_unchanged(self, member, expected_value):
|
||||
assert member.value == expected_value
|
||||
|
||||
def test_event_type_is_str_enum(self):
|
||||
"""EventType members compare equal to their string values."""
|
||||
assert EventType.NODE_LOOP_STARTED == "node_loop_started"
|
||||
assert EventType.LLM_TEXT_DELTA == "llm_text_delta"
|
||||
assert EventType.LLM_TEXT_DELTA.value == "llm_text_delta"
|
||||
|
||||
def test_event_type_accessible_by_name(self):
|
||||
assert EventType["NODE_LOOP_STARTED"] is EventType.NODE_LOOP_STARTED
|
||||
assert EventType["TOOL_CALL_COMPLETED"] is EventType.TOOL_CALL_COMPLETED
|
||||
|
||||
def test_event_type_accessible_by_value(self):
|
||||
assert EventType("node_loop_started") is EventType.NODE_LOOP_STARTED
|
||||
assert EventType("tool_call_completed") is EventType.TOOL_CALL_COMPLETED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part B: AgentEvent.node_id and Subscription.filter_node
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2AgentEventNodeId:
|
||||
"""AgentEvent supports node_id as a first-class field."""
|
||||
|
||||
def test_node_id_defaults_to_none(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="stream-1",
|
||||
)
|
||||
assert event.node_id is None
|
||||
|
||||
def test_node_id_can_be_set(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="email_composer",
|
||||
)
|
||||
assert event.node_id == "email_composer"
|
||||
|
||||
def test_node_id_in_to_dict(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="stream-1",
|
||||
node_id="search_node",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert d["node_id"] == "search_node"
|
||||
|
||||
def test_node_id_none_in_to_dict(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="stream-1",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert "node_id" in d
|
||||
assert d["node_id"] is None
|
||||
|
||||
|
||||
class TestWP2SubscriptionFilterNode:
|
||||
"""Subscription supports filter_node for node-level routing."""
|
||||
|
||||
@staticmethod
|
||||
async def _noop_handler(event: AgentEvent) -> None:
|
||||
pass
|
||||
|
||||
def test_filter_node_defaults_to_none(self):
|
||||
sub = Subscription(
|
||||
id="sub_1",
|
||||
event_types={EventType.LLM_TEXT_DELTA},
|
||||
handler=self._noop_handler,
|
||||
)
|
||||
assert sub.filter_node is None
|
||||
|
||||
def test_filter_node_can_be_set(self):
|
||||
sub = Subscription(
|
||||
id="sub_1",
|
||||
event_types={EventType.LLM_TEXT_DELTA},
|
||||
handler=self._noop_handler,
|
||||
filter_node="email_composer",
|
||||
)
|
||||
assert sub.filter_node == "email_composer"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part B: Node-level event routing integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2NodeLevelRouting:
|
||||
"""EventBus routes events by node_id using filter_node."""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
return EventBus()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_receives_matching_events(self, bus):
|
||||
"""Subscriber with filter_node='node-A' receives events from node-A."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
filter_node="node-A",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
data={"content": "hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].node_id == "node-A"
|
||||
assert received[0].data["content"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_rejects_non_matching_events(self, bus):
|
||||
"""Subscriber with filter_node='node-B' does NOT receive node-A events."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
filter_node="node-B",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
data={"content": "hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_filter_node_receives_all_events(self, bus):
|
||||
"""Subscriber with no filter_node receives events from all nodes."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-B",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interleaved_nodes_separated_by_filter(self, bus):
|
||||
"""Two subscribers on different nodes get only their node's events."""
|
||||
node_a_events = []
|
||||
node_b_events = []
|
||||
|
||||
async def handler_a(event):
|
||||
node_a_events.append(event)
|
||||
|
||||
async def handler_b(event):
|
||||
node_b_events.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler_a,
|
||||
filter_node="email_sender",
|
||||
)
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler_b,
|
||||
filter_node="inbox_scanner",
|
||||
)
|
||||
|
||||
# Interleaved events from both nodes
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="email_sender",
|
||||
data={"content": "Dear Jo"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="inbox_scanner",
|
||||
data={"content": "RE: Meeting conf"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="email_sender",
|
||||
data={"content": "hn, Thank you for"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="inbox_scanner",
|
||||
data={"content": "irmed for Thursday"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(node_a_events) == 2
|
||||
assert len(node_b_events) == 2
|
||||
assert node_a_events[0].data["content"] == "Dear Jo"
|
||||
assert node_a_events[1].data["content"] == "hn, Thank you for"
|
||||
assert node_b_events[0].data["content"] == "RE: Meeting conf"
|
||||
assert node_b_events[1].data["content"] == "irmed for Thursday"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_combined_with_filter_stream(self, bus):
|
||||
"""filter_node and filter_stream work together."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.TOOL_CALL_STARTED],
|
||||
handler=handler,
|
||||
filter_stream="webhook",
|
||||
filter_node="search_node",
|
||||
)
|
||||
|
||||
# Matching both filters
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="webhook",
|
||||
node_id="search_node",
|
||||
)
|
||||
)
|
||||
# Wrong stream
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="api",
|
||||
node_id="search_node",
|
||||
)
|
||||
)
|
||||
# Wrong node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="webhook",
|
||||
node_id="other_node",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].stream_id == "webhook"
|
||||
assert received[0].node_id == "search_node"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_with_node_id(self, bus):
|
||||
"""wait_for() accepts node_id parameter for filtering."""
|
||||
|
||||
async def publish_later():
|
||||
await asyncio.sleep(0.01)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="target_node",
|
||||
data={"iterations": 3},
|
||||
)
|
||||
)
|
||||
|
||||
task = asyncio.create_task(publish_later())
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.NODE_LOOP_COMPLETED,
|
||||
node_id="target_node",
|
||||
timeout=2.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "target_node"
|
||||
assert event.data["iterations"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_ignores_wrong_node(self, bus):
|
||||
"""wait_for() with node_id ignores events from other nodes."""
|
||||
|
||||
async def publish_wrong_then_right():
|
||||
await asyncio.sleep(0.01)
|
||||
# Wrong node — should be ignored
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="wrong_node",
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
# Right node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="target_node",
|
||||
data={"iterations": 5},
|
||||
)
|
||||
)
|
||||
|
||||
task = asyncio.create_task(publish_wrong_then_right())
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.NODE_LOOP_COMPLETED,
|
||||
node_id="target_node",
|
||||
timeout=2.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "target_node"
|
||||
assert event.data["iterations"] == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2: Convenience publisher methods
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2ConveniencePublishers:
|
||||
"""EventBus convenience methods for new WP-2 event types."""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
return EventBus()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_started(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_STARTED], handler=handler)
|
||||
await bus.emit_node_loop_started(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].node_id == "n1"
|
||||
assert received[0].data["max_iterations"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_iteration(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_ITERATION], handler=handler)
|
||||
await bus.emit_node_loop_iteration(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
iteration=3,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["iteration"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_completed(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_COMPLETED], handler=handler)
|
||||
await bus.emit_node_loop_completed(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
iterations=5,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["iterations"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_llm_text_delta(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.LLM_TEXT_DELTA], handler=handler)
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
content="hello",
|
||||
snapshot="hello world",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["content"] == "hello"
|
||||
assert received[0].data["snapshot"] == "hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_started(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler)
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
tool_use_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "test"},
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["tool_name"] == "web_search"
|
||||
assert received[0].data["tool_input"] == {"query": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_completed(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler)
|
||||
await bus.emit_tool_call_completed(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
tool_use_id="call_1",
|
||||
tool_name="web_search",
|
||||
result="3 results found",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["result"] == "3 results found"
|
||||
assert received[0].data["is_error"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_client_output_delta(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.CLIENT_OUTPUT_DELTA], handler=handler)
|
||||
await bus.emit_client_output_delta(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
content="chunk",
|
||||
snapshot="full chunk",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["content"] == "chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_stalled(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_STALLED], handler=handler)
|
||||
await bus.emit_node_stalled(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
reason="no progress after 10 iterations",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["reason"] == "no progress after 10 iterations"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convenience_publishers_set_node_id(self, bus):
|
||||
"""All WP-2 convenience publishers set node_id on the emitted event."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA, EventType.TOOL_CALL_STARTED],
|
||||
handler=handler,
|
||||
filter_node="my_node",
|
||||
)
|
||||
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="my_node",
|
||||
content="hi",
|
||||
snapshot="hi",
|
||||
)
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s1",
|
||||
node_id="my_node",
|
||||
tool_use_id="c1",
|
||||
tool_name="calc",
|
||||
)
|
||||
# Wrong node — should not be received
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="other_node",
|
||||
content="bye",
|
||||
snapshot="bye",
|
||||
)
|
||||
|
||||
assert len(received) == 2
|
||||
assert all(e.node_id == "my_node" for e in received)
|
||||
@@ -0,0 +1,389 @@
|
||||
"""Real-API streaming tests for LiteLLM provider.
|
||||
|
||||
Calls live LLM APIs and dumps stream events to JSON files for review.
|
||||
Results are saved to core/tests/stream_event_dumps/{provider}_{model}_{scenario}.json
|
||||
|
||||
Run with:
|
||||
cd core && python -m pytest tests/test_litellm_streaming.py -v -s -k "RealAPI"
|
||||
|
||||
Requires API keys set in environment:
|
||||
ANTHROPIC_API_KEY, OPENAI_API_KEY, GEMINI_API_KEY (or via credential store)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMP_DIR = Path(__file__).parent / "stream_event_dumps"
|
||||
|
||||
|
||||
def _serialize_event(index: int, event: StreamEvent) -> dict:
|
||||
"""Serialize a StreamEvent to a JSON-safe dict."""
|
||||
d = asdict(event) # type: ignore[arg-type]
|
||||
d["index"] = index
|
||||
# Move index to front for readability
|
||||
return {"index": index, **{k: v for k, v in d.items() if k != "index"}}
|
||||
|
||||
|
||||
def _dump_events(events: list[StreamEvent], filename: str) -> Path:
|
||||
"""Write stream events to a JSON file in the dump directory."""
|
||||
DUMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filepath = DUMP_DIR / filename
|
||||
serialized = [_serialize_event(i, e) for i, e in enumerate(events)]
|
||||
filepath.write_text(json.dumps(serialized, indent=2) + "\n")
|
||||
logger.info(f"Dumped {len(events)} events to {filepath}")
|
||||
return filepath
|
||||
|
||||
|
||||
async def _collect_stream(provider: LiteLLMProvider, **kwargs) -> list[StreamEvent]:
|
||||
"""Collect all stream events from a provider.stream() call."""
|
||||
events: list[StreamEvent] = []
|
||||
async for event in provider.stream(**kwargs):
|
||||
events.append(event)
|
||||
# Log each event type as it arrives
|
||||
logger.debug(f" [{len(events) - 1}] {event.type}: {event}")
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test matrix: (model_id, dump_prefix, env_var_for_skip)
|
||||
# ---------------------------------------------------------------------------
|
||||
MODELS = [
|
||||
(
|
||||
"anthropic/claude-haiku-4-5-20251001",
|
||||
"anthropic_claude-haiku-4-5-20251001",
|
||||
"ANTHROPIC_API_KEY",
|
||||
),
|
||||
("gpt-4.1-nano", "gpt-4.1-nano", "OPENAI_API_KEY"),
|
||||
("gemini/gemini-2.0-flash", "gemini_gemini-2.0-flash", "GEMINI_API_KEY"),
|
||||
]
|
||||
|
||||
WEATHER_TOOL = Tool(
|
||||
name="get_weather",
|
||||
description="Get the current weather for a city.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name, e.g. 'Tokyo'",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
)
|
||||
|
||||
SEARCH_TOOL = Tool(
|
||||
name="web_search",
|
||||
description="Search the web for information.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
|
||||
CALCULATOR_TOOL = Tool(
|
||||
name="calculator",
|
||||
description="Perform arithmetic calculations.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Math expression to evaluate, e.g. '2 + 2'",
|
||||
}
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _has_api_key(env_var: str) -> bool:
|
||||
"""Check if an API key is available (env var or credential store)."""
|
||||
if os.environ.get(env_var):
|
||||
return True
|
||||
# Try credential store
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
provider_name = env_var.replace("_API_KEY", "").lower()
|
||||
return creds.is_available(provider_name)
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — text streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRealAPITextStreaming:
|
||||
"""Stream a simple text response from each provider and dump events."""
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a multi-paragraph response to exercise chunked delivery."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
|
||||
"Cover fetch, decode, and execute stages. Be concise but thorough."
|
||||
),
|
||||
}
|
||||
],
|
||||
system="You are a computer science teacher. Give clear, structured explanations.",
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_text.json")
|
||||
|
||||
# Basic structural assertions
|
||||
assert len(events) >= 4, f"Expected at least 4 events, got {len(events)}"
|
||||
|
||||
# Must have multiple text deltas for a longer response
|
||||
text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
|
||||
assert len(text_deltas) >= 3, f"Expected 3+ TextDeltaEvents, got {len(text_deltas)}"
|
||||
|
||||
# Snapshot must accumulate monotonically
|
||||
for i in range(1, len(text_deltas)):
|
||||
assert len(text_deltas[i].snapshot) > len(text_deltas[i - 1].snapshot), (
|
||||
f"Snapshot did not grow at index {i}"
|
||||
)
|
||||
|
||||
# Must end with TextEndEvent then FinishEvent
|
||||
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
|
||||
assert len(text_ends) == 1, f"Expected 1 TextEndEvent, got {len(text_ends)}"
|
||||
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1, f"Expected 1 FinishEvent, got {len(finish_events)}"
|
||||
assert finish_events[0].stop_reason in ("stop", "end_turn")
|
||||
|
||||
# TextEndEvent.full_text should match last snapshot
|
||||
assert text_ends[0].full_text == text_deltas[-1].snapshot
|
||||
|
||||
# Response should actually contain multi-paragraph content
|
||||
full_text = text_ends[0].full_text
|
||||
assert len(full_text) > 200, f"Response too short ({len(full_text)} chars)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — tool call streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRealAPIToolCallStreaming:
|
||||
"""Stream a tool call response from each provider and dump events."""
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a single tool call with complex arguments."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web for 'Python 3.13 release notes'.",
|
||||
}
|
||||
],
|
||||
system="You have access to tools. Use the appropriate tool.",
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_tool_call.json")
|
||||
|
||||
# Basic structural assertions
|
||||
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"
|
||||
|
||||
# Must have a tool call event
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert len(tool_calls) >= 1, "No ToolCallEvent received"
|
||||
|
||||
tc = tool_calls[0]
|
||||
assert tc.tool_name == "web_search"
|
||||
assert "query" in tc.tool_input
|
||||
assert tc.tool_use_id != ""
|
||||
|
||||
# Must end with FinishEvent
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason in ("tool_calls", "tool_use", "stop")
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_tool_call_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a response that should invoke multiple tool calls."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I need three things done in parallel: "
|
||||
"1) Get the weather in London, "
|
||||
"2) Get the weather in New York, "
|
||||
"3) Calculate 1337 * 42. "
|
||||
"Use the tools for all three."
|
||||
),
|
||||
}
|
||||
],
|
||||
system=(
|
||||
"You have access to tools. When the user asks for multiple things, "
|
||||
"call all the needed tools. Always use tools, never guess results."
|
||||
),
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_multi_tool.json")
|
||||
|
||||
# Must have multiple tool call events
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert len(tool_calls) >= 2, (
|
||||
f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
|
||||
)
|
||||
|
||||
# Verify tool names used
|
||||
tool_names = {tc.tool_name for tc in tool_calls}
|
||||
assert "get_weather" in tool_names, "Expected get_weather tool call"
|
||||
|
||||
# All tool calls should have non-empty IDs
|
||||
for tc in tool_calls:
|
||||
assert tc.tool_use_id != "", f"Empty tool_use_id on {tc.tool_name}"
|
||||
assert tc.tool_input, f"Empty tool_input on {tc.tool_name}"
|
||||
|
||||
# Must end with FinishEvent
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience runner for manual invocation
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
"""Run all streaming tests and dump results. Usage: python tests/test_litellm_streaming.py"""
|
||||
|
||||
ALL_TOOLS = [WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL]
|
||||
|
||||
async def _run_all():
|
||||
for model, prefix, env_var in MODELS:
|
||||
if not _has_api_key(env_var):
|
||||
print(f"SKIP {prefix}: {env_var} not set")
|
||||
continue
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
|
||||
# Text streaming (multi-paragraph)
|
||||
print(f"\n--- {prefix} text ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
|
||||
"Cover fetch, decode, and execute stages. Be concise but thorough."
|
||||
),
|
||||
}
|
||||
],
|
||||
system="You are a computer science teacher. Give clear, structured explanations.",
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_text.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
# Tool call streaming
|
||||
print(f"\n--- {prefix} tool_call ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web for 'Python 3.13 release notes'.",
|
||||
}
|
||||
],
|
||||
system="You have access to tools. Use the appropriate tool.",
|
||||
tools=ALL_TOOLS,
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_tool_call.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
# Multi-tool call streaming
|
||||
print(f"\n--- {prefix} multi_tool ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I need three things done in parallel: "
|
||||
"1) Get the weather in London, "
|
||||
"2) Get the weather in New York, "
|
||||
"3) Calculate 1337 * 42. "
|
||||
"Use the tools for all three."
|
||||
),
|
||||
}
|
||||
],
|
||||
system=(
|
||||
"You have access to tools. When the user asks for multiple things, "
|
||||
"call all the needed tools. Always use tools, never guess results."
|
||||
),
|
||||
tools=ALL_TOOLS,
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_multi_tool.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(_run_all())
|
||||
@@ -168,6 +168,68 @@ class TestNodeConversation:
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.estimate_tokens() == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_token_count_overrides_estimate(self):
|
||||
"""When actual API token count is provided, estimate_tokens uses it."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.estimate_tokens() == 100 # chars/4 fallback
|
||||
|
||||
conv.update_token_count(500)
|
||||
assert conv.estimate_tokens() == 500 # actual API value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_resets_token_count(self):
|
||||
"""After compaction, actual token count is cleared (recalibrates on next LLM call)."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("a" * 400)
|
||||
conv.update_token_count(500)
|
||||
assert conv.estimate_tokens() == 500
|
||||
|
||||
await conv.compact("summary", keep_recent=0)
|
||||
# Falls back to chars/4 for the summary message
|
||||
assert conv.estimate_tokens() == len("summary") // 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_resets_token_count(self):
|
||||
"""clear() also resets the actual token count."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("hello")
|
||||
conv.update_token_count(1000)
|
||||
assert conv.estimate_tokens() == 1000
|
||||
|
||||
await conv.clear()
|
||||
assert conv.estimate_tokens() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_ratio(self):
|
||||
"""usage_ratio returns estimate / max_history_tokens."""
|
||||
conv = NodeConversation(max_history_tokens=1000)
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.usage_ratio() == pytest.approx(0.1) # 100/1000
|
||||
|
||||
conv.update_token_count(800)
|
||||
assert conv.usage_ratio() == pytest.approx(0.8) # 800/1000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_ratio_zero_budget(self):
|
||||
"""usage_ratio returns 0 when max_history_tokens is 0 (unlimited)."""
|
||||
conv = NodeConversation(max_history_tokens=0)
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.usage_ratio() == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction_with_actual_tokens(self):
|
||||
"""needs_compaction uses actual API token count when available."""
|
||||
conv = NodeConversation(max_history_tokens=1000, compaction_threshold=0.8)
|
||||
await conv.add_user_message("a" * 100) # chars/4 = 25, well under 800
|
||||
|
||||
assert conv.needs_compaction() is False
|
||||
|
||||
# Simulate API reporting much higher actual token usage
|
||||
conv.update_token_count(850)
|
||||
assert conv.needs_compaction() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction(self):
|
||||
conv = NodeConversation(max_history_tokens=100, compaction_threshold=0.8)
|
||||
@@ -558,3 +620,313 @@ class TestFileConversationStore:
|
||||
assert (base / "cursor.json").exists()
|
||||
assert (base / "parts" / "0000000000.json").exists()
|
||||
assert (base / "parts" / "0000000001.json").exists()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Integration tests — real FileConversationStore, no mocks
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestConversationIntegration:
|
||||
"""End-to-end tests using real FileConversationStore on disk.
|
||||
|
||||
Every test creates a fresh directory, writes real JSON files,
|
||||
and restores from a *new* store instance (simulating process restart).
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_agent_conversation(self, tmp_path):
|
||||
"""Simulate a realistic agent conversation with multiple turns,
|
||||
tool calls, and tool results — then restore from disk."""
|
||||
base = tmp_path / "agent_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a helpful travel agent.",
|
||||
max_history_tokens=16000,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Turn 1: user asks, assistant responds with tool call
|
||||
await conv.add_user_message("Find me flights from NYC to London next Friday.")
|
||||
await conv.add_assistant_message(
|
||||
"Let me search for flights.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_flight_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_flights",
|
||||
"arguments": '{"origin":"JFK","destination":"LHR","date":"2025-06-13"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_flight_1",
|
||||
'{"flights":[{"airline":"BA","price":450,"departure":"08:00"},{"airline":"AA","price":520,"departure":"14:30"}]}',
|
||||
)
|
||||
|
||||
# Turn 2: assistant presents results, user picks one
|
||||
await conv.add_assistant_message(
|
||||
"I found 2 flights:\n"
|
||||
"1. British Airways at $450, departing 08:00\n"
|
||||
"2. American Airlines at $520, departing 14:30\n"
|
||||
"Which one would you like?"
|
||||
)
|
||||
await conv.add_user_message("Book the British Airways one.")
|
||||
await conv.add_assistant_message(
|
||||
"Booking the BA flight now.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_book_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "book_flight",
|
||||
"arguments": '{"flight_id":"BA-JFK-LHR-0800","passenger":"user"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_book_1",
|
||||
'{"confirmation":"BA-12345","status":"confirmed"}',
|
||||
)
|
||||
await conv.add_assistant_message("Your flight is booked! Confirmation: BA-12345.")
|
||||
|
||||
# Verify in-memory state
|
||||
assert conv.turn_count == 2
|
||||
assert conv.message_count == 8
|
||||
assert conv.next_seq == 8
|
||||
|
||||
# --- Simulate process restart: new store, same path ---
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.system_prompt == "You are a helpful travel agent."
|
||||
assert restored.turn_count == 2
|
||||
assert restored.message_count == 8
|
||||
assert restored.next_seq == 8
|
||||
|
||||
# Verify message content integrity
|
||||
msgs = restored.messages
|
||||
assert msgs[0].role == "user"
|
||||
assert "NYC to London" in msgs[0].content
|
||||
assert msgs[1].role == "assistant"
|
||||
assert msgs[1].tool_calls[0]["id"] == "call_flight_1"
|
||||
assert msgs[2].role == "tool"
|
||||
assert msgs[2].tool_use_id == "call_flight_1"
|
||||
assert "BA" in msgs[2].content
|
||||
assert msgs[7].content == "Your flight is booked! Confirmation: BA-12345."
|
||||
|
||||
# Verify LLM-format output
|
||||
llm_msgs = restored.to_llm_messages()
|
||||
assert llm_msgs[0] == {"role": "user", "content": msgs[0].content}
|
||||
assert llm_msgs[2]["role"] == "tool"
|
||||
assert llm_msgs[2]["tool_call_id"] == "call_flight_1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_and_restore_preserves_continuity(self, tmp_path):
|
||||
"""Build up a long conversation, compact it, continue adding
|
||||
messages, then restore — verifying seq continuity and content."""
|
||||
base = tmp_path / "compact_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="research assistant",
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Build 10 messages (5 turns)
|
||||
for i in range(5):
|
||||
await conv.add_user_message(f"question {i}")
|
||||
await conv.add_assistant_message(f"answer {i}")
|
||||
|
||||
assert conv.message_count == 10
|
||||
assert conv.next_seq == 10
|
||||
|
||||
# Compact: keep last 2 messages (question 4, answer 4)
|
||||
await conv.compact("Summary of questions 0-3 and their answers.", keep_recent=2)
|
||||
|
||||
assert conv.message_count == 3 # summary + 2 recent
|
||||
assert conv.messages[0].content == "Summary of questions 0-3 and their answers."
|
||||
assert conv.messages[1].content == "question 4"
|
||||
assert conv.messages[2].content == "answer 4"
|
||||
|
||||
# Continue the conversation post-compaction
|
||||
await conv.add_user_message("question 5")
|
||||
await conv.add_assistant_message("answer 5")
|
||||
assert conv.next_seq == 12
|
||||
|
||||
# Verify disk: old part files (seq 0-7) should be deleted
|
||||
parts_dir = base / "parts"
|
||||
part_files = sorted(parts_dir.glob("*.json"))
|
||||
part_seqs = [int(f.stem) for f in part_files]
|
||||
# Should have: summary (seq 7), question 4 (seq 8), answer 4 (seq 9),
|
||||
# question 5 (seq 10), answer 5 (seq 11)
|
||||
assert all(s >= 7 for s in part_seqs), f"Stale parts found: {part_seqs}"
|
||||
|
||||
# Restore from fresh store
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.next_seq == 12
|
||||
assert restored.message_count == 5
|
||||
assert "Summary of questions 0-3" in restored.messages[0].content
|
||||
assert restored.messages[-1].content == "answer 5"
|
||||
|
||||
# Verify seq monotonicity across all restored messages
|
||||
seqs = [m.seq for m in restored.messages]
|
||||
assert seqs == sorted(seqs), f"Seqs not monotonic: {seqs}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_key_preservation_through_compact_and_restore(self, tmp_path):
|
||||
"""Output keys in compacted messages survive disk persistence."""
|
||||
base = tmp_path / "output_key_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="classifier",
|
||||
output_keys=["classification", "confidence"],
|
||||
store=store,
|
||||
)
|
||||
|
||||
await conv.add_user_message("Classify this email: 'You won a prize!'")
|
||||
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.97"}')
|
||||
await conv.add_user_message("What about: 'Meeting at 3pm'")
|
||||
await conv.add_assistant_message('{"classification": "ham", "confidence": "0.99"}')
|
||||
await conv.add_user_message("And: 'Buy cheap meds now'")
|
||||
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.95"}')
|
||||
|
||||
# Compact keeping only the last 2 messages
|
||||
await conv.compact("Classified 3 emails.", keep_recent=2)
|
||||
|
||||
# The summary should contain preserved output keys from discarded messages
|
||||
summary_content = conv.messages[0].content
|
||||
assert "PRESERVED VALUES" in summary_content
|
||||
# Most recent values from discarded messages (msgs 0-3) are "ham"/"0.99"
|
||||
assert "ham" in summary_content or "spam" in summary_content
|
||||
|
||||
# Restore and verify the preserved values survived
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
assert "PRESERVED VALUES" in restored.messages[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_roundtrip(self, tmp_path):
|
||||
"""Tool errors persist and restore with ERROR: prefix in LLM output."""
|
||||
base = tmp_path / "error_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(store=store)
|
||||
|
||||
await conv.add_user_message("Calculate 1/0")
|
||||
await conv.add_assistant_message(
|
||||
"Let me calculate that.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_calc",
|
||||
"type": "function",
|
||||
"function": {"name": "calculator", "arguments": '{"expr":"1/0"}'},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_calc", "ZeroDivisionError: division by zero", is_error=True
|
||||
)
|
||||
await conv.add_assistant_message("The calculation failed: division by zero is undefined.")
|
||||
|
||||
# Restore
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
|
||||
tool_msg = restored.messages[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.is_error is True
|
||||
assert tool_msg.tool_use_id == "call_calc"
|
||||
|
||||
llm_dict = tool_msg.to_llm_dict()
|
||||
assert llm_dict["content"].startswith("ERROR: ")
|
||||
assert "ZeroDivisionError" in llm_dict["content"]
|
||||
assert llm_dict["tool_call_id"] == "call_calc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_conversations_isolated(self, tmp_path):
|
||||
"""Two conversations in separate directories don't interfere."""
|
||||
store_a = FileConversationStore(tmp_path / "conv_a")
|
||||
store_b = FileConversationStore(tmp_path / "conv_b")
|
||||
|
||||
conv_a = NodeConversation(system_prompt="Agent A", store=store_a)
|
||||
conv_b = NodeConversation(system_prompt="Agent B", store=store_b)
|
||||
|
||||
await conv_a.add_user_message("Hello from A")
|
||||
await conv_b.add_user_message("Hello from B")
|
||||
await conv_a.add_assistant_message("Response A")
|
||||
await conv_b.add_assistant_message("Response B")
|
||||
await conv_b.add_user_message("Follow-up B")
|
||||
|
||||
# Restore independently
|
||||
restored_a = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_a"))
|
||||
restored_b = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_b"))
|
||||
|
||||
assert restored_a.system_prompt == "Agent A"
|
||||
assert restored_b.system_prompt == "Agent B"
|
||||
assert restored_a.message_count == 2
|
||||
assert restored_b.message_count == 3
|
||||
assert restored_a.messages[0].content == "Hello from A"
|
||||
assert restored_b.messages[2].content == "Follow-up B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_removes_all_files(self, tmp_path):
|
||||
"""destroy() wipes the entire conversation directory."""
|
||||
base = tmp_path / "doomed_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(system_prompt="temp", store=store)
|
||||
await conv.add_user_message("ephemeral")
|
||||
await conv.add_assistant_message("gone soon")
|
||||
|
||||
assert base.exists()
|
||||
assert (base / "meta.json").exists()
|
||||
assert (base / "parts").exists()
|
||||
|
||||
await store.destroy()
|
||||
|
||||
assert not base.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_empty_store_returns_none(self, tmp_path):
|
||||
"""Restoring from a path that was never written to returns None."""
|
||||
store = FileConversationStore(tmp_path / "empty")
|
||||
result = await NodeConversation.restore(store)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_then_continue_then_restore(self, tmp_path):
|
||||
"""clear() removes messages but preserves seq counter for new messages."""
|
||||
base = tmp_path / "clear_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(system_prompt="s", store=store)
|
||||
|
||||
await conv.add_user_message("old msg 0")
|
||||
await conv.add_assistant_message("old msg 1")
|
||||
assert conv.next_seq == 2
|
||||
|
||||
await conv.clear()
|
||||
assert conv.message_count == 0
|
||||
assert conv.next_seq == 2 # seq counter preserved
|
||||
|
||||
# Continue with new messages — seqs should start at 2
|
||||
await conv.add_user_message("new msg")
|
||||
await conv.add_assistant_message("new response")
|
||||
assert conv.next_seq == 4
|
||||
assert conv.messages[0].seq == 2
|
||||
assert conv.messages[1].seq == 3
|
||||
|
||||
# Restore
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
assert restored.message_count == 2
|
||||
assert restored.next_seq == 4
|
||||
assert restored.messages[0].content == "new msg"
|
||||
assert restored.messages[0].seq == 2
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
"""Tests for stream event dataclasses.
|
||||
|
||||
Validates construction, defaults, immutability, serialization, and the
|
||||
StreamEvent discriminated union type.
|
||||
"""
|
||||
|
||||
from dataclasses import FrozenInstanceError, asdict, fields
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
# All concrete event classes in the union
|
||||
ALL_EVENT_CLASSES = [
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
ReasoningStartEvent,
|
||||
ReasoningDeltaEvent,
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction & defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventDefaults:
|
||||
"""Each event class should be constructible with zero arguments."""
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_default_construction(self, cls):
|
||||
event = cls()
|
||||
assert event.type != ""
|
||||
|
||||
def test_text_delta_defaults(self):
|
||||
e = TextDeltaEvent()
|
||||
assert e.type == "text_delta"
|
||||
assert e.content == ""
|
||||
assert e.snapshot == ""
|
||||
|
||||
def test_text_end_defaults(self):
|
||||
e = TextEndEvent()
|
||||
assert e.type == "text_end"
|
||||
assert e.full_text == ""
|
||||
|
||||
def test_tool_call_defaults(self):
|
||||
e = ToolCallEvent()
|
||||
assert e.type == "tool_call"
|
||||
assert e.tool_use_id == ""
|
||||
assert e.tool_name == ""
|
||||
assert e.tool_input == {}
|
||||
|
||||
def test_tool_result_defaults(self):
|
||||
e = ToolResultEvent()
|
||||
assert e.type == "tool_result"
|
||||
assert e.tool_use_id == ""
|
||||
assert e.content == ""
|
||||
assert e.is_error is False
|
||||
|
||||
def test_reasoning_start_defaults(self):
|
||||
e = ReasoningStartEvent()
|
||||
assert e.type == "reasoning_start"
|
||||
|
||||
def test_reasoning_delta_defaults(self):
|
||||
e = ReasoningDeltaEvent()
|
||||
assert e.type == "reasoning_delta"
|
||||
assert e.content == ""
|
||||
|
||||
def test_finish_defaults(self):
|
||||
e = FinishEvent()
|
||||
assert e.type == "finish"
|
||||
assert e.stop_reason == ""
|
||||
assert e.input_tokens == 0
|
||||
assert e.output_tokens == 0
|
||||
assert e.model == ""
|
||||
|
||||
def test_stream_error_defaults(self):
|
||||
e = StreamErrorEvent()
|
||||
assert e.type == "error"
|
||||
assert e.error == ""
|
||||
assert e.recoverable is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction with values
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventConstruction:
|
||||
"""Events should store provided field values correctly."""
|
||||
|
||||
def test_text_delta_with_values(self):
|
||||
e = TextDeltaEvent(content="hello", snapshot="hello world")
|
||||
assert e.content == "hello"
|
||||
assert e.snapshot == "hello world"
|
||||
|
||||
def test_text_end_with_values(self):
|
||||
e = TextEndEvent(full_text="the complete response")
|
||||
assert e.full_text == "the complete response"
|
||||
|
||||
def test_tool_call_with_values(self):
|
||||
e = ToolCallEvent(
|
||||
tool_use_id="call_abc123",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "python", "num_results": 5},
|
||||
)
|
||||
assert e.tool_use_id == "call_abc123"
|
||||
assert e.tool_name == "web_search"
|
||||
assert e.tool_input == {"query": "python", "num_results": 5}
|
||||
|
||||
def test_tool_result_with_values(self):
|
||||
e = ToolResultEvent(
|
||||
tool_use_id="call_abc123",
|
||||
content="search results here",
|
||||
is_error=False,
|
||||
)
|
||||
assert e.tool_use_id == "call_abc123"
|
||||
assert e.content == "search results here"
|
||||
assert e.is_error is False
|
||||
|
||||
def test_tool_result_error(self):
|
||||
e = ToolResultEvent(
|
||||
tool_use_id="call_fail",
|
||||
content="timeout",
|
||||
is_error=True,
|
||||
)
|
||||
assert e.is_error is True
|
||||
|
||||
def test_reasoning_delta_with_content(self):
|
||||
e = ReasoningDeltaEvent(content="Let me think about this...")
|
||||
assert e.content == "Let me think about this..."
|
||||
|
||||
def test_finish_with_values(self):
|
||||
e = FinishEvent(
|
||||
stop_reason="end_turn",
|
||||
input_tokens=150,
|
||||
output_tokens=300,
|
||||
model="claude-haiku-4-5",
|
||||
)
|
||||
assert e.stop_reason == "end_turn"
|
||||
assert e.input_tokens == 150
|
||||
assert e.output_tokens == 300
|
||||
assert e.model == "claude-haiku-4-5"
|
||||
|
||||
def test_stream_error_with_values(self):
|
||||
e = StreamErrorEvent(error="rate limit exceeded", recoverable=True)
|
||||
assert e.error == "rate limit exceeded"
|
||||
assert e.recoverable is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frozen immutability
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventImmutability:
|
||||
"""All events are frozen dataclasses — fields cannot be reassigned."""
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_frozen(self, cls):
|
||||
event = cls()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
event.type = "modified"
|
||||
|
||||
def test_text_delta_frozen_content(self):
|
||||
e = TextDeltaEvent(content="hello")
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.content = "modified"
|
||||
|
||||
def test_tool_call_frozen_input(self):
|
||||
e = ToolCallEvent(tool_input={"key": "value"})
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.tool_input = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type literal values
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTypeLiterals:
|
||||
"""Each event's `type` field should match its Literal annotation."""
|
||||
|
||||
EXPECTED_TYPES = {
|
||||
TextDeltaEvent: "text_delta",
|
||||
TextEndEvent: "text_end",
|
||||
ToolCallEvent: "tool_call",
|
||||
ToolResultEvent: "tool_result",
|
||||
ReasoningStartEvent: "reasoning_start",
|
||||
ReasoningDeltaEvent: "reasoning_delta",
|
||||
FinishEvent: "finish",
|
||||
StreamErrorEvent: "error",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls,expected_type",
|
||||
EXPECTED_TYPES.items(),
|
||||
ids=lambda x: x.__name__ if isinstance(x, type) else x,
|
||||
)
|
||||
def test_type_value(self, cls, expected_type):
|
||||
assert cls().type == expected_type
|
||||
|
||||
def test_all_types_unique(self):
|
||||
types = [cls().type for cls in ALL_EVENT_CLASSES]
|
||||
assert len(types) == len(set(types)), f"Duplicate type values: {types}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization via dataclasses.asdict
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventSerialization:
|
||||
"""Events should round-trip through asdict for JSON serialization."""
|
||||
|
||||
def test_text_delta_asdict(self):
|
||||
e = TextDeltaEvent(content="chunk", snapshot="full chunk")
|
||||
d = asdict(e)
|
||||
assert d == {"type": "text_delta", "content": "chunk", "snapshot": "full chunk"}
|
||||
|
||||
def test_tool_call_asdict(self):
|
||||
e = ToolCallEvent(
|
||||
tool_use_id="id_1",
|
||||
tool_name="calc",
|
||||
tool_input={"expression": "2+2"},
|
||||
)
|
||||
d = asdict(e)
|
||||
assert d["tool_name"] == "calc"
|
||||
assert d["tool_input"] == {"expression": "2+2"}
|
||||
|
||||
def test_finish_asdict(self):
|
||||
e = FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=20, model="gpt-4")
|
||||
d = asdict(e)
|
||||
assert d == {
|
||||
"type": "finish",
|
||||
"stop_reason": "stop",
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_asdict_contains_type(self, cls):
|
||||
d = asdict(cls())
|
||||
assert "type" in d
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_asdict_keys_match_fields(self, cls):
|
||||
event = cls()
|
||||
d = asdict(event)
|
||||
field_names = {f.name for f in fields(cls)}
|
||||
assert set(d.keys()) == field_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StreamEvent union type
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestStreamEventUnion:
|
||||
"""The StreamEvent union should include all event classes."""
|
||||
|
||||
def test_union_contains_all_classes(self):
|
||||
# StreamEvent is a UnionType (PEP 604 syntax: X | Y | Z)
|
||||
union_args = StreamEvent.__args__ # type: ignore[attr-defined]
|
||||
for cls in ALL_EVENT_CLASSES:
|
||||
assert cls in union_args, f"{cls.__name__} not in StreamEvent union"
|
||||
|
||||
def test_union_has_exactly_expected_members(self):
|
||||
union_args = set(StreamEvent.__args__) # type: ignore[attr-defined]
|
||||
expected = set(ALL_EVENT_CLASSES)
|
||||
assert union_args == expected
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_isinstance_check(self, cls):
|
||||
"""Each event instance should be an instance of its class (basic sanity)."""
|
||||
event = cls()
|
||||
assert isinstance(event, cls)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Equality & hashing (frozen dataclasses support both)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventEquality:
|
||||
"""Frozen dataclasses support equality and hashing."""
|
||||
|
||||
def test_equal_events(self):
|
||||
a = TextDeltaEvent(content="hi", snapshot="hi")
|
||||
b = TextDeltaEvent(content="hi", snapshot="hi")
|
||||
assert a == b
|
||||
|
||||
def test_unequal_events(self):
|
||||
a = TextDeltaEvent(content="hi")
|
||||
b = TextDeltaEvent(content="bye")
|
||||
assert a != b
|
||||
|
||||
def test_different_types_not_equal(self):
|
||||
a = TextDeltaEvent(content="hi")
|
||||
b = ReasoningDeltaEvent(content="hi")
|
||||
assert a != b
|
||||
|
||||
def test_hashable(self):
|
||||
e = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
s = {e} # should be hashable since frozen
|
||||
assert e in s
|
||||
|
||||
def test_equal_events_same_hash(self):
|
||||
a = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
b = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
def test_events_with_dict_not_hashable(self):
|
||||
"""Events containing dict fields (e.g. tool_input) are not hashable."""
|
||||
e = ToolCallEvent(tool_use_id="x", tool_name="y", tool_input={"key": "val"})
|
||||
with pytest.raises(TypeError, match="unhashable type"):
|
||||
hash(e)
|
||||
Reference in New Issue
Block a user