Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9034d1dc71 | |||
| 537172d8ce | |||
| a86043a2ec | |||
| 3947da2cf1 | |||
| 17caab6563 | |||
| 94d31743b0 | |||
| 70db618c6e | |||
| c52ce6bb49 | |||
| bcddd4ce77 | |||
| 017872f71b | |||
| 7e670ce0a8 |
+4
-1
@@ -69,4 +69,7 @@ exports/*
|
||||
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.venv
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*
|
||||
|
||||
@@ -0,0 +1,751 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
EventLoopNode WebSocket Demo
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams EventLoopNode execution to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/event_loop_wss_demo.py
|
||||
|
||||
Then open http://localhost:8765 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.event_loop_node import EventLoopNode, JudgeVerdict, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state (shared across WebSocket connections)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_demo_"))
|
||||
STORE = FileConversationStore(STORE_DIR / "conversation")
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — real tools via ToolRegistry (same pattern as GraphExecutor)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
# Composite credential store: encrypted files (primary) + env vars (fallback)
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_composite = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
CREDENTIALS = CredentialStoreAdapter(CredentialStore(storage=_composite))
|
||||
|
||||
# Debug: log which credentials resolved
|
||||
for _name in ["brave_search", "hubspot", "anthropic"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# --- web_search (Brave Search API) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={"X-Subscription-Token": api_key, "Accept": "application/json"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
# --- web_scrape (httpx + BeautifulSoup, no playwright for sync compat) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(url, timeout=30.0, follow_redirects=True, headers=_SCRAPE_HEADERS)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {"url": url, "title": title, "content": text, "length": len(text)}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
# --- HubSpot CRM tools (optional, requires HUBSPOT_ACCESS_TOKEN) ---
|
||||
|
||||
_HUBSPOT_API = "https://api.hubapi.com"
|
||||
|
||||
|
||||
def _hubspot_headers() -> dict | None:
|
||||
token = CREDENTIALS.get("hubspot")
|
||||
if token:
|
||||
logger.debug("HubSpot token: %s...%s (len=%d)", token[:8], token[-4:], len(token))
|
||||
else:
|
||||
logger.debug("HubSpot token: not found")
|
||||
if not token:
|
||||
return None
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _exec_hubspot_search(inputs: dict) -> dict:
|
||||
headers = _hubspot_headers()
|
||||
if not headers:
|
||||
return {"error": "HUBSPOT_ACCESS_TOKEN not set"}
|
||||
object_type = inputs.get("object_type", "contacts")
|
||||
query = inputs.get("query", "")
|
||||
limit = min(inputs.get("limit", 10), 100)
|
||||
body: dict = {"limit": limit}
|
||||
if query:
|
||||
body["query"] = query
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_HUBSPOT_API}/crm/v3/objects/{object_type}/search",
|
||||
headers=headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HubSpot API HTTP {resp.status_code}: {resp.text[:200]}"}
|
||||
return resp.json()
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"HubSpot error: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="hubspot_search",
|
||||
tool=Tool(
|
||||
name="hubspot_search",
|
||||
description=(
|
||||
"Search HubSpot CRM objects (contacts, companies, or deals). "
|
||||
"Returns matching records with their properties."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"object_type": {
|
||||
"type": "string",
|
||||
"description": "CRM object type: 'contacts', 'companies', or 'deals'",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (name, email, domain, etc.)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (1-100, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["object_type"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_hubspot_search(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# ChatJudge — keeps the event loop alive between user messages
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChatJudge:
|
||||
"""Judge that blocks between user messages, keeping the loop alive.
|
||||
|
||||
After the LLM finishes responding, the judge awaits a signal indicating
|
||||
a new user message has been injected, then returns RETRY to continue.
|
||||
"""
|
||||
|
||||
def __init__(self, on_ready=None):
|
||||
self._message_ready = asyncio.Event()
|
||||
self._shutdown = False
|
||||
self._on_ready = on_ready # async callback fired when waiting for input
|
||||
|
||||
async def evaluate(self, context: dict) -> JudgeVerdict:
|
||||
# Notify client that the LLM is done — ready for next input
|
||||
if self._on_ready:
|
||||
await self._on_ready()
|
||||
|
||||
# Block until next user message (or shutdown)
|
||||
self._message_ready.clear()
|
||||
await self._message_ready.wait()
|
||||
|
||||
if self._shutdown:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
def signal_message(self):
|
||||
"""Unblock the judge — a new user message has been injected."""
|
||||
self._message_ready.set()
|
||||
|
||||
def signal_shutdown(self):
|
||||
"""Unblock the judge and let the loop exit cleanly."""
|
||||
self._shutdown = True
|
||||
self._message_ready.set()
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page (embedded)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>EventLoopNode Live Demo</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117; color: #c9d1d9;
|
||||
height: 100vh; display: flex; flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22; padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex; align-items: center; gap: 16px;
|
||||
}
|
||||
header h1 { font-size: 16px; color: #58a6ff; font-weight: 600; }
|
||||
.status {
|
||||
font-size: 12px; padding: 3px 10px; border-radius: 12px;
|
||||
background: #21262d; color: #8b949e;
|
||||
}
|
||||
.status.running { background: #1a4b2e; color: #3fb950; }
|
||||
.status.done { background: #1a3a5c; color: #58a6ff; }
|
||||
.status.error { background: #4b1a1a; color: #f85149; }
|
||||
.chat { flex: 1; overflow-y: auto; padding: 16px; }
|
||||
.msg {
|
||||
margin: 8px 0; padding: 10px 14px; border-radius: 8px;
|
||||
line-height: 1.6; white-space: pre-wrap; word-wrap: break-word;
|
||||
}
|
||||
.msg.user { background: #1a3a5c; color: #58a6ff; }
|
||||
.msg.assistant { background: #161b22; color: #c9d1d9; }
|
||||
.msg.event {
|
||||
background: transparent; color: #8b949e; font-size: 11px;
|
||||
padding: 4px 14px; border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop { border-left-color: #58a6ff; }
|
||||
.msg.event.tool { border-left-color: #d29922; }
|
||||
.msg.event.stall { border-left-color: #f85149; }
|
||||
.input-bar {
|
||||
padding: 12px 16px; background: #161b22;
|
||||
border-top: 1px solid #30363d; display: flex; gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1; background: #0d1117; border: 1px solid #30363d;
|
||||
color: #c9d1d9; padding: 8px 12px; border-radius: 6px;
|
||||
font-family: inherit; font-size: 14px; outline: none;
|
||||
}
|
||||
.input-bar input:focus { border-color: #58a6ff; }
|
||||
.input-bar button {
|
||||
background: #238636; color: #fff; border: none;
|
||||
padding: 8px 20px; border-radius: 6px; cursor: pointer;
|
||||
font-family: inherit; font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover { background: #2ea043; }
|
||||
.input-bar button:disabled {
|
||||
background: #21262d; color: #484f58; cursor: not-allowed;
|
||||
}
|
||||
.input-bar button.clear { background: #da3633; }
|
||||
.input-bar button.clear:hover { background: #f85149; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>EventLoopNode Live</h1>
|
||||
<span id="status" class="status">Idle</span>
|
||||
<span id="iter" class="status" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Ask anything..." autofocus />
|
||||
<button id="go" onclick="run()">Send</button>
|
||||
<button class="clear"
|
||||
onclick="clearConversation()">Clear</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
const chat = document.getElementById('chat');
|
||||
const status = document.getElementById('status');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setStatus(text, cls) {
|
||||
status.textContent = text;
|
||||
status.className = 'status ' + cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setStatus('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setStatus('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'ready') {
|
||||
setStatus('Ready', 'done');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg('RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls);
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
}
|
||||
else if (evt.type === 'result') {
|
||||
setStatus('Session ended', evt.success ? 'done' : 'error');
|
||||
if (evt.error) addMsg('ERROR ' + evt.error, 'event stall');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
else if (evt.type === 'cleared') {
|
||||
chat.innerHTML = '';
|
||||
iterCount = 0;
|
||||
iterEl.textContent = 'Step 0';
|
||||
iterEl.style.display = 'none';
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
setStatus('Running', 'running');
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
function clearConversation() {
|
||||
if (ws && ws.readyState === 1) {
|
||||
ws.send(JSON.stringify({ command: 'clear' }));
|
||||
}
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Persistent WebSocket: long-lived EventLoopNode kept alive by ChatJudge."""
|
||||
global STORE
|
||||
|
||||
# -- Event forwarding (WebSocket ← EventBus) ----------------------------
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
# -- Ready callback (tells browser the LLM is done, waiting for input) --
|
||||
async def send_ready():
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "ready"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -- Per-connection state -----------------------------------------------
|
||||
judge = ChatJudge(on_ready=send_ready)
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
node_spec = NodeSpec(
|
||||
id="assistant",
|
||||
name="Chat Assistant",
|
||||
description="A conversational assistant that remembers context across messages",
|
||||
node_type="event_loop",
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to tools. "
|
||||
"You can search the web, scrape webpages, and query HubSpot CRM. "
|
||||
"Use tools when the user asks for current information or external data. "
|
||||
"You have full conversation history, so you can reference previous messages."
|
||||
),
|
||||
)
|
||||
|
||||
async def start_loop(first_message: str):
|
||||
"""Create an EventLoopNode and run it as a background task."""
|
||||
nonlocal node, loop_task
|
||||
|
||||
memory = SharedMemory()
|
||||
ctx = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="assistant",
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=10_000, max_history_tokens=32_000),
|
||||
conversation_store=STORE,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
await node.inject_event(first_message)
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"error": result.error,
|
||||
"tokens": result.tokens_used,
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Loop ended: success={result.success}, tokens={result.tokens_used}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("Loop stopped: WebSocket closed")
|
||||
except Exception as e:
|
||||
logger.exception("Loop error")
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"output": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
loop_task = asyncio.create_task(_run())
|
||||
|
||||
async def stop_loop():
|
||||
"""Signal the judge and wait for the loop task to finish."""
|
||||
nonlocal node, loop_task
|
||||
if loop_task and not loop_task.done():
|
||||
judge.signal_shutdown()
|
||||
try:
|
||||
await asyncio.wait_for(loop_task, timeout=5.0)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
loop_task.cancel()
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
# -- Message loop (runs for the lifetime of this WebSocket) -------------
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Clear command
|
||||
if msg.get("command") == "clear":
|
||||
import shutil
|
||||
|
||||
await stop_loop()
|
||||
await STORE.close()
|
||||
conv_dir = STORE_DIR / "conversation"
|
||||
if conv_dir.exists():
|
||||
shutil.rmtree(conv_dir)
|
||||
STORE = FileConversationStore(conv_dir)
|
||||
# Reset judge for next session
|
||||
judge = ChatJudge(on_ready=send_ready)
|
||||
await websocket.send(json.dumps({"type": "cleared"}))
|
||||
logger.info("Conversation cleared")
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
if node is None:
|
||||
# First message — spin up the loop
|
||||
logger.info(f"Starting persistent loop: {topic}")
|
||||
await start_loop(topic)
|
||||
else:
|
||||
# Subsequent message — inject and unblock the judge
|
||||
logger.info(f"Injecting message: {topic}")
|
||||
await node.inject_event(topic)
|
||||
judge.signal_message()
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
await stop_loop()
|
||||
logger.info("WebSocket closed, loop stopped")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler for serving the HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None # let websockets handle the upgrade
|
||||
# Serve the HTML page for any other path
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8765
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Demo running at http://localhost:{port}")
|
||||
logger.info("Open in your browser and enter a topic to research.")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,930 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Two-Node ContextHandoff Demo
|
||||
|
||||
Demonstrates ContextHandoff between two EventLoopNode instances:
|
||||
Node A (Researcher) → ContextHandoff → Node B (Analyst)
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams both nodes to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/handoff_demo.py
|
||||
|
||||
Then open http://localhost:8766 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.context_handoff import ContextHandoff # noqa: E402
|
||||
from framework.graph.conversation import NodeConversation # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("handoff_demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_handoff_"))
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Credentials
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Composite credential store: encrypted files (primary) + env vars (fallback)
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_composite = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
CREDENTIALS = CredentialStoreAdapter(CredentialStore(storage=_composite))
|
||||
|
||||
for _name in ["brave_search", "hubspot"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — web_search + web_scrape for Node A (Researcher)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(
|
||||
url,
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
headers=_SCRAPE_HEADERS,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": text,
|
||||
"length": len(text),
|
||||
}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Node Specs
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
RESEARCHER_SPEC = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches a topic using web search and scraping tools",
|
||||
node_type="event_loop",
|
||||
input_keys=["topic"],
|
||||
output_keys=["research_summary"],
|
||||
system_prompt=(
|
||||
"You are a thorough research assistant. Your job is to research "
|
||||
"the given topic using the web_search and web_scrape tools.\n\n"
|
||||
"1. Search for relevant information on the topic\n"
|
||||
"2. Scrape 1-2 of the most promising URLs for details\n"
|
||||
"3. Synthesize your findings into a comprehensive summary\n"
|
||||
"4. Use set_output with key='research_summary' to save your "
|
||||
"findings\n\n"
|
||||
"Be thorough but efficient. Aim for 2-4 search/scrape calls, "
|
||||
"then summarize and set_output."
|
||||
),
|
||||
)
|
||||
|
||||
ANALYST_SPEC = NodeSpec(
|
||||
id="analyst",
|
||||
name="Analyst",
|
||||
description="Analyzes research findings and provides insights",
|
||||
node_type="event_loop",
|
||||
input_keys=["context"],
|
||||
output_keys=["analysis"],
|
||||
system_prompt=(
|
||||
"You are a strategic analyst. You receive research findings from "
|
||||
"a previous researcher and must:\n\n"
|
||||
"1. Identify key themes and patterns\n"
|
||||
"2. Assess the reliability and significance of the findings\n"
|
||||
"3. Provide actionable insights and recommendations\n"
|
||||
"4. Use set_output with key='analysis' to save your analysis\n\n"
|
||||
"Be concise but insightful. Focus on what matters most."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>ContextHandoff Demo</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117;
|
||||
color: #c9d1d9;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22;
|
||||
padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
header h1 {
|
||||
font-size: 16px;
|
||||
color: #58a6ff;
|
||||
font-weight: 600;
|
||||
}
|
||||
.badge {
|
||||
font-size: 12px;
|
||||
padding: 3px 10px;
|
||||
border-radius: 12px;
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.researcher {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.badge.analyst {
|
||||
background: #1a4b2e;
|
||||
color: #3fb950;
|
||||
}
|
||||
.badge.handoff {
|
||||
background: #3d1f00;
|
||||
color: #d29922;
|
||||
}
|
||||
.badge.done {
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.error {
|
||||
background: #4b1a1a;
|
||||
color: #f85149;
|
||||
}
|
||||
.chat {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 16px;
|
||||
}
|
||||
.msg {
|
||||
margin: 8px 0;
|
||||
padding: 10px 14px;
|
||||
border-radius: 8px;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
.msg.user {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.msg.assistant {
|
||||
background: #161b22;
|
||||
color: #c9d1d9;
|
||||
}
|
||||
.msg.assistant.analyst-msg {
|
||||
border-left: 3px solid #3fb950;
|
||||
}
|
||||
.msg.event {
|
||||
background: transparent;
|
||||
color: #8b949e;
|
||||
font-size: 11px;
|
||||
padding: 4px 14px;
|
||||
border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop {
|
||||
border-left-color: #58a6ff;
|
||||
}
|
||||
.msg.event.tool {
|
||||
border-left-color: #d29922;
|
||||
}
|
||||
.msg.event.stall {
|
||||
border-left-color: #f85149;
|
||||
}
|
||||
.handoff-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #1c1200;
|
||||
border: 1px solid #d29922;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.handoff-banner h3 {
|
||||
color: #d29922;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.handoff-banner p, .result-banner p {
|
||||
color: #8b949e;
|
||||
font-size: 12px;
|
||||
line-height: 1.5;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap;
|
||||
text-align: left;
|
||||
}
|
||||
.result-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #0a2614;
|
||||
border: 1px solid #3fb950;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.result-banner h3 {
|
||||
color: #3fb950;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.result-banner .label {
|
||||
color: #58a6ff;
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
.result-banner .tokens {
|
||||
color: #484f58;
|
||||
font-size: 11px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.input-bar {
|
||||
padding: 12px 16px;
|
||||
background: #161b22;
|
||||
border-top: 1px solid #30363d;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1;
|
||||
background: #0d1117;
|
||||
border: 1px solid #30363d;
|
||||
color: #c9d1d9;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
font-family: inherit;
|
||||
font-size: 14px;
|
||||
outline: none;
|
||||
}
|
||||
.input-bar input:focus {
|
||||
border-color: #58a6ff;
|
||||
}
|
||||
.input-bar button {
|
||||
background: #238636;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 20px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover {
|
||||
background: #2ea043;
|
||||
}
|
||||
.input-bar button:disabled {
|
||||
background: #21262d;
|
||||
color: #484f58;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>ContextHandoff Demo</h1>
|
||||
<span id="phase" class="badge">Idle</span>
|
||||
<span id="iter" class="badge" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Enter a research topic..." autofocus />
|
||||
<button id="go" onclick="run()">Research</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
let currentPhase = 'idle';
|
||||
const chat = document.getElementById('chat');
|
||||
const phase = document.getElementById('phase');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setPhase(text, cls) {
|
||||
phase.textContent = text;
|
||||
phase.className = 'badge ' + cls;
|
||||
currentPhase = cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function addHandoffBanner(summary) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'handoff-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Context Handoff: Researcher -> Analyst';
|
||||
const p = document.createElement('p');
|
||||
p.textContent = summary || 'Passing research context...';
|
||||
banner.appendChild(h3);
|
||||
banner.appendChild(p);
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function addResultBanner(researcher, analyst, tokens) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'result-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Pipeline Complete';
|
||||
banner.appendChild(h3);
|
||||
|
||||
if (researcher && researcher.research_summary) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'RESEARCH SUMMARY';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = researcher.research_summary;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (analyst && analyst.analysis) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'ANALYSIS';
|
||||
lbl.style.color = '#3fb950';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = analyst.analysis;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (tokens) {
|
||||
const t = document.createElement('div');
|
||||
t.className = 'tokens';
|
||||
t.textContent = 'Total tokens: ' + tokens.toLocaleString();
|
||||
banner.appendChild(t);
|
||||
}
|
||||
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setPhase('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setPhase('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setPhase('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'phase') {
|
||||
if (evt.phase === 'researcher') {
|
||||
setPhase('Researcher', 'researcher');
|
||||
} else if (evt.phase === 'handoff') {
|
||||
setPhase('Handoff', 'handoff');
|
||||
} else if (evt.phase === 'analyst') {
|
||||
setPhase('Analyst', 'analyst');
|
||||
}
|
||||
iterCount = 0;
|
||||
iterEl.style.display = 'none';
|
||||
}
|
||||
else if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg(
|
||||
'RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls
|
||||
);
|
||||
var assistCls = currentPhase === 'analyst'
|
||||
? 'assistant analyst-msg' : 'assistant';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'handoff_context') {
|
||||
addHandoffBanner(evt.summary);
|
||||
var assistCls = 'assistant analyst-msg';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'node_result') {
|
||||
if (evt.node_id === 'researcher') {
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'done') {
|
||||
setPhase('Done', 'done');
|
||||
iterEl.style.display = 'none';
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
currentAssistantEl = null;
|
||||
addResultBanner(
|
||||
evt.researcher, evt.analyst, evt.total_tokens
|
||||
);
|
||||
goBtn.disabled = false;
|
||||
inputEl.placeholder = 'Enter another topic...';
|
||||
}
|
||||
else if (evt.type === 'error') {
|
||||
setPhase('Error', 'error');
|
||||
addMsg('ERROR ' + evt.message, 'event stall');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
chat.innerHTML = '';
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler — sequential Node A → Handoff → Node B
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Run the two-node handoff pipeline per user message."""
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
logger.info(f"Starting handoff pipeline for: {topic}")
|
||||
|
||||
try:
|
||||
await _run_pipeline(websocket, topic)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("WebSocket closed during pipeline")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Pipeline error")
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_pipeline(websocket, topic: str):
|
||||
"""Execute: Node A (research) → ContextHandoff → Node B (analysis)."""
|
||||
import shutil
|
||||
|
||||
# Fresh stores for each run
|
||||
run_dir = Path(tempfile.mkdtemp(prefix="hive_run_", dir=STORE_DIR))
|
||||
store_a = FileConversationStore(run_dir / "node_a")
|
||||
store_b = FileConversationStore(run_dir / "node_b")
|
||||
|
||||
# Shared event bus
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
# ---- Phase 1: Researcher ------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "researcher"}))
|
||||
|
||||
node_a = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge: accept when output_keys filled
|
||||
config=LoopConfig(
|
||||
max_iterations=20,
|
||||
max_tool_calls_per_turn=10,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_a,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
ctx_a = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="researcher",
|
||||
node_spec=RESEARCHER_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"topic": topic},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
result_a = await node_a.execute(ctx_a)
|
||||
logger.info(
|
||||
"Researcher done: success=%s, tokens=%s",
|
||||
result_a.success,
|
||||
result_a.tokens_used,
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_result",
|
||||
"node_id": "researcher",
|
||||
"success": result_a.success,
|
||||
"output": result_a.output,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if not result_a.success:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": f"Researcher failed: {result_a.error}",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# ---- Phase 2: Context Handoff -------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "handoff"}))
|
||||
|
||||
# Restore the researcher's conversation from store
|
||||
conversation_a = await NodeConversation.restore(store_a)
|
||||
if conversation_a is None:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Failed to restore researcher conversation",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
handoff_engine = ContextHandoff(llm=LLM)
|
||||
handoff_context = handoff_engine.summarize_conversation(
|
||||
conversation=conversation_a,
|
||||
node_id="researcher",
|
||||
output_keys=["research_summary"],
|
||||
)
|
||||
|
||||
formatted_handoff = ContextHandoff.format_as_input(handoff_context)
|
||||
logger.info(
|
||||
"Handoff: %d turns, ~%d tokens, keys=%s",
|
||||
handoff_context.turn_count,
|
||||
handoff_context.total_tokens_used,
|
||||
list(handoff_context.key_outputs.keys()),
|
||||
)
|
||||
|
||||
# Send handoff context to browser
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "handoff_context",
|
||||
"summary": handoff_context.summary[:500],
|
||||
"turn_count": handoff_context.turn_count,
|
||||
"tokens": handoff_context.total_tokens_used,
|
||||
"key_outputs": handoff_context.key_outputs,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# ---- Phase 3: Analyst ---------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "analyst"}))
|
||||
|
||||
node_b = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
max_tool_calls_per_turn=5,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_b,
|
||||
)
|
||||
|
||||
ctx_b = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="analyst",
|
||||
node_spec=ANALYST_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"context": formatted_handoff},
|
||||
llm=LLM,
|
||||
available_tools=[],
|
||||
)
|
||||
|
||||
result_b = await node_b.execute(ctx_b)
|
||||
logger.info(
|
||||
"Analyst done: success=%s, tokens=%s",
|
||||
result_b.success,
|
||||
result_b.tokens_used,
|
||||
)
|
||||
|
||||
# ---- Done ---------------------------------------------------------------
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "done",
|
||||
"researcher": result_a.output,
|
||||
"analyst": result_b.output,
|
||||
"total_tokens": ((result_a.tokens_used or 0) + (result_b.tokens_used or 0)),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Clean up temp stores
|
||||
try:
|
||||
shutil.rmtree(run_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8766
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Handoff demo at http://localhost:{port}")
|
||||
logger.info("Enter a research topic to start the pipeline.")
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,8 +1,22 @@
|
||||
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
|
||||
|
||||
from framework.graph.client_io import (
|
||||
ActiveNodeClientIO,
|
||||
ClientIOGateway,
|
||||
InertNodeClientIO,
|
||||
NodeClientIO,
|
||||
)
|
||||
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import ConversationStore, Message, NodeConversation
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
JudgeProtocol,
|
||||
JudgeVerdict,
|
||||
LoopConfig,
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
|
||||
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
|
||||
@@ -77,4 +91,18 @@ __all__ = [
|
||||
"NodeConversation",
|
||||
"ConversationStore",
|
||||
"Message",
|
||||
# Event Loop
|
||||
"EventLoopNode",
|
||||
"LoopConfig",
|
||||
"OutputAccumulator",
|
||||
"JudgeProtocol",
|
||||
"JudgeVerdict",
|
||||
# Context Handoff
|
||||
"ContextHandoff",
|
||||
"HandoffContext",
|
||||
# Client I/O
|
||||
"NodeClientIO",
|
||||
"ActiveNodeClientIO",
|
||||
"InertNodeClientIO",
|
||||
"ClientIOGateway",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Client I/O gateway for graph nodes.
|
||||
|
||||
Provides the bridge between node code and external clients:
|
||||
- ActiveNodeClientIO: for client_facing=True nodes (streams output, accepts input)
|
||||
- InertNodeClientIO: for client_facing=False nodes (logs internally, redirects input)
|
||||
- ClientIOGateway: factory that creates the right variant per node
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeClientIO(ABC):
|
||||
"""Abstract base for node client I/O."""
|
||||
|
||||
@abstractmethod
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
"""Emit output content. If is_final=True, signal end of stream."""
|
||||
|
||||
@abstractmethod
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
"""Request input. Behavior depends on whether the node is client-facing."""
|
||||
|
||||
|
||||
class ActiveNodeClientIO(NodeClientIO):
|
||||
"""
|
||||
Client I/O for client_facing=True nodes.
|
||||
|
||||
- emit_output() queues content and publishes CLIENT_OUTPUT_DELTA.
|
||||
- request_input() publishes CLIENT_INPUT_REQUESTED, then awaits provide_input().
|
||||
- output_stream() yields queued content until the final sentinel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
|
||||
self._output_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._output_snapshot = ""
|
||||
|
||||
self._input_event: asyncio.Event | None = None
|
||||
self._input_result: str | None = None
|
||||
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
self._output_snapshot += content
|
||||
await self._output_queue.put(content)
|
||||
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_client_output_delta(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
snapshot=self._output_snapshot,
|
||||
)
|
||||
|
||||
if is_final:
|
||||
await self._output_queue.put(None)
|
||||
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
if self._input_event is not None:
|
||||
raise RuntimeError("request_input already pending for this node")
|
||||
|
||||
self._input_event = asyncio.Event()
|
||||
self._input_result = None
|
||||
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_client_input_requested(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
try:
|
||||
if timeout is not None:
|
||||
await asyncio.wait_for(self._input_event.wait(), timeout=timeout)
|
||||
else:
|
||||
await self._input_event.wait()
|
||||
finally:
|
||||
self._input_event = None
|
||||
|
||||
if self._input_result is None:
|
||||
raise RuntimeError("input event was set but no input was provided")
|
||||
result = self._input_result
|
||||
self._input_result = None
|
||||
return result
|
||||
|
||||
async def provide_input(self, content: str) -> None:
|
||||
"""Called externally to fulfill a pending request_input()."""
|
||||
if self._input_event is None:
|
||||
raise RuntimeError("no pending request_input to fulfill")
|
||||
self._input_result = content
|
||||
self._input_event.set()
|
||||
|
||||
async def output_stream(self) -> AsyncIterator[str]:
|
||||
"""Async iterator that yields output chunks until the final sentinel."""
|
||||
while True:
|
||||
chunk = await self._output_queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
|
||||
class InertNodeClientIO(NodeClientIO):
|
||||
"""
|
||||
Client I/O for client_facing=False nodes.
|
||||
|
||||
- emit_output() publishes NODE_INTERNAL_OUTPUT (content is not discarded).
|
||||
- request_input() publishes NODE_INPUT_BLOCKED and returns a redirect string.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_node_internal_output(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_node_input_blocked(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
return (
|
||||
"You are an internal processing node. There is no user to interact with."
|
||||
" Work with the data provided in your inputs to complete your task."
|
||||
)
|
||||
|
||||
|
||||
class ClientIOGateway:
|
||||
"""Factory that creates the appropriate NodeClientIO for a node."""
|
||||
|
||||
def __init__(self, event_bus: EventBus | None = None) -> None:
|
||||
self._event_bus = event_bus
|
||||
|
||||
def create_io(self, node_id: str, client_facing: bool) -> NodeClientIO:
|
||||
if client_facing:
|
||||
return ActiveNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
)
|
||||
return InertNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
)
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Context handoff: summarize a completed NodeConversation for the next graph node."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.conversation import _try_extract_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.provider import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRUNCATE_CHARS = 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandoffContext:
|
||||
"""Structured summary of a completed node conversation."""
|
||||
|
||||
source_node_id: str
|
||||
summary: str
|
||||
key_outputs: dict[str, Any]
|
||||
turn_count: int
|
||||
total_tokens_used: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContextHandoff
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ContextHandoff:
|
||||
"""Summarize a completed NodeConversation into a HandoffContext.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
llm : LLMProvider | None
|
||||
Optional LLM provider for abstractive summarization.
|
||||
When *None*, all summarization uses the extractive fallback.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLMProvider | None = None) -> None:
|
||||
self.llm = llm
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def summarize_conversation(
|
||||
self,
|
||||
conversation: NodeConversation,
|
||||
node_id: str,
|
||||
output_keys: list[str] | None = None,
|
||||
) -> HandoffContext:
|
||||
"""Produce a HandoffContext from *conversation*.
|
||||
|
||||
1. Extracts turn_count & total_tokens_used (sync properties).
|
||||
2. Extracts key_outputs by scanning assistant messages most-recent-first.
|
||||
3. Builds a summary via the LLM (if available) or extractive fallback.
|
||||
"""
|
||||
turn_count = conversation.turn_count
|
||||
total_tokens_used = conversation.estimate_tokens()
|
||||
messages = conversation.messages # defensive copy
|
||||
|
||||
# --- key outputs ---------------------------------------------------
|
||||
key_outputs: dict[str, Any] = {}
|
||||
if output_keys:
|
||||
remaining = set(output_keys)
|
||||
for msg in reversed(messages):
|
||||
if msg.role != "assistant" or not remaining:
|
||||
continue
|
||||
for key in list(remaining):
|
||||
value = _try_extract_key(msg.content, key)
|
||||
if value is not None:
|
||||
key_outputs[key] = value
|
||||
remaining.discard(key)
|
||||
|
||||
# --- summary -------------------------------------------------------
|
||||
if self.llm is not None:
|
||||
try:
|
||||
summary = self._llm_summary(messages, output_keys or [])
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"LLM summarization failed; falling back to extractive.",
|
||||
exc_info=True,
|
||||
)
|
||||
summary = self._extractive_summary(messages)
|
||||
else:
|
||||
summary = self._extractive_summary(messages)
|
||||
|
||||
return HandoffContext(
|
||||
source_node_id=node_id,
|
||||
summary=summary,
|
||||
key_outputs=key_outputs,
|
||||
turn_count=turn_count,
|
||||
total_tokens_used=total_tokens_used,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_as_input(handoff: HandoffContext) -> str:
|
||||
"""Render *handoff* as structured plain text for the next node's input."""
|
||||
header = (
|
||||
f"--- CONTEXT FROM: {handoff.source_node_id} "
|
||||
f"({handoff.turn_count} turns, ~{handoff.total_tokens_used} tokens) ---"
|
||||
)
|
||||
|
||||
sections: list[str] = [header, ""]
|
||||
|
||||
if handoff.key_outputs:
|
||||
sections.append("KEY OUTPUTS:")
|
||||
for k, v in handoff.key_outputs.items():
|
||||
sections.append(f"- {k}: {v}")
|
||||
sections.append("")
|
||||
|
||||
summary_text = handoff.summary or "No summary available."
|
||||
sections.append("SUMMARY:")
|
||||
sections.append(summary_text)
|
||||
sections.append("")
|
||||
sections.append("--- END CONTEXT ---")
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extractive_summary(messages: list) -> str:
|
||||
"""Build a summary from key assistant messages without an LLM.
|
||||
|
||||
Strategy:
|
||||
- Include the first assistant message (initial assessment).
|
||||
- Include the last assistant message (final conclusion).
|
||||
- Truncate each to ~500 chars.
|
||||
"""
|
||||
if not messages:
|
||||
return "Empty conversation."
|
||||
|
||||
assistant_msgs = [m for m in messages if m.role == "assistant"]
|
||||
if not assistant_msgs:
|
||||
return "No assistant responses."
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
first = assistant_msgs[0].content
|
||||
parts.append(first[:_TRUNCATE_CHARS])
|
||||
|
||||
if len(assistant_msgs) > 1:
|
||||
last = assistant_msgs[-1].content
|
||||
parts.append(last[:_TRUNCATE_CHARS])
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _llm_summary(self, messages: list, output_keys: list[str]) -> str:
|
||||
"""Produce a summary by calling the LLM provider."""
|
||||
if self.llm is None:
|
||||
raise ValueError("_llm_summary called without an LLM provider")
|
||||
|
||||
conversation_text = "\n".join(f"[{m.role}]: {m.content}" for m in messages)
|
||||
|
||||
key_hint = ""
|
||||
if output_keys:
|
||||
key_hint = (
|
||||
"\nThe following output keys are especially important: "
|
||||
+ ", ".join(output_keys)
|
||||
+ ".\n"
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a concise summarizer. Given the conversation below, "
|
||||
"produce a brief summary (at most ~500 tokens) that captures the "
|
||||
"key decisions, findings, and outcomes. Focus on what was concluded "
|
||||
"rather than the back-and-forth process." + key_hint
|
||||
)
|
||||
|
||||
response = self.llm.complete(
|
||||
messages=[{"role": "user", "content": conversation_text}],
|
||||
system=system_prompt,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
return response.content.strip()
|
||||
@@ -108,6 +108,50 @@ class ConversationStore(Protocol):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_extract_key(content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a *key*'s value from message content.
|
||||
|
||||
Strategies (in order):
|
||||
1. Whole message is JSON — ``json.loads``, check for key.
|
||||
2. Embedded JSON via ``find_json_object`` helper.
|
||||
3. Colon format: ``key: value``.
|
||||
4. Equals format: ``key = value``.
|
||||
"""
|
||||
from framework.graph.node import find_json_object
|
||||
|
||||
# 1. Whole message is JSON
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 2. Embedded JSON via find_json_object
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Colon format: key: value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 4. Equals format: key = value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NodeConversation:
|
||||
"""Message history for a graph node with optional write-through persistence.
|
||||
|
||||
@@ -205,8 +249,47 @@ class NodeConversation:
|
||||
# --- Query -------------------------------------------------------------
|
||||
|
||||
def to_llm_messages(self) -> list[dict[str, Any]]:
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded)."""
|
||||
return [m.to_llm_dict() for m in self._messages]
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded).
|
||||
|
||||
Automatically repairs orphaned tool_use blocks (assistant messages
|
||||
with tool_calls that lack corresponding tool-result messages). This
|
||||
can happen when a loop is cancelled mid-tool-execution.
|
||||
"""
|
||||
msgs = [m.to_llm_dict() for m in self._messages]
|
||||
return self._repair_orphaned_tool_calls(msgs)
|
||||
|
||||
@staticmethod
|
||||
def _repair_orphaned_tool_calls(
|
||||
msgs: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Ensure every tool_call has a matching tool-result message."""
|
||||
repaired: list[dict[str, Any]] = []
|
||||
for i, m in enumerate(msgs):
|
||||
repaired.append(m)
|
||||
tool_calls = m.get("tool_calls")
|
||||
if m.get("role") != "assistant" or not tool_calls:
|
||||
continue
|
||||
# Collect IDs of tool results that follow this assistant message
|
||||
answered: set[str] = set()
|
||||
for j in range(i + 1, len(msgs)):
|
||||
if msgs[j].get("role") == "tool":
|
||||
tid = msgs[j].get("tool_call_id")
|
||||
if tid:
|
||||
answered.add(tid)
|
||||
else:
|
||||
break # stop at first non-tool message
|
||||
# Patch any missing results
|
||||
for tc in tool_calls:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in answered:
|
||||
repaired.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": "ERROR: Tool execution was interrupted.",
|
||||
}
|
||||
)
|
||||
return repaired
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Rough token estimate: total characters / 4."""
|
||||
@@ -244,39 +327,7 @@ class NodeConversation:
|
||||
|
||||
def _try_extract_key(self, content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a key's value from message content."""
|
||||
from framework.graph.node import find_json_object
|
||||
|
||||
# 1. Whole message is JSON
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 2. Embedded JSON via find_json_object
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Colon format: key: value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 4. Equals format: key = value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
return _try_extract_key(content, key)
|
||||
|
||||
# --- Lifecycle ---------------------------------------------------------
|
||||
|
||||
|
||||
@@ -0,0 +1,832 @@
|
||||
"""EventLoopNode: Multi-turn LLM streaming loop with tool execution and judge evaluation.
|
||||
|
||||
Implements NodeProtocol and runs a streaming event loop:
|
||||
1. Calls LLMProvider.stream() to get streaming events
|
||||
2. Processes text deltas, tool calls, and finish events
|
||||
3. Executes tools and feeds results back to the conversation
|
||||
4. Uses judge evaluation (or implicit stop-reason) to decide loop termination
|
||||
5. Publishes lifecycle events to EventBus
|
||||
6. Persists conversation and outputs via write-through to ConversationStore
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from framework.graph.conversation import ConversationStore, NodeConversation
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult
|
||||
from framework.llm.provider import Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Judge protocol (simple 3-action interface for event loop evaluation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeVerdict:
|
||||
"""Result of judge evaluation for the event loop."""
|
||||
|
||||
action: Literal["ACCEPT", "RETRY", "ESCALATE"]
|
||||
feedback: str = ""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class JudgeProtocol(Protocol):
|
||||
"""Protocol for event-loop judges.
|
||||
|
||||
Implementations evaluate the current state of the event loop and
|
||||
decide whether to accept the output, retry with feedback, or escalate.
|
||||
"""
|
||||
|
||||
async def evaluate(self, context: dict[str, Any]) -> JudgeVerdict: ...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopConfig:
|
||||
"""Configuration for the event loop."""
|
||||
|
||||
max_iterations: int = 50
|
||||
max_tool_calls_per_turn: int = 10
|
||||
judge_every_n_turns: int = 1
|
||||
stall_detection_threshold: int = 3
|
||||
max_history_tokens: int = 32_000
|
||||
store_prefix: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output accumulator with write-through persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputAccumulator:
|
||||
"""Accumulates output key-value pairs with optional write-through persistence.
|
||||
|
||||
Values are stored in memory and optionally written through to a
|
||||
ConversationStore's cursor data for crash recovery.
|
||||
"""
|
||||
|
||||
values: dict[str, Any] = field(default_factory=dict)
|
||||
store: ConversationStore | None = None
|
||||
|
||||
async def set(self, key: str, value: Any) -> None:
|
||||
"""Set a key-value pair, persisting immediately if store is available."""
|
||||
self.values[key] = value
|
||||
if self.store:
|
||||
cursor = await self.store.read_cursor() or {}
|
||||
outputs = cursor.get("outputs", {})
|
||||
outputs[key] = value
|
||||
cursor["outputs"] = outputs
|
||||
await self.store.write_cursor(cursor)
|
||||
|
||||
def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key, or None if not present."""
|
||||
return self.values.get(key)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a copy of all accumulated values."""
|
||||
return dict(self.values)
|
||||
|
||||
def has_all_keys(self, required: list[str]) -> bool:
|
||||
"""Check if all required keys have been set (non-None)."""
|
||||
return all(key in self.values and self.values[key] is not None for key in required)
|
||||
|
||||
@classmethod
|
||||
async def restore(cls, store: ConversationStore) -> OutputAccumulator:
|
||||
"""Restore an OutputAccumulator from a store's cursor data."""
|
||||
cursor = await store.read_cursor()
|
||||
values = {}
|
||||
if cursor and "outputs" in cursor:
|
||||
values = cursor["outputs"]
|
||||
return cls(values=values, store=store)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EventLoopNode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EventLoopNode(NodeProtocol):
|
||||
"""Multi-turn LLM streaming loop with tool execution and judge evaluation.
|
||||
|
||||
Lifecycle:
|
||||
1. Try to restore from durable state (crash recovery)
|
||||
2. If no prior state, init from NodeSpec.system_prompt + input_keys
|
||||
3. Loop: drain injection queue -> stream LLM -> execute tools -> judge
|
||||
(each add_* and set_output writes through to store immediately)
|
||||
4. Publish events to EventBus at each stage
|
||||
5. Write cursor after each iteration
|
||||
6. Terminate when judge returns ACCEPT (or max iterations)
|
||||
7. Build output dict from OutputAccumulator
|
||||
|
||||
Always returns NodeResult with retryable=False semantics. The executor
|
||||
must NOT retry event loop nodes -- retry is handled internally by the
|
||||
judge (RETRY action continues the loop). See WP-7 enforcement.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_bus: EventBus | None = None,
|
||||
judge: JudgeProtocol | None = None,
|
||||
config: LoopConfig | None = None,
|
||||
tool_executor: Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None = None,
|
||||
conversation_store: ConversationStore | None = None,
|
||||
) -> None:
|
||||
self._event_bus = event_bus
|
||||
self._judge = judge
|
||||
self._config = config or LoopConfig()
|
||||
self._tool_executor = tool_executor
|
||||
self._conversation_store = conversation_store
|
||||
self._injection_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
def validate_input(self, ctx: NodeContext) -> list[str]:
|
||||
"""Validate hard requirements only.
|
||||
|
||||
Event loop nodes are LLM-powered and can reason about flexible input,
|
||||
so input_keys are treated as hints — not strict requirements.
|
||||
Only the LLM provider is a hard dependency.
|
||||
"""
|
||||
errors = []
|
||||
if ctx.llm is None:
|
||||
errors.append("LLM provider is required for EventLoopNode")
|
||||
return errors
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Public API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
"""Run the event loop."""
|
||||
start_time = time.time()
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
stream_id = ctx.node_id
|
||||
node_id = ctx.node_id
|
||||
|
||||
# 1. Guard: LLM required
|
||||
if ctx.llm is None:
|
||||
return NodeResult(success=False, error="LLM provider not available")
|
||||
|
||||
# 2. Restore or create new conversation + accumulator
|
||||
conversation, accumulator, start_iteration = await self._restore(ctx)
|
||||
if conversation is None:
|
||||
conversation = NodeConversation(
|
||||
system_prompt=ctx.node_spec.system_prompt or "",
|
||||
max_history_tokens=self._config.max_history_tokens,
|
||||
output_keys=ctx.node_spec.output_keys or None,
|
||||
store=self._conversation_store,
|
||||
)
|
||||
accumulator = OutputAccumulator(store=self._conversation_store)
|
||||
start_iteration = 0
|
||||
|
||||
# Add initial user message from input data
|
||||
initial_message = self._build_initial_message(ctx)
|
||||
if initial_message:
|
||||
await conversation.add_user_message(initial_message)
|
||||
|
||||
# 3. Build tool list: node tools + synthetic set_output tool
|
||||
tools = list(ctx.available_tools)
|
||||
set_output_tool = self._build_set_output_tool(ctx.node_spec.output_keys)
|
||||
if set_output_tool:
|
||||
tools.append(set_output_tool)
|
||||
|
||||
# 4. Publish loop started
|
||||
await self._publish_loop_started(stream_id, node_id)
|
||||
|
||||
# 5. Stall detection state
|
||||
recent_responses: list[str] = []
|
||||
|
||||
# 6. Main loop
|
||||
for iteration in range(start_iteration, self._config.max_iterations):
|
||||
# 6a. Check pause
|
||||
if await self._check_pause(ctx, conversation, iteration):
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
# 6b. Drain injection queue
|
||||
await self._drain_injection_queue(conversation)
|
||||
|
||||
# 6c. Publish iteration event
|
||||
await self._publish_iteration(stream_id, node_id, iteration)
|
||||
|
||||
# 6d. Compaction check
|
||||
if conversation.needs_compaction():
|
||||
summary = await self._generate_compaction_summary(ctx, conversation)
|
||||
await conversation.compact(summary, keep_recent=4)
|
||||
|
||||
# 6e. Run single LLM turn
|
||||
assistant_text, tool_results_list, turn_tokens = await self._run_single_turn(
|
||||
ctx, conversation, tools, iteration, accumulator
|
||||
)
|
||||
total_input_tokens += turn_tokens.get("input", 0)
|
||||
total_output_tokens += turn_tokens.get("output", 0)
|
||||
|
||||
# 6f. Stall detection
|
||||
recent_responses.append(assistant_text)
|
||||
if len(recent_responses) > self._config.stall_detection_threshold:
|
||||
recent_responses.pop(0)
|
||||
if self._is_stalled(recent_responses):
|
||||
await self._publish_stalled(stream_id, node_id)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=(
|
||||
f"Node stalled: {self._config.stall_detection_threshold} "
|
||||
"consecutive identical responses"
|
||||
),
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
# 6g. Write cursor checkpoint
|
||||
await self._write_cursor(ctx, conversation, accumulator, iteration)
|
||||
|
||||
# 6h. Judge evaluation
|
||||
should_judge = (
|
||||
(iteration + 1) % self._config.judge_every_n_turns == 0
|
||||
or not tool_results_list # no tool calls = natural stop
|
||||
)
|
||||
|
||||
if should_judge:
|
||||
verdict = await self._evaluate(
|
||||
ctx,
|
||||
conversation,
|
||||
accumulator,
|
||||
assistant_text,
|
||||
tool_results_list,
|
||||
iteration,
|
||||
)
|
||||
|
||||
if verdict.action == "ACCEPT":
|
||||
# Check for missing output keys
|
||||
missing = self._get_missing_output_keys(accumulator, ctx.node_spec.output_keys)
|
||||
if missing and self._judge is not None:
|
||||
hint = (
|
||||
f"Missing required output keys: {missing}. "
|
||||
"Use set_output to provide them."
|
||||
)
|
||||
await conversation.add_user_message(hint)
|
||||
continue
|
||||
|
||||
# Write outputs to shared memory
|
||||
for key, value in accumulator.to_dict().items():
|
||||
ctx.memory.write(key, value, validate=False)
|
||||
|
||||
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
elif verdict.action == "ESCALATE":
|
||||
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=f"Judge escalated: {verdict.feedback}",
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
elif verdict.action == "RETRY":
|
||||
if verdict.feedback:
|
||||
await conversation.add_user_message(f"[Judge feedback]: {verdict.feedback}")
|
||||
continue
|
||||
|
||||
# 7. Max iterations exhausted
|
||||
await self._publish_loop_completed(stream_id, node_id, self._config.max_iterations)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=(f"Max iterations ({self._config.max_iterations}) reached without acceptance"),
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
async def inject_event(self, content: str) -> None:
|
||||
"""Inject an external event into the running loop.
|
||||
|
||||
The content becomes a user message prepended to the next iteration.
|
||||
Thread-safe via asyncio.Queue.
|
||||
"""
|
||||
await self._injection_queue.put(content)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Single LLM turn with caller-managed tool orchestration
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def _run_single_turn(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
tools: list[Tool],
|
||||
iteration: int,
|
||||
accumulator: OutputAccumulator,
|
||||
) -> tuple[str, list[dict], dict[str, int]]:
|
||||
"""Run a single LLM turn with streaming and tool execution.
|
||||
|
||||
Returns (assistant_text, tool_results, token_counts).
|
||||
"""
|
||||
stream_id = ctx.node_id
|
||||
node_id = ctx.node_id
|
||||
token_counts: dict[str, int] = {"input": 0, "output": 0}
|
||||
tool_call_count = 0
|
||||
final_text = ""
|
||||
|
||||
# Inner tool loop: stream may produce tool calls requiring re-invocation
|
||||
while True:
|
||||
messages = conversation.to_llm_messages()
|
||||
accumulated_text = ""
|
||||
tool_calls: list[ToolCallEvent] = []
|
||||
|
||||
# Stream LLM response
|
||||
async for event in ctx.llm.stream(
|
||||
messages=messages,
|
||||
system=conversation.system_prompt,
|
||||
tools=tools if tools else None,
|
||||
max_tokens=ctx.max_tokens,
|
||||
):
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
accumulated_text = event.snapshot
|
||||
await self._publish_text_delta(
|
||||
stream_id, node_id, event.content, event.snapshot, ctx
|
||||
)
|
||||
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
tool_calls.append(event)
|
||||
|
||||
elif isinstance(event, FinishEvent):
|
||||
token_counts["input"] += event.input_tokens
|
||||
token_counts["output"] += event.output_tokens
|
||||
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
if not event.recoverable:
|
||||
raise RuntimeError(f"Stream error: {event.error}")
|
||||
logger.warning(f"Recoverable stream error: {event.error}")
|
||||
|
||||
final_text = accumulated_text
|
||||
|
||||
# Record assistant message (write-through via conversation store)
|
||||
tc_dicts = None
|
||||
if tool_calls:
|
||||
tc_dicts = [
|
||||
{
|
||||
"id": tc.tool_use_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.tool_name,
|
||||
"arguments": json.dumps(tc.tool_input),
|
||||
},
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
await conversation.add_assistant_message(
|
||||
content=accumulated_text,
|
||||
tool_calls=tc_dicts,
|
||||
)
|
||||
|
||||
# If no tool calls, turn is complete
|
||||
if not tool_calls:
|
||||
return final_text, [], token_counts
|
||||
|
||||
# Execute tool calls
|
||||
tool_results: list[dict] = []
|
||||
for tc in tool_calls:
|
||||
tool_call_count += 1
|
||||
if tool_call_count > self._config.max_tool_calls_per_turn:
|
||||
logger.warning(
|
||||
f"Max tool calls per turn ({self._config.max_tool_calls_per_turn}) exceeded"
|
||||
)
|
||||
break
|
||||
|
||||
# Publish tool call started
|
||||
await self._publish_tool_started(
|
||||
stream_id, node_id, tc.tool_use_id, tc.tool_name, tc.tool_input
|
||||
)
|
||||
|
||||
# Handle set_output synthetic tool
|
||||
if tc.tool_name == "set_output":
|
||||
result = self._handle_set_output(tc.tool_input, ctx.node_spec.output_keys)
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
)
|
||||
# Async write-through for set_output
|
||||
if not result.is_error:
|
||||
await accumulator.set(tc.tool_input["key"], tc.tool_input["value"])
|
||||
else:
|
||||
# Execute real tool
|
||||
result = await self._execute_tool(tc)
|
||||
|
||||
# Record tool result in conversation (write-through)
|
||||
await conversation.add_tool_result(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
)
|
||||
tool_results.append(
|
||||
{
|
||||
"tool_use_id": tc.tool_use_id,
|
||||
"tool_name": tc.tool_name,
|
||||
"content": result.content,
|
||||
"is_error": result.is_error,
|
||||
}
|
||||
)
|
||||
|
||||
# Publish tool call completed
|
||||
await self._publish_tool_completed(
|
||||
stream_id,
|
||||
node_id,
|
||||
tc.tool_use_id,
|
||||
tc.tool_name,
|
||||
result.content,
|
||||
result.is_error,
|
||||
)
|
||||
|
||||
# Tool calls processed -- loop back to stream with updated conversation
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# set_output synthetic tool
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def _build_set_output_tool(self, output_keys: list[str] | None) -> Tool | None:
|
||||
"""Build the synthetic set_output tool for explicit output declaration."""
|
||||
if not output_keys:
|
||||
return None
|
||||
return Tool(
|
||||
name="set_output",
|
||||
description=(
|
||||
"Set an output value for this node. Call once per output key. "
|
||||
f"Valid keys: {output_keys}"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": f"Output key. Must be one of: {output_keys}",
|
||||
"enum": output_keys,
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "The output value to store.",
|
||||
},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_set_output(
|
||||
self,
|
||||
tool_input: dict[str, Any],
|
||||
output_keys: list[str] | None,
|
||||
) -> ToolResult:
|
||||
"""Handle set_output tool call. Returns ToolResult (sync)."""
|
||||
key = tool_input.get("key", "")
|
||||
valid_keys = output_keys or []
|
||||
|
||||
if key not in valid_keys:
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Invalid output key '{key}'. Valid keys: {valid_keys}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Output '{key}' set successfully.",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Judge evaluation
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def _evaluate(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator,
|
||||
assistant_text: str,
|
||||
tool_results: list[dict],
|
||||
iteration: int,
|
||||
) -> JudgeVerdict:
|
||||
"""Evaluate the current state using judge or implicit logic."""
|
||||
if self._judge is not None:
|
||||
context = {
|
||||
"assistant_text": assistant_text,
|
||||
"tool_calls": tool_results,
|
||||
"output_accumulator": accumulator.to_dict(),
|
||||
"iteration": iteration,
|
||||
"conversation_summary": conversation.export_summary(),
|
||||
"output_keys": ctx.node_spec.output_keys,
|
||||
"missing_keys": self._get_missing_output_keys(
|
||||
accumulator, ctx.node_spec.output_keys
|
||||
),
|
||||
}
|
||||
return await self._judge.evaluate(context)
|
||||
|
||||
# Implicit judge: accept when no tool calls and all output keys present
|
||||
if not tool_results:
|
||||
missing = self._get_missing_output_keys(accumulator, ctx.node_spec.output_keys)
|
||||
if not missing:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
else:
|
||||
return JudgeVerdict(
|
||||
action="RETRY",
|
||||
feedback=(
|
||||
f"Missing output keys: {missing}. Use set_output tool to provide them."
|
||||
),
|
||||
)
|
||||
|
||||
# Tool calls were made -- continue loop
|
||||
return JudgeVerdict(action="RETRY", feedback="")
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def _build_initial_message(self, ctx: NodeContext) -> str:
|
||||
"""Build the initial user message from input data and memory.
|
||||
|
||||
Includes ALL input_data (not just declared input_keys) so that
|
||||
upstream handoff data flows through regardless of key naming.
|
||||
Declared input_keys are also checked in shared memory as fallback.
|
||||
"""
|
||||
parts = []
|
||||
seen: set[str] = set()
|
||||
# Include everything from input_data (flexible handoff)
|
||||
for key, value in ctx.input_data.items():
|
||||
if value is not None:
|
||||
parts.append(f"{key}: {value}")
|
||||
seen.add(key)
|
||||
# Fallback: check memory for declared input_keys not already covered
|
||||
for key in ctx.node_spec.input_keys:
|
||||
if key not in seen:
|
||||
value = ctx.memory.read(key)
|
||||
if value is not None:
|
||||
parts.append(f"{key}: {value}")
|
||||
if ctx.goal_context:
|
||||
parts.append(f"\nGoal: {ctx.goal_context}")
|
||||
return "\n".join(parts) if parts else "Begin."
|
||||
|
||||
def _get_missing_output_keys(
|
||||
self,
|
||||
accumulator: OutputAccumulator,
|
||||
output_keys: list[str] | None,
|
||||
) -> list[str]:
|
||||
"""Return output keys that have not been set yet."""
|
||||
if not output_keys:
|
||||
return []
|
||||
return [k for k in output_keys if accumulator.get(k) is None]
|
||||
|
||||
def _is_stalled(self, recent_responses: list[str]) -> bool:
|
||||
"""Detect stall: N consecutive identical non-empty responses."""
|
||||
if len(recent_responses) < self._config.stall_detection_threshold:
|
||||
return False
|
||||
if not recent_responses[0]:
|
||||
return False
|
||||
return all(r == recent_responses[0] for r in recent_responses)
|
||||
|
||||
async def _execute_tool(self, tc: ToolCallEvent) -> ToolResult:
|
||||
"""Execute a tool call, handling both sync and async executors."""
|
||||
if self._tool_executor is None:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"No tool executor configured for '{tc.tool_name}'",
|
||||
is_error=True,
|
||||
)
|
||||
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
|
||||
result = self._tool_executor(tool_use)
|
||||
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
async def _generate_compaction_summary(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
) -> str:
|
||||
"""Use LLM to generate a conversation summary for compaction."""
|
||||
messages_text = "\n".join(
|
||||
f"[{m.role}]: {m.content[:200]}" for m in conversation.messages[-10:]
|
||||
)
|
||||
prompt = (
|
||||
"Summarize this conversation so far in 2-3 sentences, "
|
||||
"preserving key decisions and results:\n\n"
|
||||
f"{messages_text}"
|
||||
)
|
||||
try:
|
||||
response = ctx.llm.complete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="Summarize conversations concisely.",
|
||||
max_tokens=300,
|
||||
)
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.warning(f"Compaction summary generation failed: {e}")
|
||||
return "Previous conversation context (summary unavailable)."
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Persistence: restore, cursor, injection, pause
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def _restore(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
) -> tuple[NodeConversation | None, OutputAccumulator | None, int]:
|
||||
"""Attempt to restore from a previous checkpoint."""
|
||||
if self._conversation_store is None:
|
||||
return None, None, 0
|
||||
|
||||
conversation = await NodeConversation.restore(self._conversation_store)
|
||||
if conversation is None:
|
||||
return None, None, 0
|
||||
|
||||
accumulator = await OutputAccumulator.restore(self._conversation_store)
|
||||
|
||||
cursor = await self._conversation_store.read_cursor()
|
||||
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
|
||||
|
||||
logger.info(
|
||||
f"Restored event loop: iteration={start_iteration}, "
|
||||
f"messages={conversation.message_count}, "
|
||||
f"outputs={list(accumulator.values.keys())}"
|
||||
)
|
||||
return conversation, accumulator, start_iteration
|
||||
|
||||
async def _write_cursor(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator,
|
||||
iteration: int,
|
||||
) -> None:
|
||||
"""Write checkpoint cursor for crash recovery."""
|
||||
if self._conversation_store:
|
||||
cursor = await self._conversation_store.read_cursor() or {}
|
||||
cursor.update(
|
||||
{
|
||||
"iteration": iteration,
|
||||
"node_id": ctx.node_id,
|
||||
"next_seq": conversation.next_seq,
|
||||
"outputs": accumulator.to_dict(),
|
||||
}
|
||||
)
|
||||
await self._conversation_store.write_cursor(cursor)
|
||||
|
||||
async def _drain_injection_queue(self, conversation: NodeConversation) -> int:
|
||||
"""Drain all pending injected events as user messages. Returns count."""
|
||||
count = 0
|
||||
while not self._injection_queue.empty():
|
||||
try:
|
||||
content = self._injection_queue.get_nowait()
|
||||
await conversation.add_user_message(f"[External event]: {content}")
|
||||
count += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
return count
|
||||
|
||||
async def _check_pause(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
iteration: int,
|
||||
) -> bool:
|
||||
"""Check if pause has been requested. Returns True if paused."""
|
||||
pause_requested = ctx.input_data.get("pause_requested", False)
|
||||
if not pause_requested:
|
||||
pause_requested = ctx.memory.read("pause_requested") or False
|
||||
if pause_requested:
|
||||
logger.info(f"Pause requested at iteration {iteration}")
|
||||
return True
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# EventBus publishing helpers
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def _publish_loop_started(self, stream_id: str, node_id: str) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_loop_started(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
max_iterations=self._config.max_iterations,
|
||||
)
|
||||
|
||||
async def _publish_iteration(self, stream_id: str, node_id: str, iteration: int) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_loop_iteration(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
async def _publish_loop_completed(self, stream_id: str, node_id: str, iterations: int) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_loop_completed(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
async def _publish_stalled(self, stream_id: str, node_id: str) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_node_stalled(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
reason="Consecutive identical responses detected",
|
||||
)
|
||||
|
||||
async def _publish_text_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
ctx: NodeContext,
|
||||
) -> None:
|
||||
if self._event_bus:
|
||||
if ctx.node_spec.client_facing:
|
||||
await self._event_bus.emit_client_output_delta(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
else:
|
||||
await self._event_bus.emit_llm_text_delta(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
async def _publish_tool_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_tool_call_started(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
)
|
||||
|
||||
async def _publish_tool_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
is_error: bool,
|
||||
) -> None:
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_tool_call_completed(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
result=result,
|
||||
is_error=is_error,
|
||||
)
|
||||
@@ -11,6 +11,7 @@ The executor:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
@@ -380,6 +381,15 @@ class GraphExecutor:
|
||||
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
|
||||
max_retries = getattr(node_spec, "max_retries", 3)
|
||||
|
||||
# Event loop nodes handle retry internally via judge —
|
||||
# executor retry is catastrophic (retry multiplication)
|
||||
if node_spec.node_type == "event_loop" and max_retries > 0:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has max_retries={max_retries}. "
|
||||
"Overriding to 0 — event loop nodes handle retry internally via judge."
|
||||
)
|
||||
max_retries = 0
|
||||
|
||||
if node_retry_counts[current_node_id] < max_retries:
|
||||
# Retry - don't increment steps for retries
|
||||
steps -= 1
|
||||
@@ -658,7 +668,15 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
VALID_NODE_TYPES = {
|
||||
"llm_tool_use",
|
||||
"llm_generate",
|
||||
"router",
|
||||
"function",
|
||||
"human_input",
|
||||
"event_loop",
|
||||
}
|
||||
DEPRECATED_NODE_TYPES = {"llm_tool_use": "event_loop", "llm_generate": "event_loop"}
|
||||
|
||||
def _get_node_implementation(
|
||||
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
|
||||
@@ -676,6 +694,17 @@ class GraphExecutor:
|
||||
f"Use 'llm_tool_use' for nodes that call tools, 'llm_generate' for text generation."
|
||||
)
|
||||
|
||||
# Warn on deprecated node types
|
||||
if node_spec.node_type in self.DEPRECATED_NODE_TYPES:
|
||||
replacement = self.DEPRECATED_NODE_TYPES[node_spec.node_type]
|
||||
warnings.warn(
|
||||
f"Node type '{node_spec.node_type}' is deprecated. "
|
||||
f"Use '{replacement}' instead. "
|
||||
f"Node: '{node_spec.id}'",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Create based on type
|
||||
if node_spec.node_type == "llm_tool_use":
|
||||
if not node_spec.tools:
|
||||
@@ -713,6 +742,13 @@ class GraphExecutor:
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "event_loop":
|
||||
# Event loop nodes must be pre-registered (like function nodes)
|
||||
raise RuntimeError(
|
||||
f"EventLoopNode '{node_spec.id}' not found in registry. "
|
||||
"Register it with executor.register_node() before execution."
|
||||
)
|
||||
|
||||
# Should never reach here due to validation above
|
||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||
|
||||
@@ -909,6 +945,17 @@ class GraphExecutor:
|
||||
branch.status = "failed"
|
||||
branch.error = f"Node {branch.node_id} not found in graph"
|
||||
return branch, RuntimeError(branch.error)
|
||||
|
||||
effective_max_retries = node_spec.max_retries
|
||||
if node_spec.node_type == "event_loop":
|
||||
if effective_max_retries > 1:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has "
|
||||
f"max_retries={effective_max_retries}. Overriding "
|
||||
"to 1 — event loop nodes handle retry internally."
|
||||
)
|
||||
effective_max_retries = 1
|
||||
|
||||
branch.status = "running"
|
||||
|
||||
try:
|
||||
@@ -942,7 +989,7 @@ class GraphExecutor:
|
||||
|
||||
# Execute with retries
|
||||
last_result = None
|
||||
for attempt in range(node_spec.max_retries):
|
||||
for attempt in range(effective_max_retries):
|
||||
branch.retry_count = attempt
|
||||
|
||||
# Build context for this branch
|
||||
@@ -970,7 +1017,7 @@ class GraphExecutor:
|
||||
|
||||
self.logger.warning(
|
||||
f" ↻ Branch {node_spec.name}: "
|
||||
f"retry {attempt + 1}/{node_spec.max_retries}"
|
||||
f"retry {attempt + 1}/{effective_max_retries}"
|
||||
)
|
||||
|
||||
# All retries exhausted
|
||||
@@ -979,7 +1026,7 @@ class GraphExecutor:
|
||||
branch.result = last_result
|
||||
self.logger.error(
|
||||
f" ✗ Branch {node_spec.name}: "
|
||||
f"failed after {node_spec.max_retries} attempts"
|
||||
f"failed after {effective_max_retries} attempts"
|
||||
)
|
||||
return branch, last_result
|
||||
|
||||
|
||||
@@ -153,7 +153,10 @@ class NodeSpec(BaseModel):
|
||||
# Node behavior type
|
||||
node_type: str = Field(
|
||||
default="llm_tool_use",
|
||||
description="Type: 'llm_tool_use', 'llm_generate', 'function', 'router', 'human_input'",
|
||||
description=(
|
||||
"Type: 'event_loop', 'function', 'router', 'human_input'. "
|
||||
"Deprecated: 'llm_tool_use', 'llm_generate' (use 'event_loop' instead)."
|
||||
),
|
||||
)
|
||||
|
||||
# Data flow
|
||||
@@ -218,6 +221,12 @@ class NodeSpec(BaseModel):
|
||||
description="Maximum retries when Pydantic validation fails (with feedback to LLM)",
|
||||
)
|
||||
|
||||
# Client-facing behavior
|
||||
client_facing: bool = Field(
|
||||
default=False,
|
||||
description="If True, this node streams output to the end user and can request input.",
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,31 @@
|
||||
"""LLM provider abstraction."""
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse"]
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"StreamEvent",
|
||||
"TextDeltaEvent",
|
||||
"TextEndEvent",
|
||||
"ToolCallEvent",
|
||||
"ToolResultEvent",
|
||||
"ReasoningStartEvent",
|
||||
"ReasoningDeltaEvent",
|
||||
"FinishEvent",
|
||||
"StreamErrorEvent",
|
||||
]
|
||||
|
||||
try:
|
||||
from framework.llm.anthropic import AnthropicProvider # noqa: F401
|
||||
|
||||
@@ -7,10 +7,11 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,6 +24,7 @@ except ImportError:
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -425,3 +427,167 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a completion via litellm.acompletion(stream=True).
|
||||
|
||||
Yields StreamEvent objects as chunks arrive from the provider.
|
||||
Tool call arguments are accumulated across chunks and yielded as
|
||||
a single ToolCallEvent with fully parsed JSON when complete.
|
||||
|
||||
Empty responses (e.g. Gemini stealth rate-limits that return 200
|
||||
with no content) are retried with exponential backoff, mirroring
|
||||
the retry behaviour of ``_completion_with_rate_limit_retry``.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
buffered_events: list[StreamEvent] = []
|
||||
accumulated_text = ""
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
async for chunk in response:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# --- Text content ---
|
||||
if delta and delta.content:
|
||||
accumulated_text += delta.content
|
||||
buffered_events.append(
|
||||
TextDeltaEvent(
|
||||
content=delta.content,
|
||||
snapshot=accumulated_text,
|
||||
)
|
||||
)
|
||||
|
||||
# --- Tool calls (accumulate across chunks) ---
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function:
|
||||
if tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
|
||||
# --- Finish ---
|
||||
if choice.finish_reason:
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
buffered_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
tool_name=tc_data["name"],
|
||||
tool_input=parsed_args,
|
||||
)
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
buffered_events.append(TextEndEvent(full_text=accumulated_text))
|
||||
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
buffered_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether the stream produced any real content.
|
||||
has_content = accumulated_text or tool_calls_acc
|
||||
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
)
|
||||
dump_path = _dump_failed_request(
|
||||
model=self.model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_stream",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} returned empty stream — "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
# Success (or final attempt) — flush buffered events.
|
||||
for event in buffered_events:
|
||||
yield event
|
||||
return
|
||||
|
||||
except RateLimitError as e:
|
||||
if attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
@@ -2,10 +2,16 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
|
||||
class MockLLMProvider(LLMProvider):
|
||||
@@ -175,3 +181,28 @@ class MockLLMProvider(LLMProvider):
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a mock completion as word-level TextDeltaEvents.
|
||||
|
||||
Splits the mock response into words and yields each as a separate
|
||||
TextDeltaEvent with an accumulating snapshot, exercising the full
|
||||
streaming pipeline without any API calls.
|
||||
"""
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
|
||||
for i, word in enumerate(words):
|
||||
chunk = word if i == 0 else " " + word
|
||||
accumulated += chunk
|
||||
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
|
||||
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(stop_reason="mock_complete", model=self.model)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -108,3 +108,45 @@ class LLMProvider(ABC):
|
||||
Final LLMResponse after tool use completes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator["StreamEvent"]:
|
||||
"""
|
||||
Stream a completion as an async iterator of StreamEvents.
|
||||
|
||||
Default implementation wraps complete() with synthetic events.
|
||||
Subclasses SHOULD override for true streaming.
|
||||
|
||||
Tool orchestration is the CALLER's responsibility:
|
||||
- Caller detects ToolCallEvent, executes tool, adds result
|
||||
to messages, calls stream() again.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
response = self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
|
||||
# Deferred import target for type annotation
|
||||
from framework.llm.stream_events import StreamEvent as StreamEvent # noqa: E402, F401
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Stream event types for LLM streaming responses.
|
||||
|
||||
Defines a discriminated union of frozen dataclasses representing every event
|
||||
a streaming LLM call can produce. These types form the contract between the
|
||||
LLM provider layer, EventLoopNode, event bus, persistence, and monitoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextDeltaEvent:
|
||||
"""A chunk of text produced by the LLM."""
|
||||
|
||||
type: Literal["text_delta"] = "text_delta"
|
||||
content: str = "" # this chunk's text
|
||||
snapshot: str = "" # accumulated text so far
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextEndEvent:
|
||||
"""Signals that text generation is complete."""
|
||||
|
||||
type: Literal["text_end"] = "text_end"
|
||||
full_text: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolCallEvent:
|
||||
"""The LLM has requested a tool call."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
tool_use_id: str = ""
|
||||
tool_name: str = ""
|
||||
tool_input: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolResultEvent:
|
||||
"""Result of executing a tool call."""
|
||||
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str = ""
|
||||
content: str = ""
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningStartEvent:
|
||||
"""The LLM has started a reasoning/thinking block."""
|
||||
|
||||
type: Literal["reasoning_start"] = "reasoning_start"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningDeltaEvent:
|
||||
"""A chunk of reasoning/thinking content."""
|
||||
|
||||
type: Literal["reasoning_delta"] = "reasoning_delta"
|
||||
content: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinishEvent:
|
||||
"""The LLM has finished generating."""
|
||||
|
||||
type: Literal["finish"] = "finish"
|
||||
stop_reason: str = ""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamErrorEvent:
|
||||
"""An error occurred during streaming."""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: str = ""
|
||||
recoverable: bool = False
|
||||
|
||||
|
||||
# Discriminated union of all stream event types
|
||||
StreamEvent = (
|
||||
TextDeltaEvent
|
||||
| TextEndEvent
|
||||
| ToolCallEvent
|
||||
| ToolResultEvent
|
||||
| ReasoningStartEvent
|
||||
| ReasoningDeltaEvent
|
||||
| FinishEvent
|
||||
| StreamErrorEvent
|
||||
)
|
||||
@@ -41,6 +41,28 @@ class EventType(str, Enum):
|
||||
STREAM_STARTED = "stream_started"
|
||||
STREAM_STOPPED = "stream_stopped"
|
||||
|
||||
# Node event-loop lifecycle
|
||||
NODE_LOOP_STARTED = "node_loop_started"
|
||||
NODE_LOOP_ITERATION = "node_loop_iteration"
|
||||
NODE_LOOP_COMPLETED = "node_loop_completed"
|
||||
|
||||
# LLM streaming observability
|
||||
LLM_TEXT_DELTA = "llm_text_delta"
|
||||
LLM_REASONING_DELTA = "llm_reasoning_delta"
|
||||
|
||||
# Tool lifecycle
|
||||
TOOL_CALL_STARTED = "tool_call_started"
|
||||
TOOL_CALL_COMPLETED = "tool_call_completed"
|
||||
|
||||
# Client I/O (client_facing=True nodes only)
|
||||
CLIENT_OUTPUT_DELTA = "client_output_delta"
|
||||
CLIENT_INPUT_REQUESTED = "client_input_requested"
|
||||
|
||||
# Internal node observability (client_facing=False nodes)
|
||||
NODE_INTERNAL_OUTPUT = "node_internal_output"
|
||||
NODE_INPUT_BLOCKED = "node_input_blocked"
|
||||
NODE_STALLED = "node_stalled"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
|
||||
@@ -51,6 +73,7 @@ class AgentEvent:
|
||||
|
||||
type: EventType
|
||||
stream_id: str
|
||||
node_id: str | None = None # Which node emitted this event
|
||||
execution_id: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
@@ -61,6 +84,7 @@ class AgentEvent:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"stream_id": self.stream_id,
|
||||
"node_id": self.node_id,
|
||||
"execution_id": self.execution_id,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
@@ -80,6 +104,7 @@ class Subscription:
|
||||
event_types: set[EventType]
|
||||
handler: EventHandler
|
||||
filter_stream: str | None = None # Only receive events from this stream
|
||||
filter_node: str | None = None # Only receive events from this node
|
||||
filter_execution: str | None = None # Only receive events from this execution
|
||||
|
||||
|
||||
@@ -138,6 +163,7 @@ class EventBus:
|
||||
event_types: list[EventType],
|
||||
handler: EventHandler,
|
||||
filter_stream: str | None = None,
|
||||
filter_node: str | None = None,
|
||||
filter_execution: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -147,6 +173,7 @@ class EventBus:
|
||||
event_types: Types of events to receive
|
||||
handler: Async function to call when event occurs
|
||||
filter_stream: Only receive events from this stream
|
||||
filter_node: Only receive events from this node
|
||||
filter_execution: Only receive events from this execution
|
||||
|
||||
Returns:
|
||||
@@ -160,6 +187,7 @@ class EventBus:
|
||||
event_types=set(event_types),
|
||||
handler=handler,
|
||||
filter_stream=filter_stream,
|
||||
filter_node=filter_node,
|
||||
filter_execution=filter_execution,
|
||||
)
|
||||
|
||||
@@ -218,6 +246,10 @@ class EventBus:
|
||||
if subscription.filter_stream and subscription.filter_stream != event.stream_id:
|
||||
return False
|
||||
|
||||
# Check node filter
|
||||
if subscription.filter_node and subscription.filter_node != event.node_id:
|
||||
return False
|
||||
|
||||
# Check execution filter
|
||||
if subscription.filter_execution and subscription.filter_execution != event.execution_id:
|
||||
return False
|
||||
@@ -359,6 +391,248 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
# === NODE EVENT-LOOP PUBLISHERS ===
|
||||
|
||||
async def emit_node_loop_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str | None = None,
|
||||
max_iterations: int | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop started event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"max_iterations": max_iterations},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_loop_iteration(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iteration: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop iteration event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_ITERATION,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"iteration": iteration},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_loop_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iterations: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop completed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"iterations": iterations},
|
||||
)
|
||||
)
|
||||
|
||||
# === LLM STREAMING PUBLISHERS ===
|
||||
|
||||
async def emit_llm_text_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM text delta event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_llm_reasoning_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM reasoning delta event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_REASONING_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content},
|
||||
)
|
||||
)
|
||||
|
||||
# === TOOL LIFECYCLE PUBLISHERS ===
|
||||
|
||||
async def emit_tool_call_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any] | None = None,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool call started event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_use_id": tool_use_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_input": tool_input or {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_tool_call_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
result: str = "",
|
||||
is_error: bool = False,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool call completed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_use_id": tool_use_id,
|
||||
"tool_name": tool_name,
|
||||
"result": result,
|
||||
"is_error": is_error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === CLIENT I/O PUBLISHERS ===
|
||||
|
||||
async def emit_client_output_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit client output delta event (client_facing=True nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_OUTPUT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_client_input_requested(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
prompt: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit client input requested event (client_facing=True nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"prompt": prompt},
|
||||
)
|
||||
)
|
||||
|
||||
# === INTERNAL NODE PUBLISHERS ===
|
||||
|
||||
async def emit_node_internal_output(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node internal output event (client_facing=False nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_INTERNAL_OUTPUT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_stalled(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
reason: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node stalled event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_STALLED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"reason": reason},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_input_blocked(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
prompt: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node input blocked event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_INPUT_BLOCKED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"prompt": prompt},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
@@ -410,6 +684,7 @@ class EventBus:
|
||||
self,
|
||||
event_type: EventType,
|
||||
stream_id: str | None = None,
|
||||
node_id: str | None = None,
|
||||
execution_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AgentEvent | None:
|
||||
@@ -419,6 +694,7 @@ class EventBus:
|
||||
Args:
|
||||
event_type: Type of event to wait for
|
||||
stream_id: Filter by stream
|
||||
node_id: Filter by node
|
||||
execution_id: Filter by execution
|
||||
timeout: Maximum time to wait (seconds)
|
||||
|
||||
@@ -438,6 +714,7 @@ class EventBus:
|
||||
event_types=[event_type],
|
||||
handler=handler,
|
||||
filter_stream=stream_id,
|
||||
filter_node=node_id,
|
||||
filter_execution=execution_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
"""Tests for ContextHandoff and HandoffContext."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SpyLLMProvider(MockLLMProvider):
|
||||
"""MockLLMProvider that records whether complete() was called."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.complete_called = False
|
||||
self.complete_call_args: dict[str, Any] | None = None
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
self.complete_called = True
|
||||
self.complete_call_args = {"messages": messages, **kwargs}
|
||||
return super().complete(messages, **kwargs)
|
||||
|
||||
|
||||
class FailingLLMProvider(LLMProvider):
|
||||
"""LLM provider that always raises."""
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list,
|
||||
tool_executor: Any,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
|
||||
async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation:
|
||||
"""Build a NodeConversation from (user, assistant) message pairs."""
|
||||
conv = NodeConversation()
|
||||
for user_msg, assistant_msg in pairs:
|
||||
await conv.add_user_message(user_msg)
|
||||
await conv.add_assistant_message(assistant_msg)
|
||||
return conv
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestHandoffContext
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandoffContext:
|
||||
def test_instantiation(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="node_A",
|
||||
summary="Summary text",
|
||||
key_outputs={"result": "42"},
|
||||
turn_count=3,
|
||||
total_tokens_used=1200,
|
||||
)
|
||||
assert hc.source_node_id == "node_A"
|
||||
assert hc.summary == "Summary text"
|
||||
assert hc.key_outputs == {"result": "42"}
|
||||
assert hc.turn_count == 3
|
||||
assert hc.total_tokens_used == 1200
|
||||
|
||||
def test_field_access(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="n1",
|
||||
summary="s",
|
||||
key_outputs={},
|
||||
turn_count=0,
|
||||
total_tokens_used=0,
|
||||
)
|
||||
assert hc.key_outputs == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestExtractiveSummary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractiveSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_summary_includes_first_last(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("hello", "First response here."),
|
||||
("continue", "Middle response."),
|
||||
("finish", "Final conclusion."),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="test_node")
|
||||
|
||||
assert "First response here." in hc.summary
|
||||
assert "Final conclusion." in hc.summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_summary_metadata(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("hi", "hello"),
|
||||
("bye", "goodbye"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="node_42")
|
||||
|
||||
assert hc.source_node_id == "node_42"
|
||||
assert hc.turn_count == 2
|
||||
assert hc.total_tokens_used > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_with_output_keys_colon(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("what is the answer?", "answer: 42"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["answer"])
|
||||
|
||||
assert hc.key_outputs["answer"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_with_output_keys_equals(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("compute", "result = success"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["result"])
|
||||
|
||||
assert hc.key_outputs["result"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_json_output_keys(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("give me json", '{"score": 95, "grade": "A"}'),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
|
||||
|
||||
assert hc.key_outputs["score"] == "95"
|
||||
assert hc.key_outputs["grade"] == "A"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_empty_conversation(self) -> None:
|
||||
conv = NodeConversation()
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="empty")
|
||||
|
||||
assert hc.summary == "Empty conversation."
|
||||
assert hc.turn_count == 0
|
||||
assert hc.key_outputs == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_no_assistant_messages(self) -> None:
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("hello?")
|
||||
await conv.add_user_message("anyone there?")
|
||||
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="silent")
|
||||
|
||||
assert hc.summary == "No assistant responses."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_most_recent_wins(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("first", "status: old_value"),
|
||||
("second", "status: new_value"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["status"])
|
||||
|
||||
assert hc.key_outputs["status"] == "new_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_truncation(self) -> None:
|
||||
long_text = "x" * 1000
|
||||
conv = await _build_conversation(
|
||||
("go", long_text),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n")
|
||||
|
||||
# Summary should be truncated to ~500 chars
|
||||
assert len(hc.summary) <= 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestLLMSummary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_summary_calls_provider(self) -> None:
|
||||
llm = SpyLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("hi", "hello back"),
|
||||
("what now?", "we are done"),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
hc = ch.summarize_conversation(conv, node_id="llm_node")
|
||||
|
||||
assert llm.complete_called, "LLM complete() was never invoked"
|
||||
assert hc.summary == "This is a mock response for testing purposes."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_summary_includes_output_key_hint(self) -> None:
|
||||
llm = SpyLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("compute", '{"score": 95}'),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
|
||||
|
||||
assert llm.complete_call_args is not None
|
||||
system = llm.complete_call_args.get("system", "")
|
||||
assert "score" in system
|
||||
assert "grade" in system
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_fallback_on_error(self) -> None:
|
||||
llm = FailingLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("start", "First assistant message."),
|
||||
("end", "Last assistant message."),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
hc = ch.summarize_conversation(conv, node_id="fallback_node")
|
||||
|
||||
# Should fall back to extractive (first + last assistant messages)
|
||||
assert "First assistant message." in hc.summary
|
||||
assert "Last assistant message." in hc.summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestFormatAsInput
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatAsInput:
|
||||
def test_format_structure(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="analyzer",
|
||||
summary="Analysis complete.",
|
||||
key_outputs={"score": "95"},
|
||||
turn_count=5,
|
||||
total_tokens_used=2000,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "--- CONTEXT FROM: analyzer" in output
|
||||
assert "KEY OUTPUTS:" in output
|
||||
assert "SUMMARY:" in output
|
||||
assert "--- END CONTEXT ---" in output
|
||||
|
||||
def test_format_no_key_outputs(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="simple",
|
||||
summary="Done.",
|
||||
key_outputs={},
|
||||
turn_count=1,
|
||||
total_tokens_used=100,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "KEY OUTPUTS:" not in output
|
||||
assert "SUMMARY:" in output
|
||||
|
||||
def test_format_content_values(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="node_X",
|
||||
summary="Found 3 bugs.",
|
||||
key_outputs={"bugs": "3", "severity": "high"},
|
||||
turn_count=7,
|
||||
total_tokens_used=5000,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "node_X" in output
|
||||
assert "7 turns" in output
|
||||
assert "~5000 tokens" in output
|
||||
assert "- bugs: 3" in output
|
||||
assert "- severity: high" in output
|
||||
assert "Found 3 bugs." in output
|
||||
|
||||
def test_format_empty_summary(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="n",
|
||||
summary="",
|
||||
key_outputs={},
|
||||
turn_count=0,
|
||||
total_tokens_used=0,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "No summary available." in output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_as_input_usable_as_message(self) -> None:
|
||||
"""Formatted output can be fed into a NodeConversation as a user message."""
|
||||
hc = HandoffContext(
|
||||
source_node_id="prev_node",
|
||||
summary="Completed analysis.",
|
||||
key_outputs={"result": "42"},
|
||||
turn_count=3,
|
||||
total_tokens_used=900,
|
||||
)
|
||||
text = ContextHandoff.format_as_input(hc)
|
||||
|
||||
conv = NodeConversation()
|
||||
msg = await conv.add_user_message(text)
|
||||
|
||||
assert msg.role == "user"
|
||||
assert "CONTEXT FROM: prev_node" in msg.content
|
||||
assert conv.turn_count == 1
|
||||
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Tests for event_loop node type wiring (Issue #2513).
|
||||
|
||||
Covers:
|
||||
- NodeSpec.client_facing field
|
||||
- event_loop in VALID_NODE_TYPES
|
||||
- _get_node_implementation() event_loop branch
|
||||
- no-retry enforcement in serial execution path
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
class AlwaysFailsNode(NodeProtocol):
|
||||
"""A test node that always fails."""
|
||||
|
||||
def __init__(self):
|
||||
self.attempt_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
self.attempt_count += 1
|
||||
return NodeResult(success=False, error=f"Permanent error (attempt {self.attempt_count})")
|
||||
|
||||
|
||||
class SucceedsOnceNode(NodeProtocol):
|
||||
"""A test node that always succeeds."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"result": "ok"})
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fast_sleep(monkeypatch):
|
||||
"""Mock asyncio.sleep to avoid real delays from exponential backoff."""
|
||||
monkeypatch.setattr("asyncio.sleep", AsyncMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
"""Create a mock Runtime for testing."""
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.start_run = MagicMock(return_value="test_run_id")
|
||||
runtime.decide = MagicMock(return_value="test_decision_id")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
runtime.report_problem = MagicMock()
|
||||
runtime.set_node = MagicMock()
|
||||
return runtime
|
||||
|
||||
|
||||
# --- NodeSpec.client_facing tests ---
|
||||
|
||||
|
||||
def test_client_facing_defaults_false():
|
||||
"""NodeSpec without client_facing should default to False."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="llm_generate",
|
||||
)
|
||||
assert spec.client_facing is False
|
||||
|
||||
|
||||
def test_client_facing_explicit_true():
|
||||
"""NodeSpec with client_facing=True should retain the value."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
)
|
||||
assert spec.client_facing is True
|
||||
|
||||
|
||||
# --- VALID_NODE_TYPES tests ---
|
||||
|
||||
|
||||
def test_event_loop_in_valid_node_types():
|
||||
"""'event_loop' must be in GraphExecutor.VALID_NODE_TYPES."""
|
||||
assert "event_loop" in GraphExecutor.VALID_NODE_TYPES
|
||||
|
||||
|
||||
def test_event_loop_node_spec_accepted():
|
||||
"""Creating a NodeSpec with node_type='event_loop' should not raise."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
assert spec.node_type == "event_loop"
|
||||
|
||||
|
||||
# --- _get_node_implementation() tests ---
|
||||
|
||||
|
||||
def test_unregistered_event_loop_raises(runtime):
|
||||
"""An event_loop node not in the registry should raise RuntimeError."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
|
||||
with pytest.raises(RuntimeError, match="not found in registry"):
|
||||
executor._get_node_implementation(spec)
|
||||
|
||||
|
||||
def test_registered_event_loop_returns_impl(runtime):
|
||||
"""A registered event_loop node should be returned from the registry."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
impl = SucceedsOnceNode()
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("el1", impl)
|
||||
|
||||
result = executor._get_node_implementation(spec)
|
||||
assert result is impl
|
||||
|
||||
|
||||
# --- No-retry enforcement (serial path) ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_forced_zero(runtime):
|
||||
"""An event_loop node with max_retries=3 should only execute once (no retry)."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_fail",
|
||||
name="Failing Event Loop",
|
||||
description="event loop that fails",
|
||||
node_type="event_loop",
|
||||
max_retries=3,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_fail",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_fail"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_fail", failing_node)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Event loop nodes get max_retries overridden to 0, meaning execute once then fail
|
||||
assert not result.success
|
||||
assert failing_node.attempt_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_zero_no_warning(runtime, caplog):
|
||||
"""An event_loop node with max_retries=0 should not log a warning."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_zero",
|
||||
name="Zero Retry Event Loop",
|
||||
description="event loop with 0 retries",
|
||||
node_type="event_loop",
|
||||
max_retries=0,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_zero",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_zero"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_zero", failing_node)
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
# max_retries=0 should not trigger the override warning
|
||||
assert "Overriding to 0" not in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_positive_logs_warning(runtime, caplog):
|
||||
"""An event_loop node with max_retries=3 should log a warning about override."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_warn",
|
||||
name="Warning Event Loop",
|
||||
description="event loop with retries",
|
||||
node_type="event_loop",
|
||||
max_retries=3,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_warn",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_warn"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_warn", failing_node)
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
assert "Overriding to 0" in caplog.text
|
||||
assert "el_warn" in caplog.text
|
||||
|
||||
|
||||
# --- Existing node types unaffected ---
|
||||
|
||||
|
||||
def test_existing_node_types_unchanged():
|
||||
"""All pre-existing node types must still be in VALID_NODE_TYPES with defaults preserved."""
|
||||
expected = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
assert expected.issubset(GraphExecutor.VALID_NODE_TYPES)
|
||||
|
||||
# Default node_type is still llm_tool_use
|
||||
spec = NodeSpec(id="x", name="X", description="x")
|
||||
assert spec.node_type == "llm_tool_use"
|
||||
|
||||
# Default max_retries is still 3
|
||||
assert spec.max_retries == 3
|
||||
|
||||
# Default client_facing is False
|
||||
assert spec.client_facing is False
|
||||
@@ -558,3 +558,313 @@ class TestFileConversationStore:
|
||||
assert (base / "cursor.json").exists()
|
||||
assert (base / "parts" / "0000000000.json").exists()
|
||||
assert (base / "parts" / "0000000001.json").exists()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Integration tests — real FileConversationStore, no mocks
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestConversationIntegration:
|
||||
"""End-to-end tests using real FileConversationStore on disk.
|
||||
|
||||
Every test creates a fresh directory, writes real JSON files,
|
||||
and restores from a *new* store instance (simulating process restart).
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_agent_conversation(self, tmp_path):
|
||||
"""Simulate a realistic agent conversation with multiple turns,
|
||||
tool calls, and tool results — then restore from disk."""
|
||||
base = tmp_path / "agent_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a helpful travel agent.",
|
||||
max_history_tokens=16000,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Turn 1: user asks, assistant responds with tool call
|
||||
await conv.add_user_message("Find me flights from NYC to London next Friday.")
|
||||
await conv.add_assistant_message(
|
||||
"Let me search for flights.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_flight_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_flights",
|
||||
"arguments": '{"origin":"JFK","destination":"LHR","date":"2025-06-13"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_flight_1",
|
||||
'{"flights":[{"airline":"BA","price":450,"departure":"08:00"},{"airline":"AA","price":520,"departure":"14:30"}]}',
|
||||
)
|
||||
|
||||
# Turn 2: assistant presents results, user picks one
|
||||
await conv.add_assistant_message(
|
||||
"I found 2 flights:\n"
|
||||
"1. British Airways at $450, departing 08:00\n"
|
||||
"2. American Airlines at $520, departing 14:30\n"
|
||||
"Which one would you like?"
|
||||
)
|
||||
await conv.add_user_message("Book the British Airways one.")
|
||||
await conv.add_assistant_message(
|
||||
"Booking the BA flight now.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_book_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "book_flight",
|
||||
"arguments": '{"flight_id":"BA-JFK-LHR-0800","passenger":"user"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_book_1",
|
||||
'{"confirmation":"BA-12345","status":"confirmed"}',
|
||||
)
|
||||
await conv.add_assistant_message("Your flight is booked! Confirmation: BA-12345.")
|
||||
|
||||
# Verify in-memory state
|
||||
assert conv.turn_count == 2
|
||||
assert conv.message_count == 8
|
||||
assert conv.next_seq == 8
|
||||
|
||||
# --- Simulate process restart: new store, same path ---
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.system_prompt == "You are a helpful travel agent."
|
||||
assert restored.turn_count == 2
|
||||
assert restored.message_count == 8
|
||||
assert restored.next_seq == 8
|
||||
|
||||
# Verify message content integrity
|
||||
msgs = restored.messages
|
||||
assert msgs[0].role == "user"
|
||||
assert "NYC to London" in msgs[0].content
|
||||
assert msgs[1].role == "assistant"
|
||||
assert msgs[1].tool_calls[0]["id"] == "call_flight_1"
|
||||
assert msgs[2].role == "tool"
|
||||
assert msgs[2].tool_use_id == "call_flight_1"
|
||||
assert "BA" in msgs[2].content
|
||||
assert msgs[7].content == "Your flight is booked! Confirmation: BA-12345."
|
||||
|
||||
# Verify LLM-format output
|
||||
llm_msgs = restored.to_llm_messages()
|
||||
assert llm_msgs[0] == {"role": "user", "content": msgs[0].content}
|
||||
assert llm_msgs[2]["role"] == "tool"
|
||||
assert llm_msgs[2]["tool_call_id"] == "call_flight_1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_and_restore_preserves_continuity(self, tmp_path):
|
||||
"""Build up a long conversation, compact it, continue adding
|
||||
messages, then restore — verifying seq continuity and content."""
|
||||
base = tmp_path / "compact_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="research assistant",
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Build 10 messages (5 turns)
|
||||
for i in range(5):
|
||||
await conv.add_user_message(f"question {i}")
|
||||
await conv.add_assistant_message(f"answer {i}")
|
||||
|
||||
assert conv.message_count == 10
|
||||
assert conv.next_seq == 10
|
||||
|
||||
# Compact: keep last 2 messages (question 4, answer 4)
|
||||
await conv.compact("Summary of questions 0-3 and their answers.", keep_recent=2)
|
||||
|
||||
assert conv.message_count == 3 # summary + 2 recent
|
||||
assert conv.messages[0].content == "Summary of questions 0-3 and their answers."
|
||||
assert conv.messages[1].content == "question 4"
|
||||
assert conv.messages[2].content == "answer 4"
|
||||
|
||||
# Continue the conversation post-compaction
|
||||
await conv.add_user_message("question 5")
|
||||
await conv.add_assistant_message("answer 5")
|
||||
assert conv.next_seq == 12
|
||||
|
||||
# Verify disk: old part files (seq 0-7) should be deleted
|
||||
parts_dir = base / "parts"
|
||||
part_files = sorted(parts_dir.glob("*.json"))
|
||||
part_seqs = [int(f.stem) for f in part_files]
|
||||
# Should have: summary (seq 7), question 4 (seq 8), answer 4 (seq 9),
|
||||
# question 5 (seq 10), answer 5 (seq 11)
|
||||
assert all(s >= 7 for s in part_seqs), f"Stale parts found: {part_seqs}"
|
||||
|
||||
# Restore from fresh store
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.next_seq == 12
|
||||
assert restored.message_count == 5
|
||||
assert "Summary of questions 0-3" in restored.messages[0].content
|
||||
assert restored.messages[-1].content == "answer 5"
|
||||
|
||||
# Verify seq monotonicity across all restored messages
|
||||
seqs = [m.seq for m in restored.messages]
|
||||
assert seqs == sorted(seqs), f"Seqs not monotonic: {seqs}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_key_preservation_through_compact_and_restore(self, tmp_path):
|
||||
"""Output keys in compacted messages survive disk persistence."""
|
||||
base = tmp_path / "output_key_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="classifier",
|
||||
output_keys=["classification", "confidence"],
|
||||
store=store,
|
||||
)
|
||||
|
||||
await conv.add_user_message("Classify this email: 'You won a prize!'")
|
||||
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.97"}')
|
||||
await conv.add_user_message("What about: 'Meeting at 3pm'")
|
||||
await conv.add_assistant_message('{"classification": "ham", "confidence": "0.99"}')
|
||||
await conv.add_user_message("And: 'Buy cheap meds now'")
|
||||
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.95"}')
|
||||
|
||||
# Compact keeping only the last 2 messages
|
||||
await conv.compact("Classified 3 emails.", keep_recent=2)
|
||||
|
||||
# The summary should contain preserved output keys from discarded messages
|
||||
summary_content = conv.messages[0].content
|
||||
assert "PRESERVED VALUES" in summary_content
|
||||
# Most recent values from discarded messages (msgs 0-3) are "ham"/"0.99"
|
||||
assert "ham" in summary_content or "spam" in summary_content
|
||||
|
||||
# Restore and verify the preserved values survived
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
assert "PRESERVED VALUES" in restored.messages[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_roundtrip(self, tmp_path):
|
||||
"""Tool errors persist and restore with ERROR: prefix in LLM output."""
|
||||
base = tmp_path / "error_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(store=store)
|
||||
|
||||
await conv.add_user_message("Calculate 1/0")
|
||||
await conv.add_assistant_message(
|
||||
"Let me calculate that.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_calc",
|
||||
"type": "function",
|
||||
"function": {"name": "calculator", "arguments": '{"expr":"1/0"}'},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_calc", "ZeroDivisionError: division by zero", is_error=True
|
||||
)
|
||||
await conv.add_assistant_message("The calculation failed: division by zero is undefined.")
|
||||
|
||||
# Restore
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
|
||||
tool_msg = restored.messages[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.is_error is True
|
||||
assert tool_msg.tool_use_id == "call_calc"
|
||||
|
||||
llm_dict = tool_msg.to_llm_dict()
|
||||
assert llm_dict["content"].startswith("ERROR: ")
|
||||
assert "ZeroDivisionError" in llm_dict["content"]
|
||||
assert llm_dict["tool_call_id"] == "call_calc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_conversations_isolated(self, tmp_path):
|
||||
"""Two conversations in separate directories don't interfere."""
|
||||
store_a = FileConversationStore(tmp_path / "conv_a")
|
||||
store_b = FileConversationStore(tmp_path / "conv_b")
|
||||
|
||||
conv_a = NodeConversation(system_prompt="Agent A", store=store_a)
|
||||
conv_b = NodeConversation(system_prompt="Agent B", store=store_b)
|
||||
|
||||
await conv_a.add_user_message("Hello from A")
|
||||
await conv_b.add_user_message("Hello from B")
|
||||
await conv_a.add_assistant_message("Response A")
|
||||
await conv_b.add_assistant_message("Response B")
|
||||
await conv_b.add_user_message("Follow-up B")
|
||||
|
||||
# Restore independently
|
||||
restored_a = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_a"))
|
||||
restored_b = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_b"))
|
||||
|
||||
assert restored_a.system_prompt == "Agent A"
|
||||
assert restored_b.system_prompt == "Agent B"
|
||||
assert restored_a.message_count == 2
|
||||
assert restored_b.message_count == 3
|
||||
assert restored_a.messages[0].content == "Hello from A"
|
||||
assert restored_b.messages[2].content == "Follow-up B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_removes_all_files(self, tmp_path):
|
||||
"""destroy() wipes the entire conversation directory."""
|
||||
base = tmp_path / "doomed_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(system_prompt="temp", store=store)
|
||||
await conv.add_user_message("ephemeral")
|
||||
await conv.add_assistant_message("gone soon")
|
||||
|
||||
assert base.exists()
|
||||
assert (base / "meta.json").exists()
|
||||
assert (base / "parts").exists()
|
||||
|
||||
await store.destroy()
|
||||
|
||||
assert not base.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_empty_store_returns_none(self, tmp_path):
|
||||
"""Restoring from a path that was never written to returns None."""
|
||||
store = FileConversationStore(tmp_path / "empty")
|
||||
result = await NodeConversation.restore(store)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_then_continue_then_restore(self, tmp_path):
|
||||
"""clear() removes messages but preserves seq counter for new messages."""
|
||||
base = tmp_path / "clear_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(system_prompt="s", store=store)
|
||||
|
||||
await conv.add_user_message("old msg 0")
|
||||
await conv.add_assistant_message("old msg 1")
|
||||
assert conv.next_seq == 2
|
||||
|
||||
await conv.clear()
|
||||
assert conv.message_count == 0
|
||||
assert conv.next_seq == 2 # seq counter preserved
|
||||
|
||||
# Continue with new messages — seqs should start at 2
|
||||
await conv.add_user_message("new msg")
|
||||
await conv.add_assistant_message("new response")
|
||||
assert conv.next_seq == 4
|
||||
assert conv.messages[0].seq == 2
|
||||
assert conv.messages[1].seq == 3
|
||||
|
||||
# Restore
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
assert restored.message_count == 2
|
||||
assert restored.next_seq == 4
|
||||
assert restored.messages[0].content == "new msg"
|
||||
assert restored.messages[0].seq == 2
|
||||
|
||||
Reference in New Issue
Block a user