chore: update gitignore

This commit is contained in:
Timothy
2026-02-02 10:30:03 -08:00
parent 3240616808
commit 22e816bf86
5 changed files with 2432 additions and 1 deletions
+1 -1
View File
@@ -72,4 +72,4 @@ exports/*
.venv
docs/github-issues/*
core/tests/*
core/tests/*dumps/*
+746
View File
@@ -0,0 +1,746 @@
"""WP-8: Tests for EventLoopNode, OutputAccumulator, LoopConfig, JudgeProtocol.
Uses real FileConversationStore (no mocks for storage) and a MockStreamingLLM
that yields pre-programmed StreamEvents to control the loop deterministically.
"""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from framework.graph.conversation import NodeConversation
from framework.graph.event_loop_node import (
EventLoopNode,
JudgeProtocol,
JudgeVerdict,
LoopConfig,
OutputAccumulator,
)
from framework.graph.node import NodeContext, NodeProtocol, NodeSpec, SharedMemory
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
ToolCallEvent,
)
from framework.runtime.core import Runtime
from framework.runtime.event_bus import EventBus, EventType
from framework.storage.conversation_store import FileConversationStore
# ---------------------------------------------------------------------------
# Mock LLM that yields pre-programmed stream events
# ---------------------------------------------------------------------------
class MockStreamingLLM(LLMProvider):
"""Mock LLM that yields pre-programmed StreamEvent sequences.
Each call to stream() consumes the next scenario from the list.
Cycles back to the beginning if more calls are made than scenarios.
"""
def __init__(self, scenarios: list[list] | None = None):
self.scenarios = scenarios or []
self._call_index = 0
self.stream_calls: list[dict] = []
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator:
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
if not self.scenarios:
return
events = self.scenarios[self._call_index % len(self.scenarios)]
self._call_index += 1
for event in events:
yield event
def complete(self, messages, system="", **kwargs) -> LLMResponse:
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
return LLMResponse(content="", model="mock", stop_reason="stop")
# ---------------------------------------------------------------------------
# Helper: build a simple text-only scenario
# ---------------------------------------------------------------------------
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list:
"""Build a stream scenario that produces text and finishes."""
return [
TextDeltaEvent(content=text, snapshot=text),
FinishEvent(
stop_reason="stop", input_tokens=input_tokens, output_tokens=output_tokens, model="mock"
),
]
def tool_call_scenario(
tool_name: str,
tool_input: dict,
tool_use_id: str = "call_1",
text: str = "",
) -> list:
"""Build a stream scenario that produces a tool call."""
events = []
if text:
events.append(TextDeltaEvent(content=text, snapshot=text))
events.append(
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
)
events.append(
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
)
return events
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def runtime():
rt = MagicMock(spec=Runtime)
rt.start_run = MagicMock(return_value="run_1")
rt.decide = MagicMock(return_value="dec_1")
rt.record_outcome = MagicMock()
rt.end_run = MagicMock()
rt.report_problem = MagicMock()
rt.set_node = MagicMock()
return rt
@pytest.fixture
def node_spec():
return NodeSpec(
id="test_loop",
name="Test Loop",
description="A test event loop node",
node_type="event_loop",
output_keys=["result"],
system_prompt="You are a test assistant.",
)
@pytest.fixture
def memory():
return SharedMemory()
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
"""Build a NodeContext for testing."""
return NodeContext(
runtime=runtime,
node_id=node_spec.id,
node_spec=node_spec,
memory=memory,
input_data=input_data or {},
llm=llm,
available_tools=tools or [],
goal_context=goal_context,
)
# ===========================================================================
# NodeProtocol conformance
# ===========================================================================
class TestNodeProtocolConformance:
def test_subclasses_node_protocol(self):
"""EventLoopNode must be a subclass of NodeProtocol."""
assert issubclass(EventLoopNode, NodeProtocol)
def test_has_execute_method(self):
node = EventLoopNode()
assert hasattr(node, "execute")
assert asyncio.iscoroutinefunction(node.execute)
def test_has_validate_input(self):
node = EventLoopNode()
assert hasattr(node, "validate_input")
# ===========================================================================
# Basic loop execution
# ===========================================================================
class TestBasicLoop:
@pytest.mark.asyncio
async def test_basic_text_only_implicit_accept(self, runtime, node_spec, memory):
"""No tools, no judge. LLM produces text, implicit accept on stop."""
# Override to no output_keys so implicit judge accepts immediately
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("Hello world")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.tokens_used > 0
@pytest.mark.asyncio
async def test_no_llm_returns_failure(self, runtime, node_spec, memory):
"""ctx.llm=None should return failure immediately."""
ctx = build_ctx(runtime, node_spec, memory, llm=None)
node = EventLoopNode()
result = await node.execute(ctx)
assert result.success is False
assert "LLM" in result.error
@pytest.mark.asyncio
async def test_max_iterations_failure(self, runtime, node_spec, memory):
"""When max_iterations is reached without acceptance, should fail."""
# LLM always produces text but never calls set_output, so implicit
# judge retries asking for missing keys
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=2))
result = await node.execute(ctx)
assert result.success is False
assert "Max iterations" in result.error
# ===========================================================================
# Judge integration
# ===========================================================================
class TestJudgeIntegration:
@pytest.mark.asyncio
async def test_judge_accept(self, runtime, node_spec, memory):
"""Mock judge ACCEPT -> success."""
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("Done!")])
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
judge.evaluate.assert_called_once()
@pytest.mark.asyncio
async def test_judge_escalate(self, runtime, node_spec, memory):
"""Mock judge ESCALATE -> failure."""
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("Attempt")])
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(
return_value=JudgeVerdict(action="ESCALATE", feedback="Tone violation")
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is False
assert "escalated" in result.error.lower()
assert "Tone violation" in result.error
@pytest.mark.asyncio
async def test_judge_retry_then_accept(self, runtime, node_spec, memory):
"""RETRY twice, then ACCEPT. Should run 3 iterations."""
node_spec.output_keys = []
llm = MockStreamingLLM(
scenarios=[
text_scenario("attempt 1"),
text_scenario("attempt 2"),
text_scenario("attempt 3"),
]
)
call_count = 0
async def evaluate_fn(context):
nonlocal call_count
call_count += 1
if call_count < 3:
return JudgeVerdict(action="RETRY", feedback="Try harder")
return JudgeVerdict(action="ACCEPT")
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=10))
result = await node.execute(ctx)
assert result.success is True
assert call_count == 3
# ===========================================================================
# set_output tool
# ===========================================================================
class TestSetOutput:
@pytest.mark.asyncio
async def test_set_output_accumulates(self, runtime, node_spec, memory):
"""LLM calls set_output -> values appear in NodeResult.output."""
llm = MockStreamingLLM(
scenarios=[
# Turn 1: call set_output
tool_call_scenario("set_output", {"key": "result", "value": "42"}),
# Turn 2: text response (triggers implicit judge)
text_scenario("Done, result is 42"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "42"
@pytest.mark.asyncio
async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory):
"""set_output with key not in output_keys -> is_error=True."""
llm = MockStreamingLLM(
scenarios=[
# Turn 1: call set_output with bad key
tool_call_scenario("set_output", {"key": "bad_key", "value": "x"}),
# Turn 2: call set_output with good key
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
# Turn 3: text done
text_scenario("Done"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "ok"
assert "bad_key" not in result.output
@pytest.mark.asyncio
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
"""Judge accepts but output keys are missing -> retry with hint."""
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
llm = MockStreamingLLM(
scenarios=[
# Turn 1: text without set_output -> judge accepts but keys missing -> retry
text_scenario("I'll get to it"),
# Turn 2: set_output
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
# Turn 3: text -> judge accepts, keys present -> success
text_scenario("All done"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "done"
# ===========================================================================
# Stall detection
# ===========================================================================
class TestStallDetection:
@pytest.mark.asyncio
async def test_stall_detection(self, runtime, node_spec, memory):
"""3 identical responses should trigger stall detection."""
node_spec.output_keys = [] # so implicit judge would accept
# But we need the judge to RETRY so we actually get 3 identical responses
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
llm = MockStreamingLLM(scenarios=[text_scenario("same answer")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
judge=judge,
config=LoopConfig(max_iterations=10, stall_detection_threshold=3),
)
result = await node.execute(ctx)
assert result.success is False
assert "stalled" in result.error.lower()
# ===========================================================================
# EventBus lifecycle events
# ===========================================================================
class TestEventBusLifecycle:
@pytest.mark.asyncio
async def test_lifecycle_events_published(self, runtime, node_spec, memory):
"""NODE_LOOP_STARTED, NODE_LOOP_ITERATION, NODE_LOOP_COMPLETED should be published."""
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("ok")])
bus = EventBus()
received_events = []
bus.subscribe(
event_types=[
EventType.NODE_LOOP_STARTED,
EventType.NODE_LOOP_ITERATION,
EventType.NODE_LOOP_COMPLETED,
],
handler=lambda e: received_events.append(e.type),
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert EventType.NODE_LOOP_STARTED in received_events
assert EventType.NODE_LOOP_ITERATION in received_events
assert EventType.NODE_LOOP_COMPLETED in received_events
@pytest.mark.asyncio
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
spec = NodeSpec(
id="ui_node",
name="UI Node",
description="Streams to user",
node_type="event_loop",
output_keys=[],
client_facing=True,
)
llm = MockStreamingLLM(scenarios=[text_scenario("visible to user")])
bus = EventBus()
received_types = []
bus.subscribe(
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.LLM_TEXT_DELTA],
handler=lambda e: received_types.append(e.type),
)
ctx = build_ctx(runtime, spec, memory, llm)
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
await node.execute(ctx)
assert EventType.CLIENT_OUTPUT_DELTA in received_types
assert EventType.LLM_TEXT_DELTA not in received_types
# ===========================================================================
# Tool execution
# ===========================================================================
class TestToolExecution:
@pytest.mark.asyncio
async def test_tool_execution_feedback(self, runtime, node_spec, memory):
"""Tool call -> result fed back to conversation via stream loop."""
node_spec.output_keys = []
def my_tool_executor(tool_use: ToolUse) -> ToolResult:
return ToolResult(
tool_use_id=tool_use.id,
content=f"Result for {tool_use.name}",
is_error=False,
)
llm = MockStreamingLLM(
scenarios=[
# Turn 1: call a tool
tool_call_scenario("search", {"query": "test"}, tool_use_id="call_search"),
# Turn 2: text response after seeing tool result
text_scenario("Found the answer"),
]
)
ctx = build_ctx(
runtime,
node_spec,
memory,
llm,
tools=[Tool(name="search", description="Search", parameters={})],
)
node = EventLoopNode(
tool_executor=my_tool_executor,
config=LoopConfig(max_iterations=5),
)
result = await node.execute(ctx)
assert result.success is True
# stream() should have been called twice (tool call turn + final text turn)
assert llm._call_index >= 2
# ===========================================================================
# Write-through persistence with real FileConversationStore
# ===========================================================================
class TestWriteThroughPersistence:
@pytest.mark.asyncio
async def test_messages_written_to_store(self, tmp_path, runtime, node_spec, memory):
"""Messages should be persisted immediately via write-through."""
store = FileConversationStore(tmp_path / "conv")
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("Hello")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
conversation_store=store,
config=LoopConfig(max_iterations=5),
)
result = await node.execute(ctx)
assert result.success is True
# Verify parts were written to disk
parts = await store.read_parts()
assert len(parts) >= 2 # at least initial user msg + assistant msg
@pytest.mark.asyncio
async def test_output_accumulator_write_through(self, tmp_path, runtime, node_spec, memory):
"""set_output values should be persisted in cursor immediately."""
store = FileConversationStore(tmp_path / "conv")
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario("set_output", {"key": "result", "value": "persisted_value"}),
text_scenario("Done"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
conversation_store=store,
config=LoopConfig(max_iterations=5),
)
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "persisted_value"
# Verify output was written to cursor on disk
cursor = await store.read_cursor()
assert cursor is not None
assert cursor["outputs"]["result"] == "persisted_value"
# ===========================================================================
# Crash recovery (restore from real FileConversationStore)
# ===========================================================================
class TestCrashRecovery:
@pytest.mark.asyncio
async def test_restore_from_checkpoint(self, tmp_path, runtime, node_spec, memory):
"""Populate a store with state, then verify EventLoopNode restores from it."""
store = FileConversationStore(tmp_path / "conv")
# Simulate a previous run that wrote conversation + cursor
conv = NodeConversation(
system_prompt="You are a test assistant.",
output_keys=["result"],
store=store,
)
await conv.add_user_message("Initial input")
await conv.add_assistant_message("Working on it...")
# Write cursor with iteration and outputs
await store.write_cursor(
{
"iteration": 1,
"next_seq": conv.next_seq,
"outputs": {"result": "partial_value"},
}
)
# Now create a new EventLoopNode and execute -- it should restore
node_spec.output_keys = [] # no required keys so implicit accept works
llm = MockStreamingLLM(scenarios=[text_scenario("Continuing...")])
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
conversation_store=store,
config=LoopConfig(max_iterations=5),
)
result = await node.execute(ctx)
assert result.success is True
# Should have the restored output
assert result.output.get("result") == "partial_value"
# ===========================================================================
# External event injection
# ===========================================================================
class TestEventInjection:
@pytest.mark.asyncio
async def test_inject_event(self, runtime, node_spec, memory):
"""inject_event() content should appear as user message in next iteration."""
node_spec.output_keys = []
judge_calls = []
async def evaluate_fn(context):
judge_calls.append(context)
if len(judge_calls) >= 2:
return JudgeVerdict(action="ACCEPT")
return JudgeVerdict(action="RETRY")
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
llm = MockStreamingLLM(
scenarios=[
text_scenario("iteration 1"),
text_scenario("iteration 2"),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
judge=judge,
config=LoopConfig(max_iterations=5),
)
# Pre-inject an event before execute runs
await node.inject_event("Priority: CEO wants meeting rescheduled")
result = await node.execute(ctx)
assert result.success is True
# Verify the injected content made it into the LLM messages
all_messages = []
for call in llm.stream_calls:
all_messages.extend(call["messages"])
injected_found = any("[External event]" in str(m.get("content", "")) for m in all_messages)
assert injected_found
# ===========================================================================
# Pause/resume
# ===========================================================================
class TestPauseResume:
@pytest.mark.asyncio
async def test_pause_returns_early(self, runtime, node_spec, memory):
"""pause_requested in input_data should trigger early return."""
node_spec.output_keys = []
llm = MockStreamingLLM(scenarios=[text_scenario("should not run")])
ctx = build_ctx(
runtime,
node_spec,
memory,
llm,
input_data={"pause_requested": True},
)
node = EventLoopNode(config=LoopConfig(max_iterations=10))
result = await node.execute(ctx)
# Should return success (paused, not failed)
assert result.success is True
# LLM should not have been called (paused before first turn)
assert llm._call_index == 0
# ===========================================================================
# Stream errors
# ===========================================================================
class TestStreamErrors:
@pytest.mark.asyncio
async def test_non_recoverable_stream_error_raises(self, runtime, node_spec, memory):
"""Non-recoverable StreamErrorEvent should raise RuntimeError."""
node_spec.output_keys = []
llm = MockStreamingLLM(
scenarios=[
[StreamErrorEvent(error="Connection lost", recoverable=False)],
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
with pytest.raises(RuntimeError, match="Stream error"):
await node.execute(ctx)
# ===========================================================================
# OutputAccumulator unit tests
# ===========================================================================
class TestOutputAccumulator:
@pytest.mark.asyncio
async def test_set_and_get(self):
acc = OutputAccumulator()
await acc.set("key1", "value1")
assert acc.get("key1") == "value1"
assert acc.get("nonexistent") is None
@pytest.mark.asyncio
async def test_to_dict(self):
acc = OutputAccumulator()
await acc.set("a", 1)
await acc.set("b", 2)
assert acc.to_dict() == {"a": 1, "b": 2}
@pytest.mark.asyncio
async def test_has_all_keys(self):
acc = OutputAccumulator()
assert acc.has_all_keys([]) is True
assert acc.has_all_keys(["x"]) is False
await acc.set("x", "val")
assert acc.has_all_keys(["x"]) is True
@pytest.mark.asyncio
async def test_write_through_to_real_store(self, tmp_path):
"""OutputAccumulator should write through to FileConversationStore cursor."""
store = FileConversationStore(tmp_path / "acc_test")
acc = OutputAccumulator(store=store)
await acc.set("result", "hello")
cursor = await store.read_cursor()
assert cursor["outputs"]["result"] == "hello"
@pytest.mark.asyncio
async def test_restore_from_real_store(self, tmp_path):
"""OutputAccumulator.restore() should rebuild from FileConversationStore."""
store = FileConversationStore(tmp_path / "acc_restore")
await store.write_cursor({"outputs": {"key1": "val1", "key2": "val2"}})
acc = await OutputAccumulator.restore(store)
assert acc.get("key1") == "val1"
assert acc.get("key2") == "val2"
assert acc.has_all_keys(["key1", "key2"]) is True
+978
View File
@@ -0,0 +1,978 @@
"""Tests for extending the stream event type system.
Validates that the StreamEvent discriminated union pattern supports:
- Type-based dispatch (matching on event.type)
- Pattern matching / isinstance branching
- Custom event subclasses following the same frozen-dataclass convention
- Serialization of mixed event sequences
WP-2 tests validate EventType enum extension and node-level event routing:
- All 12 new EventType enum members with correct string values
- node_id routing on AgentEvent
- filter_node on Subscription
- Backward compatibility with existing enum members
"""
import asyncio
from dataclasses import FrozenInstanceError, asdict, dataclass, field
from typing import Any, Literal
import pytest
from framework.llm.stream_events import (
FinishEvent,
ReasoningDeltaEvent,
ReasoningStartEvent,
StreamErrorEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
ToolResultEvent,
)
from framework.runtime.event_bus import AgentEvent, EventBus, EventType, Subscription
# ---------------------------------------------------------------------------
# Helpers: type-based dispatch
# ---------------------------------------------------------------------------
def dispatch_event(event) -> str:
"""Dispatch an event by its type field, returning a label."""
handlers = {
"text_delta": lambda e: f"text:{e.content}",
"text_end": lambda e: f"end:{len(e.full_text)}chars",
"tool_call": lambda e: f"call:{e.tool_name}",
"tool_result": lambda e: f"result:{e.tool_use_id}",
"reasoning_start": lambda _: "reasoning:start",
"reasoning_delta": lambda e: f"reasoning:{e.content[:20]}",
"finish": lambda e: f"finish:{e.stop_reason}",
"error": lambda e: f"error:{e.error}",
}
handler = handlers.get(event.type)
if handler is None:
return f"unknown:{event.type}"
return handler(event)
def collect_text(events: list) -> str:
"""Accumulate full text from a stream of events."""
for event in reversed(events):
if isinstance(event, TextEndEvent):
return event.full_text
if isinstance(event, TextDeltaEvent):
return event.snapshot
return ""
def extract_tool_calls(events: list) -> list[dict[str, Any]]:
"""Extract tool call info from a stream of events."""
return [
{"id": e.tool_use_id, "name": e.tool_name, "input": e.tool_input}
for e in events
if isinstance(e, ToolCallEvent)
]
# ---------------------------------------------------------------------------
# Type-based dispatch tests
# ---------------------------------------------------------------------------
class TestTypeDispatch:
"""Dispatch on event.type string for handler routing."""
def test_dispatch_text_delta(self):
e = TextDeltaEvent(content="hello")
assert dispatch_event(e) == "text:hello"
def test_dispatch_text_end(self):
e = TextEndEvent(full_text="hello world")
assert dispatch_event(e) == "end:11chars"
def test_dispatch_tool_call(self):
e = ToolCallEvent(tool_name="web_search")
assert dispatch_event(e) == "call:web_search"
def test_dispatch_tool_result(self):
e = ToolResultEvent(tool_use_id="abc")
assert dispatch_event(e) == "result:abc"
def test_dispatch_reasoning_start(self):
e = ReasoningStartEvent()
assert dispatch_event(e) == "reasoning:start"
def test_dispatch_reasoning_delta(self):
e = ReasoningDeltaEvent(content="Let me think step by step")
assert dispatch_event(e) == "reasoning:Let me think step by"
def test_dispatch_finish(self):
e = FinishEvent(stop_reason="end_turn")
assert dispatch_event(e) == "finish:end_turn"
def test_dispatch_error(self):
e = StreamErrorEvent(error="timeout")
assert dispatch_event(e) == "error:timeout"
# ---------------------------------------------------------------------------
# isinstance-based filtering
# ---------------------------------------------------------------------------
class TestInstanceFiltering:
"""Filter event streams using isinstance for each event type."""
@pytest.fixture
def text_stream(self) -> list:
"""Simulate a text-only stream."""
return [
TextDeltaEvent(content="Hello", snapshot="Hello"),
TextDeltaEvent(content=" world", snapshot="Hello world"),
TextDeltaEvent(content="!", snapshot="Hello world!"),
TextEndEvent(full_text="Hello world!"),
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=3, model="test"),
]
@pytest.fixture
def tool_stream(self) -> list:
"""Simulate a tool call stream."""
return [
ToolCallEvent(
tool_use_id="call_1",
tool_name="get_weather",
tool_input={"city": "London"},
),
ToolCallEvent(
tool_use_id="call_2",
tool_name="calculator",
tool_input={"expression": "2+2"},
),
FinishEvent(stop_reason="tool_calls"),
]
@pytest.fixture
def reasoning_stream(self) -> list:
"""Simulate a stream with reasoning blocks."""
return [
ReasoningStartEvent(),
ReasoningDeltaEvent(content="Let me analyze this..."),
ReasoningDeltaEvent(content="The answer is 42."),
TextDeltaEvent(content="The answer is 42.", snapshot="The answer is 42."),
TextEndEvent(full_text="The answer is 42."),
FinishEvent(stop_reason="end_turn"),
]
def test_collect_text(self, text_stream):
assert collect_text(text_stream) == "Hello world!"
def test_collect_text_from_tool_stream(self, tool_stream):
assert collect_text(tool_stream) == ""
def test_extract_tool_calls(self, tool_stream):
calls = extract_tool_calls(tool_stream)
assert len(calls) == 2
assert calls[0]["name"] == "get_weather"
assert calls[1]["name"] == "calculator"
def test_extract_tool_calls_from_text_stream(self, text_stream):
assert extract_tool_calls(text_stream) == []
def test_filter_text_deltas(self, text_stream):
deltas = [e for e in text_stream if isinstance(e, TextDeltaEvent)]
assert len(deltas) == 3
def test_filter_finish(self, text_stream):
finishes = [e for e in text_stream if isinstance(e, FinishEvent)]
assert len(finishes) == 1
assert finishes[0].stop_reason == "stop"
def test_reasoning_then_text(self, reasoning_stream):
reasoning = [e for e in reasoning_stream if isinstance(e, ReasoningDeltaEvent)]
text = collect_text(reasoning_stream)
assert len(reasoning) == 2
assert text == "The answer is 42."
def test_mixed_stream_type_counts(self, reasoning_stream):
type_counts = {}
for e in reasoning_stream:
type_counts[e.type] = type_counts.get(e.type, 0) + 1
assert type_counts == {
"reasoning_start": 1,
"reasoning_delta": 2,
"text_delta": 1,
"text_end": 1,
"finish": 1,
}
# ---------------------------------------------------------------------------
# Custom event extension pattern
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class CustomMetricsEvent:
"""Example custom event following the same pattern."""
type: Literal["custom_metrics"] = "custom_metrics"
latency_ms: float = 0.0
tokens_per_second: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class CustomCitationEvent:
"""Example citation event extending the pattern."""
type: Literal["citation"] = "citation"
source_url: str = ""
quote: str = ""
confidence: float = 0.0
class TestCustomEventExtension:
"""Custom events should follow the same frozen-dataclass convention."""
def test_custom_event_construction(self):
e = CustomMetricsEvent(latency_ms=150.5, tokens_per_second=42.3)
assert e.type == "custom_metrics"
assert e.latency_ms == 150.5
def test_custom_event_frozen(self):
e = CustomMetricsEvent()
with pytest.raises(FrozenInstanceError):
e.type = "modified"
def test_custom_event_serialization(self):
e = CustomMetricsEvent(
latency_ms=100.0,
tokens_per_second=50.0,
metadata={"provider": "anthropic"},
)
d = asdict(e)
assert d["type"] == "custom_metrics"
assert d["metadata"] == {"provider": "anthropic"}
def test_custom_event_dispatch(self):
"""Custom events can extend the dispatch map."""
e = CustomMetricsEvent(latency_ms=200.0)
# Falls through to "unknown" in our dispatch_event
assert dispatch_event(e) == "unknown:custom_metrics"
def test_custom_event_in_mixed_stream(self):
"""Custom events can coexist with standard events in a list."""
stream = [
TextDeltaEvent(content="hi", snapshot="hi"),
CustomMetricsEvent(latency_ms=50.0),
TextEndEvent(full_text="hi"),
CustomCitationEvent(source_url="https://example.com", quote="hi"),
FinishEvent(stop_reason="stop"),
]
standard = [
e
for e in stream
if hasattr(e, "type")
and e.type
in {
"text_delta",
"text_end",
"tool_call",
"tool_result",
"reasoning_start",
"reasoning_delta",
"finish",
"error",
}
]
custom = [
e
for e in stream
if e.type
not in {
"text_delta",
"text_end",
"tool_call",
"tool_result",
"reasoning_start",
"reasoning_delta",
"finish",
"error",
}
]
assert len(standard) == 3
assert len(custom) == 2
# ---------------------------------------------------------------------------
# Serialization of full event sequences
# ---------------------------------------------------------------------------
class TestSequenceSerialization:
"""Serialize entire event sequences, as done by the dump tests."""
def test_serialize_text_sequence(self):
events = [
TextDeltaEvent(content="Hello", snapshot="Hello"),
TextDeltaEvent(content=" world", snapshot="Hello world"),
TextEndEvent(full_text="Hello world"),
FinishEvent(stop_reason="stop", model="test-model"),
]
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
assert len(serialized) == 4
assert serialized[0]["index"] == 0
assert serialized[0]["type"] == "text_delta"
assert serialized[-1]["type"] == "finish"
assert serialized[-1]["model"] == "test-model"
def test_serialize_tool_sequence(self):
events = [
ToolCallEvent(
tool_use_id="call_1",
tool_name="search",
tool_input={"query": "test"},
),
FinishEvent(stop_reason="tool_calls"),
]
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
assert serialized[0]["tool_input"] == {"query": "test"}
assert serialized[1]["stop_reason"] == "tool_calls"
def test_serialize_error_sequence(self):
events = [
TextDeltaEvent(content="partial"),
StreamErrorEvent(error="connection reset", recoverable=True),
FinishEvent(stop_reason="error"),
]
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
assert serialized[1]["type"] == "error"
assert serialized[1]["recoverable"] is True
def test_roundtrip_snapshot_accumulation(self):
"""Verify snapshot grows monotonically through serialization."""
chunks = ["Hello", " beautiful", " world", "!"]
events = []
snapshot = ""
for chunk in chunks:
snapshot += chunk
events.append(TextDeltaEvent(content=chunk, snapshot=snapshot))
serialized = [asdict(e) for e in events]
for i in range(1, len(serialized)):
assert len(serialized[i]["snapshot"]) > len(serialized[i - 1]["snapshot"])
assert serialized[-1]["snapshot"] == "Hello beautiful world!"
# ===========================================================================
# WP-2: EventType Enum Extension + Node-Level Event Routing
# ===========================================================================
# The 12 new EventType members added by WP-2
WP2_EVENT_TYPES = {
# Node event-loop lifecycle
EventType.NODE_LOOP_STARTED: "node_loop_started",
EventType.NODE_LOOP_ITERATION: "node_loop_iteration",
EventType.NODE_LOOP_COMPLETED: "node_loop_completed",
# LLM streaming observability
EventType.LLM_TEXT_DELTA: "llm_text_delta",
EventType.LLM_REASONING_DELTA: "llm_reasoning_delta",
# Tool lifecycle
EventType.TOOL_CALL_STARTED: "tool_call_started",
EventType.TOOL_CALL_COMPLETED: "tool_call_completed",
# Client I/O
EventType.CLIENT_OUTPUT_DELTA: "client_output_delta",
EventType.CLIENT_INPUT_REQUESTED: "client_input_requested",
# Internal node observability
EventType.NODE_INTERNAL_OUTPUT: "node_internal_output",
EventType.NODE_INPUT_BLOCKED: "node_input_blocked",
EventType.NODE_STALLED: "node_stalled",
}
# Pre-existing enum members that must remain unchanged
ORIGINAL_EVENT_TYPES = {
EventType.EXECUTION_STARTED: "execution_started",
EventType.EXECUTION_COMPLETED: "execution_completed",
EventType.EXECUTION_FAILED: "execution_failed",
EventType.EXECUTION_PAUSED: "execution_paused",
EventType.EXECUTION_RESUMED: "execution_resumed",
EventType.STATE_CHANGED: "state_changed",
EventType.STATE_CONFLICT: "state_conflict",
EventType.GOAL_PROGRESS: "goal_progress",
EventType.GOAL_ACHIEVED: "goal_achieved",
EventType.CONSTRAINT_VIOLATION: "constraint_violation",
EventType.STREAM_STARTED: "stream_started",
EventType.STREAM_STOPPED: "stream_stopped",
EventType.CUSTOM: "custom",
}
# ---------------------------------------------------------------------------
# WP-2 Part A: EventType enum members
# ---------------------------------------------------------------------------
class TestWP2EventTypeEnumMembers:
"""All 12 new EventType members exist with correct string values."""
@pytest.mark.parametrize(
"member,expected_value",
WP2_EVENT_TYPES.items(),
ids=lambda x: x.name if isinstance(x, EventType) else x,
)
def test_new_member_value(self, member, expected_value):
assert member.value == expected_value
def test_all_12_new_members_exist(self):
assert len(WP2_EVENT_TYPES) == 12
def test_new_member_string_values_are_unique(self):
values = list(WP2_EVENT_TYPES.values())
assert len(values) == len(set(values))
def test_no_collision_with_original_members(self):
new_values = set(WP2_EVENT_TYPES.values())
old_values = set(ORIGINAL_EVENT_TYPES.values())
overlap = new_values & old_values
assert overlap == set(), f"Colliding values: {overlap}"
@pytest.mark.parametrize(
"member,expected_value",
ORIGINAL_EVENT_TYPES.items(),
ids=lambda x: x.name if isinstance(x, EventType) else x,
)
def test_original_members_unchanged(self, member, expected_value):
assert member.value == expected_value
def test_event_type_is_str_enum(self):
"""EventType members compare equal to their string values."""
assert EventType.NODE_LOOP_STARTED == "node_loop_started"
assert EventType.LLM_TEXT_DELTA == "llm_text_delta"
assert EventType.LLM_TEXT_DELTA.value == "llm_text_delta"
def test_event_type_accessible_by_name(self):
assert EventType["NODE_LOOP_STARTED"] is EventType.NODE_LOOP_STARTED
assert EventType["TOOL_CALL_COMPLETED"] is EventType.TOOL_CALL_COMPLETED
def test_event_type_accessible_by_value(self):
assert EventType("node_loop_started") is EventType.NODE_LOOP_STARTED
assert EventType("tool_call_completed") is EventType.TOOL_CALL_COMPLETED
# ---------------------------------------------------------------------------
# WP-2 Part B: AgentEvent.node_id and Subscription.filter_node
# ---------------------------------------------------------------------------
class TestWP2AgentEventNodeId:
"""AgentEvent supports node_id as a first-class field."""
def test_node_id_defaults_to_none(self):
event = AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="stream-1",
)
assert event.node_id is None
def test_node_id_can_be_set(self):
event = AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="email_composer",
)
assert event.node_id == "email_composer"
def test_node_id_in_to_dict(self):
event = AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id="stream-1",
node_id="search_node",
)
d = event.to_dict()
assert d["node_id"] == "search_node"
def test_node_id_none_in_to_dict(self):
event = AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="stream-1",
)
d = event.to_dict()
assert "node_id" in d
assert d["node_id"] is None
class TestWP2SubscriptionFilterNode:
"""Subscription supports filter_node for node-level routing."""
@staticmethod
async def _noop_handler(event: AgentEvent) -> None:
pass
def test_filter_node_defaults_to_none(self):
sub = Subscription(
id="sub_1",
event_types={EventType.LLM_TEXT_DELTA},
handler=self._noop_handler,
)
assert sub.filter_node is None
def test_filter_node_can_be_set(self):
sub = Subscription(
id="sub_1",
event_types={EventType.LLM_TEXT_DELTA},
handler=self._noop_handler,
filter_node="email_composer",
)
assert sub.filter_node == "email_composer"
# ---------------------------------------------------------------------------
# WP-2 Part B: Node-level event routing integration tests
# ---------------------------------------------------------------------------
class TestWP2NodeLevelRouting:
"""EventBus routes events by node_id using filter_node."""
@pytest.fixture
def bus(self):
return EventBus()
@pytest.mark.asyncio
async def test_filter_node_receives_matching_events(self, bus):
"""Subscriber with filter_node='node-A' receives events from node-A."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler,
filter_node="node-A",
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="node-A",
data={"content": "hello"},
)
)
assert len(received) == 1
assert received[0].node_id == "node-A"
assert received[0].data["content"] == "hello"
@pytest.mark.asyncio
async def test_filter_node_rejects_non_matching_events(self, bus):
"""Subscriber with filter_node='node-B' does NOT receive node-A events."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler,
filter_node="node-B",
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="node-A",
data={"content": "hello"},
)
)
assert len(received) == 0
@pytest.mark.asyncio
async def test_no_filter_node_receives_all_events(self, bus):
"""Subscriber with no filter_node receives events from all nodes."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler,
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="node-A",
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id="node-B",
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="stream-1",
node_id=None,
)
)
assert len(received) == 3
@pytest.mark.asyncio
async def test_interleaved_nodes_separated_by_filter(self, bus):
"""Two subscribers on different nodes get only their node's events."""
node_a_events = []
node_b_events = []
async def handler_a(event):
node_a_events.append(event)
async def handler_b(event):
node_b_events.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler_a,
filter_node="email_sender",
)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA],
handler=handler_b,
filter_node="inbox_scanner",
)
# Interleaved events from both nodes
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="webhook",
node_id="email_sender",
data={"content": "Dear Jo"},
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="webhook",
node_id="inbox_scanner",
data={"content": "RE: Meeting conf"},
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="webhook",
node_id="email_sender",
data={"content": "hn, Thank you for"},
)
)
await bus.publish(
AgentEvent(
type=EventType.LLM_TEXT_DELTA,
stream_id="webhook",
node_id="inbox_scanner",
data={"content": "irmed for Thursday"},
)
)
assert len(node_a_events) == 2
assert len(node_b_events) == 2
assert node_a_events[0].data["content"] == "Dear Jo"
assert node_a_events[1].data["content"] == "hn, Thank you for"
assert node_b_events[0].data["content"] == "RE: Meeting conf"
assert node_b_events[1].data["content"] == "irmed for Thursday"
@pytest.mark.asyncio
async def test_filter_node_combined_with_filter_stream(self, bus):
"""filter_node and filter_stream work together."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.TOOL_CALL_STARTED],
handler=handler,
filter_stream="webhook",
filter_node="search_node",
)
# Matching both filters
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id="webhook",
node_id="search_node",
)
)
# Wrong stream
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id="api",
node_id="search_node",
)
)
# Wrong node
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_STARTED,
stream_id="webhook",
node_id="other_node",
)
)
assert len(received) == 1
assert received[0].stream_id == "webhook"
assert received[0].node_id == "search_node"
@pytest.mark.asyncio
async def test_wait_for_with_node_id(self, bus):
"""wait_for() accepts node_id parameter for filtering."""
async def publish_later():
await asyncio.sleep(0.01)
await bus.publish(
AgentEvent(
type=EventType.NODE_LOOP_COMPLETED,
stream_id="stream-1",
node_id="target_node",
data={"iterations": 3},
)
)
task = asyncio.create_task(publish_later())
event = await bus.wait_for(
event_type=EventType.NODE_LOOP_COMPLETED,
node_id="target_node",
timeout=2.0,
)
await task
assert event is not None
assert event.node_id == "target_node"
assert event.data["iterations"] == 3
@pytest.mark.asyncio
async def test_wait_for_ignores_wrong_node(self, bus):
"""wait_for() with node_id ignores events from other nodes."""
async def publish_wrong_then_right():
await asyncio.sleep(0.01)
# Wrong node — should be ignored
await bus.publish(
AgentEvent(
type=EventType.NODE_LOOP_COMPLETED,
stream_id="stream-1",
node_id="wrong_node",
)
)
await asyncio.sleep(0.01)
# Right node
await bus.publish(
AgentEvent(
type=EventType.NODE_LOOP_COMPLETED,
stream_id="stream-1",
node_id="target_node",
data={"iterations": 5},
)
)
task = asyncio.create_task(publish_wrong_then_right())
event = await bus.wait_for(
event_type=EventType.NODE_LOOP_COMPLETED,
node_id="target_node",
timeout=2.0,
)
await task
assert event is not None
assert event.node_id == "target_node"
assert event.data["iterations"] == 5
# ---------------------------------------------------------------------------
# WP-2: Convenience publisher methods
# ---------------------------------------------------------------------------
class TestWP2ConveniencePublishers:
"""EventBus convenience methods for new WP-2 event types."""
@pytest.fixture
def bus(self):
return EventBus()
@pytest.mark.asyncio
async def test_emit_node_loop_started(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.NODE_LOOP_STARTED], handler=handler)
await bus.emit_node_loop_started(
stream_id="s1",
node_id="n1",
max_iterations=10,
)
assert len(received) == 1
assert received[0].node_id == "n1"
assert received[0].data["max_iterations"] == 10
@pytest.mark.asyncio
async def test_emit_node_loop_iteration(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.NODE_LOOP_ITERATION], handler=handler)
await bus.emit_node_loop_iteration(
stream_id="s1",
node_id="n1",
iteration=3,
)
assert len(received) == 1
assert received[0].data["iteration"] == 3
@pytest.mark.asyncio
async def test_emit_node_loop_completed(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.NODE_LOOP_COMPLETED], handler=handler)
await bus.emit_node_loop_completed(
stream_id="s1",
node_id="n1",
iterations=5,
)
assert len(received) == 1
assert received[0].data["iterations"] == 5
@pytest.mark.asyncio
async def test_emit_llm_text_delta(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.LLM_TEXT_DELTA], handler=handler)
await bus.emit_llm_text_delta(
stream_id="s1",
node_id="n1",
content="hello",
snapshot="hello world",
)
assert len(received) == 1
assert received[0].data["content"] == "hello"
assert received[0].data["snapshot"] == "hello world"
@pytest.mark.asyncio
async def test_emit_tool_call_started(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler)
await bus.emit_tool_call_started(
stream_id="s1",
node_id="n1",
tool_use_id="call_1",
tool_name="web_search",
tool_input={"query": "test"},
)
assert len(received) == 1
assert received[0].data["tool_name"] == "web_search"
assert received[0].data["tool_input"] == {"query": "test"}
@pytest.mark.asyncio
async def test_emit_tool_call_completed(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler)
await bus.emit_tool_call_completed(
stream_id="s1",
node_id="n1",
tool_use_id="call_1",
tool_name="web_search",
result="3 results found",
)
assert len(received) == 1
assert received[0].data["result"] == "3 results found"
assert received[0].data["is_error"] is False
@pytest.mark.asyncio
async def test_emit_client_output_delta(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.CLIENT_OUTPUT_DELTA], handler=handler)
await bus.emit_client_output_delta(
stream_id="s1",
node_id="n1",
content="chunk",
snapshot="full chunk",
)
assert len(received) == 1
assert received[0].data["content"] == "chunk"
@pytest.mark.asyncio
async def test_emit_node_stalled(self, bus):
received = []
async def handler(event):
received.append(event)
bus.subscribe(event_types=[EventType.NODE_STALLED], handler=handler)
await bus.emit_node_stalled(
stream_id="s1",
node_id="n1",
reason="no progress after 10 iterations",
)
assert len(received) == 1
assert received[0].data["reason"] == "no progress after 10 iterations"
@pytest.mark.asyncio
async def test_convenience_publishers_set_node_id(self, bus):
"""All WP-2 convenience publishers set node_id on the emitted event."""
received = []
async def handler(event):
received.append(event)
bus.subscribe(
event_types=[EventType.LLM_TEXT_DELTA, EventType.TOOL_CALL_STARTED],
handler=handler,
filter_node="my_node",
)
await bus.emit_llm_text_delta(
stream_id="s1",
node_id="my_node",
content="hi",
snapshot="hi",
)
await bus.emit_tool_call_started(
stream_id="s1",
node_id="my_node",
tool_use_id="c1",
tool_name="calc",
)
# Wrong node — should not be received
await bus.emit_llm_text_delta(
stream_id="s1",
node_id="other_node",
content="bye",
snapshot="bye",
)
assert len(received) == 2
assert all(e.node_id == "my_node" for e in received)
+389
View File
@@ -0,0 +1,389 @@
"""Real-API streaming tests for LiteLLM provider.
Calls live LLM APIs and dumps stream events to JSON files for review.
Results are saved to core/tests/stream_event_dumps/{provider}_{model}_{scenario}.json
Run with:
cd core && python -m pytest tests/test_litellm_streaming.py -v -s -k "RealAPI"
Requires API keys set in environment:
ANTHROPIC_API_KEY, OPENAI_API_KEY, GEMINI_API_KEY (or via credential store)
"""
import asyncio
import json
import logging
import os
from dataclasses import asdict
from pathlib import Path
import pytest
from framework.llm.litellm import LiteLLMProvider
from framework.llm.provider import Tool
from framework.llm.stream_events import (
FinishEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
logger = logging.getLogger(__name__)
DUMP_DIR = Path(__file__).parent / "stream_event_dumps"
def _serialize_event(index: int, event: StreamEvent) -> dict:
"""Serialize a StreamEvent to a JSON-safe dict."""
d = asdict(event) # type: ignore[arg-type]
d["index"] = index
# Move index to front for readability
return {"index": index, **{k: v for k, v in d.items() if k != "index"}}
def _dump_events(events: list[StreamEvent], filename: str) -> Path:
"""Write stream events to a JSON file in the dump directory."""
DUMP_DIR.mkdir(parents=True, exist_ok=True)
filepath = DUMP_DIR / filename
serialized = [_serialize_event(i, e) for i, e in enumerate(events)]
filepath.write_text(json.dumps(serialized, indent=2) + "\n")
logger.info(f"Dumped {len(events)} events to {filepath}")
return filepath
async def _collect_stream(provider: LiteLLMProvider, **kwargs) -> list[StreamEvent]:
"""Collect all stream events from a provider.stream() call."""
events: list[StreamEvent] = []
async for event in provider.stream(**kwargs):
events.append(event)
# Log each event type as it arrives
logger.debug(f" [{len(events) - 1}] {event.type}: {event}")
return events
# ---------------------------------------------------------------------------
# Test matrix: (model_id, dump_prefix, env_var_for_skip)
# ---------------------------------------------------------------------------
MODELS = [
(
"anthropic/claude-haiku-4-5-20251001",
"anthropic_claude-haiku-4-5-20251001",
"ANTHROPIC_API_KEY",
),
("gpt-4.1-nano", "gpt-4.1-nano", "OPENAI_API_KEY"),
("gemini/gemini-2.0-flash", "gemini_gemini-2.0-flash", "GEMINI_API_KEY"),
]
WEATHER_TOOL = Tool(
name="get_weather",
description="Get the current weather for a city.",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name, e.g. 'Tokyo'",
}
},
"required": ["city"],
},
)
SEARCH_TOOL = Tool(
name="web_search",
description="Search the web for information.",
parameters={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"num_results": {
"type": "integer",
"description": "Number of results to return (1-10)",
},
},
"required": ["query"],
},
)
CALCULATOR_TOOL = Tool(
name="calculator",
description="Perform arithmetic calculations.",
parameters={
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "Math expression to evaluate, e.g. '2 + 2'",
}
},
"required": ["expression"],
},
)
def _has_api_key(env_var: str) -> bool:
"""Check if an API key is available (env var or credential store)."""
if os.environ.get(env_var):
return True
# Try credential store
try:
from aden_tools.credentials import CredentialStoreAdapter
creds = CredentialStoreAdapter.with_env_storage()
provider_name = env_var.replace("_API_KEY", "").lower()
return creds.is_available(provider_name)
except (ImportError, Exception):
return False
# ---------------------------------------------------------------------------
# Real API tests — text streaming
# ---------------------------------------------------------------------------
class TestRealAPITextStreaming:
"""Stream a simple text response from each provider and dump events."""
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
@pytest.mark.asyncio
async def test_text_stream(self, model: str, prefix: str, env_var: str):
"""Stream a multi-paragraph response to exercise chunked delivery."""
if not _has_api_key(env_var):
pytest.skip(f"{env_var} not set")
provider = LiteLLMProvider(model=model)
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": (
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
"Cover fetch, decode, and execute stages. Be concise but thorough."
),
}
],
system="You are a computer science teacher. Give clear, structured explanations.",
max_tokens=512,
)
# Dump to file
_dump_events(events, f"{prefix}_text.json")
# Basic structural assertions
assert len(events) >= 4, f"Expected at least 4 events, got {len(events)}"
# Must have multiple text deltas for a longer response
text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
assert len(text_deltas) >= 3, f"Expected 3+ TextDeltaEvents, got {len(text_deltas)}"
# Snapshot must accumulate monotonically
for i in range(1, len(text_deltas)):
assert len(text_deltas[i].snapshot) > len(
text_deltas[i - 1].snapshot
), f"Snapshot did not grow at index {i}"
# Must end with TextEndEvent then FinishEvent
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
assert len(text_ends) == 1, f"Expected 1 TextEndEvent, got {len(text_ends)}"
finish_events = [e for e in events if isinstance(e, FinishEvent)]
assert len(finish_events) == 1, f"Expected 1 FinishEvent, got {len(finish_events)}"
assert finish_events[0].stop_reason in ("stop", "end_turn")
# TextEndEvent.full_text should match last snapshot
assert text_ends[0].full_text == text_deltas[-1].snapshot
# Response should actually contain multi-paragraph content
full_text = text_ends[0].full_text
assert len(full_text) > 200, f"Response too short ({len(full_text)} chars)"
# ---------------------------------------------------------------------------
# Real API tests — tool call streaming
# ---------------------------------------------------------------------------
class TestRealAPIToolCallStreaming:
"""Stream a tool call response from each provider and dump events."""
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
@pytest.mark.asyncio
async def test_tool_call_stream(self, model: str, prefix: str, env_var: str):
"""Stream a single tool call with complex arguments."""
if not _has_api_key(env_var):
pytest.skip(f"{env_var} not set")
provider = LiteLLMProvider(model=model)
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": "Search the web for 'Python 3.13 release notes'.",
}
],
system="You have access to tools. Use the appropriate tool.",
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
max_tokens=512,
)
# Dump to file
_dump_events(events, f"{prefix}_tool_call.json")
# Basic structural assertions
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"
# Must have a tool call event
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
assert len(tool_calls) >= 1, "No ToolCallEvent received"
tc = tool_calls[0]
assert tc.tool_name == "web_search"
assert "query" in tc.tool_input
assert tc.tool_use_id != ""
# Must end with FinishEvent
finish_events = [e for e in events if isinstance(e, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason in ("tool_calls", "tool_use", "stop")
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
@pytest.mark.asyncio
async def test_multi_tool_call_stream(self, model: str, prefix: str, env_var: str):
"""Stream a response that should invoke multiple tool calls."""
if not _has_api_key(env_var):
pytest.skip(f"{env_var} not set")
provider = LiteLLMProvider(model=model)
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": (
"I need three things done in parallel: "
"1) Get the weather in London, "
"2) Get the weather in New York, "
"3) Calculate 1337 * 42. "
"Use the tools for all three."
),
}
],
system=(
"You have access to tools. When the user asks for multiple things, "
"call all the needed tools. Always use tools, never guess results."
),
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
max_tokens=512,
)
# Dump to file
_dump_events(events, f"{prefix}_multi_tool.json")
# Must have multiple tool call events
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
assert (
len(tool_calls) >= 2
), f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
# Verify tool names used
tool_names = {tc.tool_name for tc in tool_calls}
assert "get_weather" in tool_names, "Expected get_weather tool call"
# All tool calls should have non-empty IDs
for tc in tool_calls:
assert tc.tool_use_id != "", f"Empty tool_use_id on {tc.tool_name}"
assert tc.tool_input, f"Empty tool_input on {tc.tool_name}"
# Must end with FinishEvent
finish_events = [e for e in events if isinstance(e, FinishEvent)]
assert len(finish_events) == 1
# ---------------------------------------------------------------------------
# Convenience runner for manual invocation
# ---------------------------------------------------------------------------
if __name__ == "__main__":
"""Run all streaming tests and dump results. Usage: python tests/test_litellm_streaming.py"""
ALL_TOOLS = [WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL]
async def _run_all():
for model, prefix, env_var in MODELS:
if not _has_api_key(env_var):
print(f"SKIP {prefix}: {env_var} not set")
continue
provider = LiteLLMProvider(model=model)
# Text streaming (multi-paragraph)
print(f"\n--- {prefix} text ---")
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": (
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
"Cover fetch, decode, and execute stages. Be concise but thorough."
),
}
],
system="You are a computer science teacher. Give clear, structured explanations.",
max_tokens=512,
)
path = _dump_events(events, f"{prefix}_text.json")
print(f" {len(events)} events -> {path}")
for i, e in enumerate(events):
print(f" [{i}] {e.type}: {e}")
# Tool call streaming
print(f"\n--- {prefix} tool_call ---")
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": "Search the web for 'Python 3.13 release notes'.",
}
],
system="You have access to tools. Use the appropriate tool.",
tools=ALL_TOOLS,
max_tokens=512,
)
path = _dump_events(events, f"{prefix}_tool_call.json")
print(f" {len(events)} events -> {path}")
for i, e in enumerate(events):
print(f" [{i}] {e.type}: {e}")
# Multi-tool call streaming
print(f"\n--- {prefix} multi_tool ---")
events = await _collect_stream(
provider,
messages=[
{
"role": "user",
"content": (
"I need three things done in parallel: "
"1) Get the weather in London, "
"2) Get the weather in New York, "
"3) Calculate 1337 * 42. "
"Use the tools for all three."
),
}
],
system=(
"You have access to tools. When the user asks for multiple things, "
"call all the needed tools. Always use tools, never guess results."
),
tools=ALL_TOOLS,
max_tokens=512,
)
path = _dump_events(events, f"{prefix}_multi_tool.json")
print(f" {len(events)} events -> {path}")
for i, e in enumerate(events):
print(f" [{i}] {e.type}: {e}")
logging.basicConfig(level=logging.DEBUG)
asyncio.run(_run_all())
+318
View File
@@ -0,0 +1,318 @@
"""Tests for stream event dataclasses.
Validates construction, defaults, immutability, serialization, and the
StreamEvent discriminated union type.
"""
from dataclasses import FrozenInstanceError, asdict, fields
import pytest
from framework.llm.stream_events import (
FinishEvent,
ReasoningDeltaEvent,
ReasoningStartEvent,
StreamErrorEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
ToolResultEvent,
)
# All concrete event classes in the union
ALL_EVENT_CLASSES = [
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
ToolResultEvent,
ReasoningStartEvent,
ReasoningDeltaEvent,
FinishEvent,
StreamErrorEvent,
]
# ---------------------------------------------------------------------------
# Construction & defaults
# ---------------------------------------------------------------------------
class TestEventDefaults:
"""Each event class should be constructible with zero arguments."""
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_default_construction(self, cls):
event = cls()
assert event.type != ""
def test_text_delta_defaults(self):
e = TextDeltaEvent()
assert e.type == "text_delta"
assert e.content == ""
assert e.snapshot == ""
def test_text_end_defaults(self):
e = TextEndEvent()
assert e.type == "text_end"
assert e.full_text == ""
def test_tool_call_defaults(self):
e = ToolCallEvent()
assert e.type == "tool_call"
assert e.tool_use_id == ""
assert e.tool_name == ""
assert e.tool_input == {}
def test_tool_result_defaults(self):
e = ToolResultEvent()
assert e.type == "tool_result"
assert e.tool_use_id == ""
assert e.content == ""
assert e.is_error is False
def test_reasoning_start_defaults(self):
e = ReasoningStartEvent()
assert e.type == "reasoning_start"
def test_reasoning_delta_defaults(self):
e = ReasoningDeltaEvent()
assert e.type == "reasoning_delta"
assert e.content == ""
def test_finish_defaults(self):
e = FinishEvent()
assert e.type == "finish"
assert e.stop_reason == ""
assert e.input_tokens == 0
assert e.output_tokens == 0
assert e.model == ""
def test_stream_error_defaults(self):
e = StreamErrorEvent()
assert e.type == "error"
assert e.error == ""
assert e.recoverable is False
# ---------------------------------------------------------------------------
# Construction with values
# ---------------------------------------------------------------------------
class TestEventConstruction:
"""Events should store provided field values correctly."""
def test_text_delta_with_values(self):
e = TextDeltaEvent(content="hello", snapshot="hello world")
assert e.content == "hello"
assert e.snapshot == "hello world"
def test_text_end_with_values(self):
e = TextEndEvent(full_text="the complete response")
assert e.full_text == "the complete response"
def test_tool_call_with_values(self):
e = ToolCallEvent(
tool_use_id="call_abc123",
tool_name="web_search",
tool_input={"query": "python", "num_results": 5},
)
assert e.tool_use_id == "call_abc123"
assert e.tool_name == "web_search"
assert e.tool_input == {"query": "python", "num_results": 5}
def test_tool_result_with_values(self):
e = ToolResultEvent(
tool_use_id="call_abc123",
content="search results here",
is_error=False,
)
assert e.tool_use_id == "call_abc123"
assert e.content == "search results here"
assert e.is_error is False
def test_tool_result_error(self):
e = ToolResultEvent(
tool_use_id="call_fail",
content="timeout",
is_error=True,
)
assert e.is_error is True
def test_reasoning_delta_with_content(self):
e = ReasoningDeltaEvent(content="Let me think about this...")
assert e.content == "Let me think about this..."
def test_finish_with_values(self):
e = FinishEvent(
stop_reason="end_turn",
input_tokens=150,
output_tokens=300,
model="claude-haiku-4-5",
)
assert e.stop_reason == "end_turn"
assert e.input_tokens == 150
assert e.output_tokens == 300
assert e.model == "claude-haiku-4-5"
def test_stream_error_with_values(self):
e = StreamErrorEvent(error="rate limit exceeded", recoverable=True)
assert e.error == "rate limit exceeded"
assert e.recoverable is True
# ---------------------------------------------------------------------------
# Frozen immutability
# ---------------------------------------------------------------------------
class TestEventImmutability:
"""All events are frozen dataclasses — fields cannot be reassigned."""
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_frozen(self, cls):
event = cls()
with pytest.raises(FrozenInstanceError):
event.type = "modified"
def test_text_delta_frozen_content(self):
e = TextDeltaEvent(content="hello")
with pytest.raises(FrozenInstanceError):
e.content = "modified"
def test_tool_call_frozen_input(self):
e = ToolCallEvent(tool_input={"key": "value"})
with pytest.raises(FrozenInstanceError):
e.tool_input = {}
# ---------------------------------------------------------------------------
# Type literal values
# ---------------------------------------------------------------------------
class TestTypeLiterals:
"""Each event's `type` field should match its Literal annotation."""
EXPECTED_TYPES = {
TextDeltaEvent: "text_delta",
TextEndEvent: "text_end",
ToolCallEvent: "tool_call",
ToolResultEvent: "tool_result",
ReasoningStartEvent: "reasoning_start",
ReasoningDeltaEvent: "reasoning_delta",
FinishEvent: "finish",
StreamErrorEvent: "error",
}
@pytest.mark.parametrize(
"cls,expected_type",
EXPECTED_TYPES.items(),
ids=lambda x: x.__name__ if isinstance(x, type) else x,
)
def test_type_value(self, cls, expected_type):
assert cls().type == expected_type
def test_all_types_unique(self):
types = [cls().type for cls in ALL_EVENT_CLASSES]
assert len(types) == len(set(types)), f"Duplicate type values: {types}"
# ---------------------------------------------------------------------------
# Serialization via dataclasses.asdict
# ---------------------------------------------------------------------------
class TestEventSerialization:
"""Events should round-trip through asdict for JSON serialization."""
def test_text_delta_asdict(self):
e = TextDeltaEvent(content="chunk", snapshot="full chunk")
d = asdict(e)
assert d == {"type": "text_delta", "content": "chunk", "snapshot": "full chunk"}
def test_tool_call_asdict(self):
e = ToolCallEvent(
tool_use_id="id_1",
tool_name="calc",
tool_input={"expression": "2+2"},
)
d = asdict(e)
assert d["tool_name"] == "calc"
assert d["tool_input"] == {"expression": "2+2"}
def test_finish_asdict(self):
e = FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=20, model="gpt-4")
d = asdict(e)
assert d == {
"type": "finish",
"stop_reason": "stop",
"input_tokens": 10,
"output_tokens": 20,
"model": "gpt-4",
}
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_asdict_contains_type(self, cls):
d = asdict(cls())
assert "type" in d
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_asdict_keys_match_fields(self, cls):
event = cls()
d = asdict(event)
field_names = {f.name for f in fields(cls)}
assert set(d.keys()) == field_names
# ---------------------------------------------------------------------------
# StreamEvent union type
# ---------------------------------------------------------------------------
class TestStreamEventUnion:
"""The StreamEvent union should include all event classes."""
def test_union_contains_all_classes(self):
# StreamEvent is a UnionType (PEP 604 syntax: X | Y | Z)
union_args = StreamEvent.__args__ # type: ignore[attr-defined]
for cls in ALL_EVENT_CLASSES:
assert cls in union_args, f"{cls.__name__} not in StreamEvent union"
def test_union_has_exactly_expected_members(self):
union_args = set(StreamEvent.__args__) # type: ignore[attr-defined]
expected = set(ALL_EVENT_CLASSES)
assert union_args == expected
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
def test_isinstance_check(self, cls):
"""Each event instance should be an instance of its class (basic sanity)."""
event = cls()
assert isinstance(event, cls)
# ---------------------------------------------------------------------------
# Equality & hashing (frozen dataclasses support both)
# ---------------------------------------------------------------------------
class TestEventEquality:
"""Frozen dataclasses support equality and hashing."""
def test_equal_events(self):
a = TextDeltaEvent(content="hi", snapshot="hi")
b = TextDeltaEvent(content="hi", snapshot="hi")
assert a == b
def test_unequal_events(self):
a = TextDeltaEvent(content="hi")
b = TextDeltaEvent(content="bye")
assert a != b
def test_different_types_not_equal(self):
a = TextDeltaEvent(content="hi")
b = ReasoningDeltaEvent(content="hi")
assert a != b
def test_hashable(self):
e = FinishEvent(stop_reason="stop", model="gpt-4")
s = {e} # should be hashable since frozen
assert e in s
def test_equal_events_same_hash(self):
a = FinishEvent(stop_reason="stop", model="gpt-4")
b = FinishEvent(stop_reason="stop", model="gpt-4")
assert hash(a) == hash(b)
def test_events_with_dict_not_hashable(self):
"""Events containing dict fields (e.g. tool_input) are not hashable."""
e = ToolCallEvent(tool_use_id="x", tool_name="y", tool_input={"key": "val"})
with pytest.raises(TypeError, match="unhashable type"):
hash(e)