chore: update gitignore
This commit is contained in:
+1
-1
@@ -72,4 +72,4 @@ exports/*
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*
|
||||
core/tests/*dumps/*
|
||||
|
||||
@@ -0,0 +1,746 @@
|
||||
"""WP-8: Tests for EventLoopNode, OutputAccumulator, LoopConfig, JudgeProtocol.
|
||||
|
||||
Uses real FileConversationStore (no mocks for storage) and a MockStreamingLLM
|
||||
that yields pre-programmed StreamEvents to control the loop deterministically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
JudgeProtocol,
|
||||
JudgeVerdict,
|
||||
LoopConfig,
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeSpec, SharedMemory
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.event_bus import EventBus, EventType
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM that yields pre-programmed stream events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences.
|
||||
|
||||
Each call to stream() consumes the next scenario from the list.
|
||||
Cycles back to the beginning if more calls are made than scenarios.
|
||||
"""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build a simple text-only scenario
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list:
|
||||
"""Build a stream scenario that produces text and finishes."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(
|
||||
stop_reason="stop", input_tokens=input_tokens, output_tokens=output_tokens, model="mock"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def tool_call_scenario(
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
tool_use_id: str = "call_1",
|
||||
text: str = "",
|
||||
) -> list:
|
||||
"""Build a stream scenario that produces a tool call."""
|
||||
events = []
|
||||
if text:
|
||||
events.append(TextDeltaEvent(content=text, snapshot=text))
|
||||
events.append(
|
||||
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
|
||||
)
|
||||
events.append(
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_spec():
|
||||
return NodeSpec(
|
||||
id="test_loop",
|
||||
name="Test Loop",
|
||||
description="A test event loop node",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
system_prompt="You are a test assistant.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory():
|
||||
return SharedMemory()
|
||||
|
||||
|
||||
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
|
||||
"""Build a NodeContext for testing."""
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_spec.id,
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data=input_data or {},
|
||||
llm=llm,
|
||||
available_tools=tools or [],
|
||||
goal_context=goal_context,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# NodeProtocol conformance
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNodeProtocolConformance:
|
||||
def test_subclasses_node_protocol(self):
|
||||
"""EventLoopNode must be a subclass of NodeProtocol."""
|
||||
assert issubclass(EventLoopNode, NodeProtocol)
|
||||
|
||||
def test_has_execute_method(self):
|
||||
node = EventLoopNode()
|
||||
assert hasattr(node, "execute")
|
||||
assert asyncio.iscoroutinefunction(node.execute)
|
||||
|
||||
def test_has_validate_input(self):
|
||||
node = EventLoopNode()
|
||||
assert hasattr(node, "validate_input")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Basic loop execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestBasicLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_text_only_implicit_accept(self, runtime, node_spec, memory):
|
||||
"""No tools, no judge. LLM produces text, implicit accept on stop."""
|
||||
# Override to no output_keys so implicit judge accepts immediately
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello world")])
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.tokens_used > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_llm_returns_failure(self, runtime, node_spec, memory):
|
||||
"""ctx.llm=None should return failure immediately."""
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm=None)
|
||||
|
||||
node = EventLoopNode()
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "LLM" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_failure(self, runtime, node_spec, memory):
|
||||
"""When max_iterations is reached without acceptance, should fail."""
|
||||
# LLM always produces text but never calls set_output, so implicit
|
||||
# judge retries asking for missing keys
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "Max iterations" in result.error
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Judge integration
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestJudgeIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_accept(self, runtime, node_spec, memory):
|
||||
"""Mock judge ACCEPT -> success."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Done!")])
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
judge.evaluate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_escalate(self, runtime, node_spec, memory):
|
||||
"""Mock judge ESCALATE -> failure."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Attempt")])
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(
|
||||
return_value=JudgeVerdict(action="ESCALATE", feedback="Tone violation")
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "escalated" in result.error.lower()
|
||||
assert "Tone violation" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_retry_then_accept(self, runtime, node_spec, memory):
|
||||
"""RETRY twice, then ACCEPT. Should run 3 iterations."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("attempt 1"),
|
||||
text_scenario("attempt 2"),
|
||||
text_scenario("attempt 3"),
|
||||
]
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def evaluate_fn(context):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
return JudgeVerdict(action="RETRY", feedback="Try harder")
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=10))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert call_count == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# set_output tool
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSetOutput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_accumulates(self, runtime, node_spec, memory):
|
||||
"""LLM calls set_output -> values appear in NodeResult.output."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call set_output
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "42"}),
|
||||
# Turn 2: text response (triggers implicit judge)
|
||||
text_scenario("Done, result is 42"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory):
|
||||
"""set_output with key not in output_keys -> is_error=True."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call set_output with bad key
|
||||
tool_call_scenario("set_output", {"key": "bad_key", "value": "x"}),
|
||||
# Turn 2: call set_output with good key
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
|
||||
# Turn 3: text done
|
||||
text_scenario("Done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "ok"
|
||||
assert "bad_key" not in result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
|
||||
"""Judge accepts but output keys are missing -> retry with hint."""
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: text without set_output -> judge accepts but keys missing -> retry
|
||||
text_scenario("I'll get to it"),
|
||||
# Turn 2: set_output
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
# Turn 3: text -> judge accepts, keys present -> success
|
||||
text_scenario("All done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "done"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stall detection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStallDetection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stall_detection(self, runtime, node_spec, memory):
|
||||
"""3 identical responses should trigger stall detection."""
|
||||
node_spec.output_keys = [] # so implicit judge would accept
|
||||
# But we need the judge to RETRY so we actually get 3 identical responses
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
|
||||
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("same answer")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=10, stall_detection_threshold=3),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "stalled" in result.error.lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# EventBus lifecycle events
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEventBusLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_events_published(self, runtime, node_spec, memory):
|
||||
"""NODE_LOOP_STARTED, NODE_LOOP_ITERATION, NODE_LOOP_COMPLETED should be published."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("ok")])
|
||||
bus = EventBus()
|
||||
|
||||
received_events = []
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
],
|
||||
handler=lambda e: received_events.append(e.type),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert EventType.NODE_LOOP_STARTED in received_events
|
||||
assert EventType.NODE_LOOP_ITERATION in received_events
|
||||
assert EventType.NODE_LOOP_COMPLETED in received_events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
|
||||
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
|
||||
spec = NodeSpec(
|
||||
id="ui_node",
|
||||
name="UI Node",
|
||||
description="Streams to user",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("visible to user")])
|
||||
bus = EventBus()
|
||||
|
||||
received_types = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.LLM_TEXT_DELTA],
|
||||
handler=lambda e: received_types.append(e.type),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
await node.execute(ctx)
|
||||
|
||||
assert EventType.CLIENT_OUTPUT_DELTA in received_types
|
||||
assert EventType.LLM_TEXT_DELTA not in received_types
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tool execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestToolExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_feedback(self, runtime, node_spec, memory):
|
||||
"""Tool call -> result fed back to conversation via stream loop."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
def my_tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=f"Result for {tool_use.name}",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call a tool
|
||||
tool_call_scenario("search", {"query": "test"}, tool_use_id="call_search"),
|
||||
# Turn 2: text response after seeing tool result
|
||||
text_scenario("Found the answer"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="Search", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
tool_executor=my_tool_executor,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
# stream() should have been called twice (tool call turn + final text turn)
|
||||
assert llm._call_index >= 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Write-through persistence with real FileConversationStore
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWriteThroughPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_written_to_store(self, tmp_path, runtime, node_spec, memory):
|
||||
"""Messages should be persisted immediately via write-through."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
# Verify parts were written to disk
|
||||
parts = await store.read_parts()
|
||||
assert len(parts) >= 2 # at least initial user msg + assistant msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_accumulator_write_through(self, tmp_path, runtime, node_spec, memory):
|
||||
"""set_output values should be persisted in cursor immediately."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "persisted_value"}),
|
||||
text_scenario("Done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "persisted_value"
|
||||
|
||||
# Verify output was written to cursor on disk
|
||||
cursor = await store.read_cursor()
|
||||
assert cursor is not None
|
||||
assert cursor["outputs"]["result"] == "persisted_value"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Crash recovery (restore from real FileConversationStore)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCrashRecovery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_checkpoint(self, tmp_path, runtime, node_spec, memory):
|
||||
"""Populate a store with state, then verify EventLoopNode restores from it."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
|
||||
# Simulate a previous run that wrote conversation + cursor
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a test assistant.",
|
||||
output_keys=["result"],
|
||||
store=store,
|
||||
)
|
||||
await conv.add_user_message("Initial input")
|
||||
await conv.add_assistant_message("Working on it...")
|
||||
|
||||
# Write cursor with iteration and outputs
|
||||
await store.write_cursor(
|
||||
{
|
||||
"iteration": 1,
|
||||
"next_seq": conv.next_seq,
|
||||
"outputs": {"result": "partial_value"},
|
||||
}
|
||||
)
|
||||
|
||||
# Now create a new EventLoopNode and execute -- it should restore
|
||||
node_spec.output_keys = [] # no required keys so implicit accept works
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Continuing...")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
# Should have the restored output
|
||||
assert result.output.get("result") == "partial_value"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# External event injection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEventInjection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_event(self, runtime, node_spec, memory):
|
||||
"""inject_event() content should appear as user message in next iteration."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
judge_calls = []
|
||||
|
||||
async def evaluate_fn(context):
|
||||
judge_calls.append(context)
|
||||
if len(judge_calls) >= 2:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("iteration 1"),
|
||||
text_scenario("iteration 2"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
|
||||
# Pre-inject an event before execute runs
|
||||
await node.inject_event("Priority: CEO wants meeting rescheduled")
|
||||
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
# Verify the injected content made it into the LLM messages
|
||||
all_messages = []
|
||||
for call in llm.stream_calls:
|
||||
all_messages.extend(call["messages"])
|
||||
injected_found = any("[External event]" in str(m.get("content", "")) for m in all_messages)
|
||||
assert injected_found
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Pause/resume
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestPauseResume:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_returns_early(self, runtime, node_spec, memory):
|
||||
"""pause_requested in input_data should trigger early return."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("should not run")])
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
input_data={"pause_requested": True},
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
# Should return success (paused, not failed)
|
||||
assert result.success is True
|
||||
# LLM should not have been called (paused before first turn)
|
||||
assert llm._call_index == 0
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stream errors
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamErrors:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_recoverable_stream_error_raises(self, runtime, node_spec, memory):
|
||||
"""Non-recoverable StreamErrorEvent should raise RuntimeError."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
[StreamErrorEvent(error="Connection lost", recoverable=False)],
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
|
||||
with pytest.raises(RuntimeError, match="Stream error"):
|
||||
await node.execute(ctx)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# OutputAccumulator unit tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOutputAccumulator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get(self):
|
||||
acc = OutputAccumulator()
|
||||
await acc.set("key1", "value1")
|
||||
assert acc.get("key1") == "value1"
|
||||
assert acc.get("nonexistent") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_dict(self):
|
||||
acc = OutputAccumulator()
|
||||
await acc.set("a", 1)
|
||||
await acc.set("b", 2)
|
||||
assert acc.to_dict() == {"a": 1, "b": 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_all_keys(self):
|
||||
acc = OutputAccumulator()
|
||||
assert acc.has_all_keys([]) is True
|
||||
assert acc.has_all_keys(["x"]) is False
|
||||
await acc.set("x", "val")
|
||||
assert acc.has_all_keys(["x"]) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_through_to_real_store(self, tmp_path):
|
||||
"""OutputAccumulator should write through to FileConversationStore cursor."""
|
||||
store = FileConversationStore(tmp_path / "acc_test")
|
||||
acc = OutputAccumulator(store=store)
|
||||
|
||||
await acc.set("result", "hello")
|
||||
|
||||
cursor = await store.read_cursor()
|
||||
assert cursor["outputs"]["result"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_real_store(self, tmp_path):
|
||||
"""OutputAccumulator.restore() should rebuild from FileConversationStore."""
|
||||
store = FileConversationStore(tmp_path / "acc_restore")
|
||||
await store.write_cursor({"outputs": {"key1": "val1", "key2": "val2"}})
|
||||
|
||||
acc = await OutputAccumulator.restore(store)
|
||||
assert acc.get("key1") == "val1"
|
||||
assert acc.get("key2") == "val2"
|
||||
assert acc.has_all_keys(["key1", "key2"]) is True
|
||||
@@ -0,0 +1,978 @@
|
||||
"""Tests for extending the stream event type system.
|
||||
|
||||
Validates that the StreamEvent discriminated union pattern supports:
|
||||
- Type-based dispatch (matching on event.type)
|
||||
- Pattern matching / isinstance branching
|
||||
- Custom event subclasses following the same frozen-dataclass convention
|
||||
- Serialization of mixed event sequences
|
||||
|
||||
WP-2 tests validate EventType enum extension and node-level event routing:
|
||||
- All 12 new EventType enum members with correct string values
|
||||
- node_id routing on AgentEvent
|
||||
- filter_node on Subscription
|
||||
- Backward compatibility with existing enum members
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import FrozenInstanceError, asdict, dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType, Subscription
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers: type-based dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
def dispatch_event(event) -> str:
|
||||
"""Dispatch an event by its type field, returning a label."""
|
||||
handlers = {
|
||||
"text_delta": lambda e: f"text:{e.content}",
|
||||
"text_end": lambda e: f"end:{len(e.full_text)}chars",
|
||||
"tool_call": lambda e: f"call:{e.tool_name}",
|
||||
"tool_result": lambda e: f"result:{e.tool_use_id}",
|
||||
"reasoning_start": lambda _: "reasoning:start",
|
||||
"reasoning_delta": lambda e: f"reasoning:{e.content[:20]}",
|
||||
"finish": lambda e: f"finish:{e.stop_reason}",
|
||||
"error": lambda e: f"error:{e.error}",
|
||||
}
|
||||
handler = handlers.get(event.type)
|
||||
if handler is None:
|
||||
return f"unknown:{event.type}"
|
||||
return handler(event)
|
||||
|
||||
|
||||
def collect_text(events: list) -> str:
|
||||
"""Accumulate full text from a stream of events."""
|
||||
for event in reversed(events):
|
||||
if isinstance(event, TextEndEvent):
|
||||
return event.full_text
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
return event.snapshot
|
||||
return ""
|
||||
|
||||
|
||||
def extract_tool_calls(events: list) -> list[dict[str, Any]]:
|
||||
"""Extract tool call info from a stream of events."""
|
||||
return [
|
||||
{"id": e.tool_use_id, "name": e.tool_name, "input": e.tool_input}
|
||||
for e in events
|
||||
if isinstance(e, ToolCallEvent)
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type-based dispatch tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTypeDispatch:
|
||||
"""Dispatch on event.type string for handler routing."""
|
||||
|
||||
def test_dispatch_text_delta(self):
|
||||
e = TextDeltaEvent(content="hello")
|
||||
assert dispatch_event(e) == "text:hello"
|
||||
|
||||
def test_dispatch_text_end(self):
|
||||
e = TextEndEvent(full_text="hello world")
|
||||
assert dispatch_event(e) == "end:11chars"
|
||||
|
||||
def test_dispatch_tool_call(self):
|
||||
e = ToolCallEvent(tool_name="web_search")
|
||||
assert dispatch_event(e) == "call:web_search"
|
||||
|
||||
def test_dispatch_tool_result(self):
|
||||
e = ToolResultEvent(tool_use_id="abc")
|
||||
assert dispatch_event(e) == "result:abc"
|
||||
|
||||
def test_dispatch_reasoning_start(self):
|
||||
e = ReasoningStartEvent()
|
||||
assert dispatch_event(e) == "reasoning:start"
|
||||
|
||||
def test_dispatch_reasoning_delta(self):
|
||||
e = ReasoningDeltaEvent(content="Let me think step by step")
|
||||
assert dispatch_event(e) == "reasoning:Let me think step by"
|
||||
|
||||
def test_dispatch_finish(self):
|
||||
e = FinishEvent(stop_reason="end_turn")
|
||||
assert dispatch_event(e) == "finish:end_turn"
|
||||
|
||||
def test_dispatch_error(self):
|
||||
e = StreamErrorEvent(error="timeout")
|
||||
assert dispatch_event(e) == "error:timeout"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# isinstance-based filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestInstanceFiltering:
|
||||
"""Filter event streams using isinstance for each event type."""
|
||||
|
||||
@pytest.fixture
|
||||
def text_stream(self) -> list:
|
||||
"""Simulate a text-only stream."""
|
||||
return [
|
||||
TextDeltaEvent(content="Hello", snapshot="Hello"),
|
||||
TextDeltaEvent(content=" world", snapshot="Hello world"),
|
||||
TextDeltaEvent(content="!", snapshot="Hello world!"),
|
||||
TextEndEvent(full_text="Hello world!"),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=3, model="test"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def tool_stream(self) -> list:
|
||||
"""Simulate a tool call stream."""
|
||||
return [
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_1",
|
||||
tool_name="get_weather",
|
||||
tool_input={"city": "London"},
|
||||
),
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_2",
|
||||
tool_name="calculator",
|
||||
tool_input={"expression": "2+2"},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def reasoning_stream(self) -> list:
|
||||
"""Simulate a stream with reasoning blocks."""
|
||||
return [
|
||||
ReasoningStartEvent(),
|
||||
ReasoningDeltaEvent(content="Let me analyze this..."),
|
||||
ReasoningDeltaEvent(content="The answer is 42."),
|
||||
TextDeltaEvent(content="The answer is 42.", snapshot="The answer is 42."),
|
||||
TextEndEvent(full_text="The answer is 42."),
|
||||
FinishEvent(stop_reason="end_turn"),
|
||||
]
|
||||
|
||||
def test_collect_text(self, text_stream):
|
||||
assert collect_text(text_stream) == "Hello world!"
|
||||
|
||||
def test_collect_text_from_tool_stream(self, tool_stream):
|
||||
assert collect_text(tool_stream) == ""
|
||||
|
||||
def test_extract_tool_calls(self, tool_stream):
|
||||
calls = extract_tool_calls(tool_stream)
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["name"] == "get_weather"
|
||||
assert calls[1]["name"] == "calculator"
|
||||
|
||||
def test_extract_tool_calls_from_text_stream(self, text_stream):
|
||||
assert extract_tool_calls(text_stream) == []
|
||||
|
||||
def test_filter_text_deltas(self, text_stream):
|
||||
deltas = [e for e in text_stream if isinstance(e, TextDeltaEvent)]
|
||||
assert len(deltas) == 3
|
||||
|
||||
def test_filter_finish(self, text_stream):
|
||||
finishes = [e for e in text_stream if isinstance(e, FinishEvent)]
|
||||
assert len(finishes) == 1
|
||||
assert finishes[0].stop_reason == "stop"
|
||||
|
||||
def test_reasoning_then_text(self, reasoning_stream):
|
||||
reasoning = [e for e in reasoning_stream if isinstance(e, ReasoningDeltaEvent)]
|
||||
text = collect_text(reasoning_stream)
|
||||
assert len(reasoning) == 2
|
||||
assert text == "The answer is 42."
|
||||
|
||||
def test_mixed_stream_type_counts(self, reasoning_stream):
|
||||
type_counts = {}
|
||||
for e in reasoning_stream:
|
||||
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
||||
assert type_counts == {
|
||||
"reasoning_start": 1,
|
||||
"reasoning_delta": 2,
|
||||
"text_delta": 1,
|
||||
"text_end": 1,
|
||||
"finish": 1,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom event extension pattern
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass(frozen=True)
|
||||
class CustomMetricsEvent:
|
||||
"""Example custom event following the same pattern."""
|
||||
|
||||
type: Literal["custom_metrics"] = "custom_metrics"
|
||||
latency_ms: float = 0.0
|
||||
tokens_per_second: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CustomCitationEvent:
|
||||
"""Example citation event extending the pattern."""
|
||||
|
||||
type: Literal["citation"] = "citation"
|
||||
source_url: str = ""
|
||||
quote: str = ""
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class TestCustomEventExtension:
|
||||
"""Custom events should follow the same frozen-dataclass convention."""
|
||||
|
||||
def test_custom_event_construction(self):
|
||||
e = CustomMetricsEvent(latency_ms=150.5, tokens_per_second=42.3)
|
||||
assert e.type == "custom_metrics"
|
||||
assert e.latency_ms == 150.5
|
||||
|
||||
def test_custom_event_frozen(self):
|
||||
e = CustomMetricsEvent()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.type = "modified"
|
||||
|
||||
def test_custom_event_serialization(self):
|
||||
e = CustomMetricsEvent(
|
||||
latency_ms=100.0,
|
||||
tokens_per_second=50.0,
|
||||
metadata={"provider": "anthropic"},
|
||||
)
|
||||
d = asdict(e)
|
||||
assert d["type"] == "custom_metrics"
|
||||
assert d["metadata"] == {"provider": "anthropic"}
|
||||
|
||||
def test_custom_event_dispatch(self):
|
||||
"""Custom events can extend the dispatch map."""
|
||||
e = CustomMetricsEvent(latency_ms=200.0)
|
||||
# Falls through to "unknown" in our dispatch_event
|
||||
assert dispatch_event(e) == "unknown:custom_metrics"
|
||||
|
||||
def test_custom_event_in_mixed_stream(self):
|
||||
"""Custom events can coexist with standard events in a list."""
|
||||
stream = [
|
||||
TextDeltaEvent(content="hi", snapshot="hi"),
|
||||
CustomMetricsEvent(latency_ms=50.0),
|
||||
TextEndEvent(full_text="hi"),
|
||||
CustomCitationEvent(source_url="https://example.com", quote="hi"),
|
||||
FinishEvent(stop_reason="stop"),
|
||||
]
|
||||
standard = [
|
||||
e
|
||||
for e in stream
|
||||
if hasattr(e, "type")
|
||||
and e.type
|
||||
in {
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"tool_call",
|
||||
"tool_result",
|
||||
"reasoning_start",
|
||||
"reasoning_delta",
|
||||
"finish",
|
||||
"error",
|
||||
}
|
||||
]
|
||||
custom = [
|
||||
e
|
||||
for e in stream
|
||||
if e.type
|
||||
not in {
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"tool_call",
|
||||
"tool_result",
|
||||
"reasoning_start",
|
||||
"reasoning_delta",
|
||||
"finish",
|
||||
"error",
|
||||
}
|
||||
]
|
||||
assert len(standard) == 3
|
||||
assert len(custom) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization of full event sequences
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSequenceSerialization:
|
||||
"""Serialize entire event sequences, as done by the dump tests."""
|
||||
|
||||
def test_serialize_text_sequence(self):
|
||||
events = [
|
||||
TextDeltaEvent(content="Hello", snapshot="Hello"),
|
||||
TextDeltaEvent(content=" world", snapshot="Hello world"),
|
||||
TextEndEvent(full_text="Hello world"),
|
||||
FinishEvent(stop_reason="stop", model="test-model"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert len(serialized) == 4
|
||||
assert serialized[0]["index"] == 0
|
||||
assert serialized[0]["type"] == "text_delta"
|
||||
assert serialized[-1]["type"] == "finish"
|
||||
assert serialized[-1]["model"] == "test-model"
|
||||
|
||||
def test_serialize_tool_sequence(self):
|
||||
events = [
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_1",
|
||||
tool_name="search",
|
||||
tool_input={"query": "test"},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert serialized[0]["tool_input"] == {"query": "test"}
|
||||
assert serialized[1]["stop_reason"] == "tool_calls"
|
||||
|
||||
def test_serialize_error_sequence(self):
|
||||
events = [
|
||||
TextDeltaEvent(content="partial"),
|
||||
StreamErrorEvent(error="connection reset", recoverable=True),
|
||||
FinishEvent(stop_reason="error"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert serialized[1]["type"] == "error"
|
||||
assert serialized[1]["recoverable"] is True
|
||||
|
||||
def test_roundtrip_snapshot_accumulation(self):
|
||||
"""Verify snapshot grows monotonically through serialization."""
|
||||
chunks = ["Hello", " beautiful", " world", "!"]
|
||||
events = []
|
||||
snapshot = ""
|
||||
for chunk in chunks:
|
||||
snapshot += chunk
|
||||
events.append(TextDeltaEvent(content=chunk, snapshot=snapshot))
|
||||
|
||||
serialized = [asdict(e) for e in events]
|
||||
for i in range(1, len(serialized)):
|
||||
assert len(serialized[i]["snapshot"]) > len(serialized[i - 1]["snapshot"])
|
||||
assert serialized[-1]["snapshot"] == "Hello beautiful world!"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# WP-2: EventType Enum Extension + Node-Level Event Routing
|
||||
# ===========================================================================
|
||||
|
||||
# The 12 new EventType members added by WP-2
|
||||
WP2_EVENT_TYPES = {
|
||||
# Node event-loop lifecycle
|
||||
EventType.NODE_LOOP_STARTED: "node_loop_started",
|
||||
EventType.NODE_LOOP_ITERATION: "node_loop_iteration",
|
||||
EventType.NODE_LOOP_COMPLETED: "node_loop_completed",
|
||||
# LLM streaming observability
|
||||
EventType.LLM_TEXT_DELTA: "llm_text_delta",
|
||||
EventType.LLM_REASONING_DELTA: "llm_reasoning_delta",
|
||||
# Tool lifecycle
|
||||
EventType.TOOL_CALL_STARTED: "tool_call_started",
|
||||
EventType.TOOL_CALL_COMPLETED: "tool_call_completed",
|
||||
# Client I/O
|
||||
EventType.CLIENT_OUTPUT_DELTA: "client_output_delta",
|
||||
EventType.CLIENT_INPUT_REQUESTED: "client_input_requested",
|
||||
# Internal node observability
|
||||
EventType.NODE_INTERNAL_OUTPUT: "node_internal_output",
|
||||
EventType.NODE_INPUT_BLOCKED: "node_input_blocked",
|
||||
EventType.NODE_STALLED: "node_stalled",
|
||||
}
|
||||
|
||||
# Pre-existing enum members that must remain unchanged
|
||||
ORIGINAL_EVENT_TYPES = {
|
||||
EventType.EXECUTION_STARTED: "execution_started",
|
||||
EventType.EXECUTION_COMPLETED: "execution_completed",
|
||||
EventType.EXECUTION_FAILED: "execution_failed",
|
||||
EventType.EXECUTION_PAUSED: "execution_paused",
|
||||
EventType.EXECUTION_RESUMED: "execution_resumed",
|
||||
EventType.STATE_CHANGED: "state_changed",
|
||||
EventType.STATE_CONFLICT: "state_conflict",
|
||||
EventType.GOAL_PROGRESS: "goal_progress",
|
||||
EventType.GOAL_ACHIEVED: "goal_achieved",
|
||||
EventType.CONSTRAINT_VIOLATION: "constraint_violation",
|
||||
EventType.STREAM_STARTED: "stream_started",
|
||||
EventType.STREAM_STOPPED: "stream_stopped",
|
||||
EventType.CUSTOM: "custom",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part A: EventType enum members
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2EventTypeEnumMembers:
|
||||
"""All 12 new EventType members exist with correct string values."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,expected_value",
|
||||
WP2_EVENT_TYPES.items(),
|
||||
ids=lambda x: x.name if isinstance(x, EventType) else x,
|
||||
)
|
||||
def test_new_member_value(self, member, expected_value):
|
||||
assert member.value == expected_value
|
||||
|
||||
def test_all_12_new_members_exist(self):
|
||||
assert len(WP2_EVENT_TYPES) == 12
|
||||
|
||||
def test_new_member_string_values_are_unique(self):
|
||||
values = list(WP2_EVENT_TYPES.values())
|
||||
assert len(values) == len(set(values))
|
||||
|
||||
def test_no_collision_with_original_members(self):
|
||||
new_values = set(WP2_EVENT_TYPES.values())
|
||||
old_values = set(ORIGINAL_EVENT_TYPES.values())
|
||||
overlap = new_values & old_values
|
||||
assert overlap == set(), f"Colliding values: {overlap}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,expected_value",
|
||||
ORIGINAL_EVENT_TYPES.items(),
|
||||
ids=lambda x: x.name if isinstance(x, EventType) else x,
|
||||
)
|
||||
def test_original_members_unchanged(self, member, expected_value):
|
||||
assert member.value == expected_value
|
||||
|
||||
def test_event_type_is_str_enum(self):
|
||||
"""EventType members compare equal to their string values."""
|
||||
assert EventType.NODE_LOOP_STARTED == "node_loop_started"
|
||||
assert EventType.LLM_TEXT_DELTA == "llm_text_delta"
|
||||
assert EventType.LLM_TEXT_DELTA.value == "llm_text_delta"
|
||||
|
||||
def test_event_type_accessible_by_name(self):
|
||||
assert EventType["NODE_LOOP_STARTED"] is EventType.NODE_LOOP_STARTED
|
||||
assert EventType["TOOL_CALL_COMPLETED"] is EventType.TOOL_CALL_COMPLETED
|
||||
|
||||
def test_event_type_accessible_by_value(self):
|
||||
assert EventType("node_loop_started") is EventType.NODE_LOOP_STARTED
|
||||
assert EventType("tool_call_completed") is EventType.TOOL_CALL_COMPLETED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part B: AgentEvent.node_id and Subscription.filter_node
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2AgentEventNodeId:
|
||||
"""AgentEvent supports node_id as a first-class field."""
|
||||
|
||||
def test_node_id_defaults_to_none(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="stream-1",
|
||||
)
|
||||
assert event.node_id is None
|
||||
|
||||
def test_node_id_can_be_set(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="email_composer",
|
||||
)
|
||||
assert event.node_id == "email_composer"
|
||||
|
||||
def test_node_id_in_to_dict(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="stream-1",
|
||||
node_id="search_node",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert d["node_id"] == "search_node"
|
||||
|
||||
def test_node_id_none_in_to_dict(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="stream-1",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert "node_id" in d
|
||||
assert d["node_id"] is None
|
||||
|
||||
|
||||
class TestWP2SubscriptionFilterNode:
|
||||
"""Subscription supports filter_node for node-level routing."""
|
||||
|
||||
@staticmethod
|
||||
async def _noop_handler(event: AgentEvent) -> None:
|
||||
pass
|
||||
|
||||
def test_filter_node_defaults_to_none(self):
|
||||
sub = Subscription(
|
||||
id="sub_1",
|
||||
event_types={EventType.LLM_TEXT_DELTA},
|
||||
handler=self._noop_handler,
|
||||
)
|
||||
assert sub.filter_node is None
|
||||
|
||||
def test_filter_node_can_be_set(self):
|
||||
sub = Subscription(
|
||||
id="sub_1",
|
||||
event_types={EventType.LLM_TEXT_DELTA},
|
||||
handler=self._noop_handler,
|
||||
filter_node="email_composer",
|
||||
)
|
||||
assert sub.filter_node == "email_composer"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part B: Node-level event routing integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2NodeLevelRouting:
|
||||
"""EventBus routes events by node_id using filter_node."""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
return EventBus()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_receives_matching_events(self, bus):
|
||||
"""Subscriber with filter_node='node-A' receives events from node-A."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
filter_node="node-A",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
data={"content": "hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].node_id == "node-A"
|
||||
assert received[0].data["content"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_rejects_non_matching_events(self, bus):
|
||||
"""Subscriber with filter_node='node-B' does NOT receive node-A events."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
filter_node="node-B",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
data={"content": "hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_filter_node_receives_all_events(self, bus):
|
||||
"""Subscriber with no filter_node receives events from all nodes."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-B",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interleaved_nodes_separated_by_filter(self, bus):
|
||||
"""Two subscribers on different nodes get only their node's events."""
|
||||
node_a_events = []
|
||||
node_b_events = []
|
||||
|
||||
async def handler_a(event):
|
||||
node_a_events.append(event)
|
||||
|
||||
async def handler_b(event):
|
||||
node_b_events.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler_a,
|
||||
filter_node="email_sender",
|
||||
)
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler_b,
|
||||
filter_node="inbox_scanner",
|
||||
)
|
||||
|
||||
# Interleaved events from both nodes
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="email_sender",
|
||||
data={"content": "Dear Jo"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="inbox_scanner",
|
||||
data={"content": "RE: Meeting conf"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="email_sender",
|
||||
data={"content": "hn, Thank you for"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="inbox_scanner",
|
||||
data={"content": "irmed for Thursday"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(node_a_events) == 2
|
||||
assert len(node_b_events) == 2
|
||||
assert node_a_events[0].data["content"] == "Dear Jo"
|
||||
assert node_a_events[1].data["content"] == "hn, Thank you for"
|
||||
assert node_b_events[0].data["content"] == "RE: Meeting conf"
|
||||
assert node_b_events[1].data["content"] == "irmed for Thursday"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_combined_with_filter_stream(self, bus):
|
||||
"""filter_node and filter_stream work together."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.TOOL_CALL_STARTED],
|
||||
handler=handler,
|
||||
filter_stream="webhook",
|
||||
filter_node="search_node",
|
||||
)
|
||||
|
||||
# Matching both filters
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="webhook",
|
||||
node_id="search_node",
|
||||
)
|
||||
)
|
||||
# Wrong stream
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="api",
|
||||
node_id="search_node",
|
||||
)
|
||||
)
|
||||
# Wrong node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="webhook",
|
||||
node_id="other_node",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].stream_id == "webhook"
|
||||
assert received[0].node_id == "search_node"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_with_node_id(self, bus):
|
||||
"""wait_for() accepts node_id parameter for filtering."""
|
||||
|
||||
async def publish_later():
|
||||
await asyncio.sleep(0.01)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="target_node",
|
||||
data={"iterations": 3},
|
||||
)
|
||||
)
|
||||
|
||||
task = asyncio.create_task(publish_later())
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.NODE_LOOP_COMPLETED,
|
||||
node_id="target_node",
|
||||
timeout=2.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "target_node"
|
||||
assert event.data["iterations"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_ignores_wrong_node(self, bus):
|
||||
"""wait_for() with node_id ignores events from other nodes."""
|
||||
|
||||
async def publish_wrong_then_right():
|
||||
await asyncio.sleep(0.01)
|
||||
# Wrong node — should be ignored
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="wrong_node",
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
# Right node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="target_node",
|
||||
data={"iterations": 5},
|
||||
)
|
||||
)
|
||||
|
||||
task = asyncio.create_task(publish_wrong_then_right())
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.NODE_LOOP_COMPLETED,
|
||||
node_id="target_node",
|
||||
timeout=2.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "target_node"
|
||||
assert event.data["iterations"] == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2: Convenience publisher methods
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2ConveniencePublishers:
|
||||
"""EventBus convenience methods for new WP-2 event types."""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
return EventBus()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_started(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_STARTED], handler=handler)
|
||||
await bus.emit_node_loop_started(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].node_id == "n1"
|
||||
assert received[0].data["max_iterations"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_iteration(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_ITERATION], handler=handler)
|
||||
await bus.emit_node_loop_iteration(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
iteration=3,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["iteration"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_completed(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_COMPLETED], handler=handler)
|
||||
await bus.emit_node_loop_completed(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
iterations=5,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["iterations"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_llm_text_delta(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.LLM_TEXT_DELTA], handler=handler)
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
content="hello",
|
||||
snapshot="hello world",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["content"] == "hello"
|
||||
assert received[0].data["snapshot"] == "hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_started(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler)
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
tool_use_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "test"},
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["tool_name"] == "web_search"
|
||||
assert received[0].data["tool_input"] == {"query": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_completed(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler)
|
||||
await bus.emit_tool_call_completed(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
tool_use_id="call_1",
|
||||
tool_name="web_search",
|
||||
result="3 results found",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["result"] == "3 results found"
|
||||
assert received[0].data["is_error"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_client_output_delta(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.CLIENT_OUTPUT_DELTA], handler=handler)
|
||||
await bus.emit_client_output_delta(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
content="chunk",
|
||||
snapshot="full chunk",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["content"] == "chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_stalled(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_STALLED], handler=handler)
|
||||
await bus.emit_node_stalled(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
reason="no progress after 10 iterations",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["reason"] == "no progress after 10 iterations"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convenience_publishers_set_node_id(self, bus):
|
||||
"""All WP-2 convenience publishers set node_id on the emitted event."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA, EventType.TOOL_CALL_STARTED],
|
||||
handler=handler,
|
||||
filter_node="my_node",
|
||||
)
|
||||
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="my_node",
|
||||
content="hi",
|
||||
snapshot="hi",
|
||||
)
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s1",
|
||||
node_id="my_node",
|
||||
tool_use_id="c1",
|
||||
tool_name="calc",
|
||||
)
|
||||
# Wrong node — should not be received
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="other_node",
|
||||
content="bye",
|
||||
snapshot="bye",
|
||||
)
|
||||
|
||||
assert len(received) == 2
|
||||
assert all(e.node_id == "my_node" for e in received)
|
||||
@@ -0,0 +1,389 @@
|
||||
"""Real-API streaming tests for LiteLLM provider.
|
||||
|
||||
Calls live LLM APIs and dumps stream events to JSON files for review.
|
||||
Results are saved to core/tests/stream_event_dumps/{provider}_{model}_{scenario}.json
|
||||
|
||||
Run with:
|
||||
cd core && python -m pytest tests/test_litellm_streaming.py -v -s -k "RealAPI"
|
||||
|
||||
Requires API keys set in environment:
|
||||
ANTHROPIC_API_KEY, OPENAI_API_KEY, GEMINI_API_KEY (or via credential store)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMP_DIR = Path(__file__).parent / "stream_event_dumps"
|
||||
|
||||
|
||||
def _serialize_event(index: int, event: StreamEvent) -> dict:
|
||||
"""Serialize a StreamEvent to a JSON-safe dict."""
|
||||
d = asdict(event) # type: ignore[arg-type]
|
||||
d["index"] = index
|
||||
# Move index to front for readability
|
||||
return {"index": index, **{k: v for k, v in d.items() if k != "index"}}
|
||||
|
||||
|
||||
def _dump_events(events: list[StreamEvent], filename: str) -> Path:
|
||||
"""Write stream events to a JSON file in the dump directory."""
|
||||
DUMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filepath = DUMP_DIR / filename
|
||||
serialized = [_serialize_event(i, e) for i, e in enumerate(events)]
|
||||
filepath.write_text(json.dumps(serialized, indent=2) + "\n")
|
||||
logger.info(f"Dumped {len(events)} events to {filepath}")
|
||||
return filepath
|
||||
|
||||
|
||||
async def _collect_stream(provider: LiteLLMProvider, **kwargs) -> list[StreamEvent]:
|
||||
"""Collect all stream events from a provider.stream() call."""
|
||||
events: list[StreamEvent] = []
|
||||
async for event in provider.stream(**kwargs):
|
||||
events.append(event)
|
||||
# Log each event type as it arrives
|
||||
logger.debug(f" [{len(events) - 1}] {event.type}: {event}")
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test matrix: (model_id, dump_prefix, env_var_for_skip)
|
||||
# ---------------------------------------------------------------------------
|
||||
MODELS = [
|
||||
(
|
||||
"anthropic/claude-haiku-4-5-20251001",
|
||||
"anthropic_claude-haiku-4-5-20251001",
|
||||
"ANTHROPIC_API_KEY",
|
||||
),
|
||||
("gpt-4.1-nano", "gpt-4.1-nano", "OPENAI_API_KEY"),
|
||||
("gemini/gemini-2.0-flash", "gemini_gemini-2.0-flash", "GEMINI_API_KEY"),
|
||||
]
|
||||
|
||||
WEATHER_TOOL = Tool(
|
||||
name="get_weather",
|
||||
description="Get the current weather for a city.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name, e.g. 'Tokyo'",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
)
|
||||
|
||||
SEARCH_TOOL = Tool(
|
||||
name="web_search",
|
||||
description="Search the web for information.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
|
||||
CALCULATOR_TOOL = Tool(
|
||||
name="calculator",
|
||||
description="Perform arithmetic calculations.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Math expression to evaluate, e.g. '2 + 2'",
|
||||
}
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _has_api_key(env_var: str) -> bool:
|
||||
"""Check if an API key is available (env var or credential store)."""
|
||||
if os.environ.get(env_var):
|
||||
return True
|
||||
# Try credential store
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
provider_name = env_var.replace("_API_KEY", "").lower()
|
||||
return creds.is_available(provider_name)
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — text streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRealAPITextStreaming:
|
||||
"""Stream a simple text response from each provider and dump events."""
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a multi-paragraph response to exercise chunked delivery."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
|
||||
"Cover fetch, decode, and execute stages. Be concise but thorough."
|
||||
),
|
||||
}
|
||||
],
|
||||
system="You are a computer science teacher. Give clear, structured explanations.",
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_text.json")
|
||||
|
||||
# Basic structural assertions
|
||||
assert len(events) >= 4, f"Expected at least 4 events, got {len(events)}"
|
||||
|
||||
# Must have multiple text deltas for a longer response
|
||||
text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
|
||||
assert len(text_deltas) >= 3, f"Expected 3+ TextDeltaEvents, got {len(text_deltas)}"
|
||||
|
||||
# Snapshot must accumulate monotonically
|
||||
for i in range(1, len(text_deltas)):
|
||||
assert len(text_deltas[i].snapshot) > len(
|
||||
text_deltas[i - 1].snapshot
|
||||
), f"Snapshot did not grow at index {i}"
|
||||
|
||||
# Must end with TextEndEvent then FinishEvent
|
||||
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
|
||||
assert len(text_ends) == 1, f"Expected 1 TextEndEvent, got {len(text_ends)}"
|
||||
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1, f"Expected 1 FinishEvent, got {len(finish_events)}"
|
||||
assert finish_events[0].stop_reason in ("stop", "end_turn")
|
||||
|
||||
# TextEndEvent.full_text should match last snapshot
|
||||
assert text_ends[0].full_text == text_deltas[-1].snapshot
|
||||
|
||||
# Response should actually contain multi-paragraph content
|
||||
full_text = text_ends[0].full_text
|
||||
assert len(full_text) > 200, f"Response too short ({len(full_text)} chars)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — tool call streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRealAPIToolCallStreaming:
|
||||
"""Stream a tool call response from each provider and dump events."""
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a single tool call with complex arguments."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web for 'Python 3.13 release notes'.",
|
||||
}
|
||||
],
|
||||
system="You have access to tools. Use the appropriate tool.",
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_tool_call.json")
|
||||
|
||||
# Basic structural assertions
|
||||
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"
|
||||
|
||||
# Must have a tool call event
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert len(tool_calls) >= 1, "No ToolCallEvent received"
|
||||
|
||||
tc = tool_calls[0]
|
||||
assert tc.tool_name == "web_search"
|
||||
assert "query" in tc.tool_input
|
||||
assert tc.tool_use_id != ""
|
||||
|
||||
# Must end with FinishEvent
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason in ("tool_calls", "tool_use", "stop")
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_tool_call_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a response that should invoke multiple tool calls."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I need three things done in parallel: "
|
||||
"1) Get the weather in London, "
|
||||
"2) Get the weather in New York, "
|
||||
"3) Calculate 1337 * 42. "
|
||||
"Use the tools for all three."
|
||||
),
|
||||
}
|
||||
],
|
||||
system=(
|
||||
"You have access to tools. When the user asks for multiple things, "
|
||||
"call all the needed tools. Always use tools, never guess results."
|
||||
),
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_multi_tool.json")
|
||||
|
||||
# Must have multiple tool call events
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert (
|
||||
len(tool_calls) >= 2
|
||||
), f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
|
||||
|
||||
# Verify tool names used
|
||||
tool_names = {tc.tool_name for tc in tool_calls}
|
||||
assert "get_weather" in tool_names, "Expected get_weather tool call"
|
||||
|
||||
# All tool calls should have non-empty IDs
|
||||
for tc in tool_calls:
|
||||
assert tc.tool_use_id != "", f"Empty tool_use_id on {tc.tool_name}"
|
||||
assert tc.tool_input, f"Empty tool_input on {tc.tool_name}"
|
||||
|
||||
# Must end with FinishEvent
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience runner for manual invocation
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
"""Run all streaming tests and dump results. Usage: python tests/test_litellm_streaming.py"""
|
||||
|
||||
ALL_TOOLS = [WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL]
|
||||
|
||||
async def _run_all():
|
||||
for model, prefix, env_var in MODELS:
|
||||
if not _has_api_key(env_var):
|
||||
print(f"SKIP {prefix}: {env_var} not set")
|
||||
continue
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
|
||||
# Text streaming (multi-paragraph)
|
||||
print(f"\n--- {prefix} text ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
|
||||
"Cover fetch, decode, and execute stages. Be concise but thorough."
|
||||
),
|
||||
}
|
||||
],
|
||||
system="You are a computer science teacher. Give clear, structured explanations.",
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_text.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
# Tool call streaming
|
||||
print(f"\n--- {prefix} tool_call ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web for 'Python 3.13 release notes'.",
|
||||
}
|
||||
],
|
||||
system="You have access to tools. Use the appropriate tool.",
|
||||
tools=ALL_TOOLS,
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_tool_call.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
# Multi-tool call streaming
|
||||
print(f"\n--- {prefix} multi_tool ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I need three things done in parallel: "
|
||||
"1) Get the weather in London, "
|
||||
"2) Get the weather in New York, "
|
||||
"3) Calculate 1337 * 42. "
|
||||
"Use the tools for all three."
|
||||
),
|
||||
}
|
||||
],
|
||||
system=(
|
||||
"You have access to tools. When the user asks for multiple things, "
|
||||
"call all the needed tools. Always use tools, never guess results."
|
||||
),
|
||||
tools=ALL_TOOLS,
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_multi_tool.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(_run_all())
|
||||
@@ -0,0 +1,318 @@
|
||||
"""Tests for stream event dataclasses.
|
||||
|
||||
Validates construction, defaults, immutability, serialization, and the
|
||||
StreamEvent discriminated union type.
|
||||
"""
|
||||
|
||||
from dataclasses import FrozenInstanceError, asdict, fields
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
# All concrete event classes in the union
|
||||
ALL_EVENT_CLASSES = [
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
ReasoningStartEvent,
|
||||
ReasoningDeltaEvent,
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction & defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventDefaults:
|
||||
"""Each event class should be constructible with zero arguments."""
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_default_construction(self, cls):
|
||||
event = cls()
|
||||
assert event.type != ""
|
||||
|
||||
def test_text_delta_defaults(self):
|
||||
e = TextDeltaEvent()
|
||||
assert e.type == "text_delta"
|
||||
assert e.content == ""
|
||||
assert e.snapshot == ""
|
||||
|
||||
def test_text_end_defaults(self):
|
||||
e = TextEndEvent()
|
||||
assert e.type == "text_end"
|
||||
assert e.full_text == ""
|
||||
|
||||
def test_tool_call_defaults(self):
|
||||
e = ToolCallEvent()
|
||||
assert e.type == "tool_call"
|
||||
assert e.tool_use_id == ""
|
||||
assert e.tool_name == ""
|
||||
assert e.tool_input == {}
|
||||
|
||||
def test_tool_result_defaults(self):
|
||||
e = ToolResultEvent()
|
||||
assert e.type == "tool_result"
|
||||
assert e.tool_use_id == ""
|
||||
assert e.content == ""
|
||||
assert e.is_error is False
|
||||
|
||||
def test_reasoning_start_defaults(self):
|
||||
e = ReasoningStartEvent()
|
||||
assert e.type == "reasoning_start"
|
||||
|
||||
def test_reasoning_delta_defaults(self):
|
||||
e = ReasoningDeltaEvent()
|
||||
assert e.type == "reasoning_delta"
|
||||
assert e.content == ""
|
||||
|
||||
def test_finish_defaults(self):
|
||||
e = FinishEvent()
|
||||
assert e.type == "finish"
|
||||
assert e.stop_reason == ""
|
||||
assert e.input_tokens == 0
|
||||
assert e.output_tokens == 0
|
||||
assert e.model == ""
|
||||
|
||||
def test_stream_error_defaults(self):
|
||||
e = StreamErrorEvent()
|
||||
assert e.type == "error"
|
||||
assert e.error == ""
|
||||
assert e.recoverable is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction with values
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventConstruction:
|
||||
"""Events should store provided field values correctly."""
|
||||
|
||||
def test_text_delta_with_values(self):
|
||||
e = TextDeltaEvent(content="hello", snapshot="hello world")
|
||||
assert e.content == "hello"
|
||||
assert e.snapshot == "hello world"
|
||||
|
||||
def test_text_end_with_values(self):
|
||||
e = TextEndEvent(full_text="the complete response")
|
||||
assert e.full_text == "the complete response"
|
||||
|
||||
def test_tool_call_with_values(self):
|
||||
e = ToolCallEvent(
|
||||
tool_use_id="call_abc123",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "python", "num_results": 5},
|
||||
)
|
||||
assert e.tool_use_id == "call_abc123"
|
||||
assert e.tool_name == "web_search"
|
||||
assert e.tool_input == {"query": "python", "num_results": 5}
|
||||
|
||||
def test_tool_result_with_values(self):
|
||||
e = ToolResultEvent(
|
||||
tool_use_id="call_abc123",
|
||||
content="search results here",
|
||||
is_error=False,
|
||||
)
|
||||
assert e.tool_use_id == "call_abc123"
|
||||
assert e.content == "search results here"
|
||||
assert e.is_error is False
|
||||
|
||||
def test_tool_result_error(self):
|
||||
e = ToolResultEvent(
|
||||
tool_use_id="call_fail",
|
||||
content="timeout",
|
||||
is_error=True,
|
||||
)
|
||||
assert e.is_error is True
|
||||
|
||||
def test_reasoning_delta_with_content(self):
|
||||
e = ReasoningDeltaEvent(content="Let me think about this...")
|
||||
assert e.content == "Let me think about this..."
|
||||
|
||||
def test_finish_with_values(self):
|
||||
e = FinishEvent(
|
||||
stop_reason="end_turn",
|
||||
input_tokens=150,
|
||||
output_tokens=300,
|
||||
model="claude-haiku-4-5",
|
||||
)
|
||||
assert e.stop_reason == "end_turn"
|
||||
assert e.input_tokens == 150
|
||||
assert e.output_tokens == 300
|
||||
assert e.model == "claude-haiku-4-5"
|
||||
|
||||
def test_stream_error_with_values(self):
|
||||
e = StreamErrorEvent(error="rate limit exceeded", recoverable=True)
|
||||
assert e.error == "rate limit exceeded"
|
||||
assert e.recoverable is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frozen immutability
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventImmutability:
|
||||
"""All events are frozen dataclasses — fields cannot be reassigned."""
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_frozen(self, cls):
|
||||
event = cls()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
event.type = "modified"
|
||||
|
||||
def test_text_delta_frozen_content(self):
|
||||
e = TextDeltaEvent(content="hello")
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.content = "modified"
|
||||
|
||||
def test_tool_call_frozen_input(self):
|
||||
e = ToolCallEvent(tool_input={"key": "value"})
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.tool_input = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type literal values
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTypeLiterals:
|
||||
"""Each event's `type` field should match its Literal annotation."""
|
||||
|
||||
EXPECTED_TYPES = {
|
||||
TextDeltaEvent: "text_delta",
|
||||
TextEndEvent: "text_end",
|
||||
ToolCallEvent: "tool_call",
|
||||
ToolResultEvent: "tool_result",
|
||||
ReasoningStartEvent: "reasoning_start",
|
||||
ReasoningDeltaEvent: "reasoning_delta",
|
||||
FinishEvent: "finish",
|
||||
StreamErrorEvent: "error",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls,expected_type",
|
||||
EXPECTED_TYPES.items(),
|
||||
ids=lambda x: x.__name__ if isinstance(x, type) else x,
|
||||
)
|
||||
def test_type_value(self, cls, expected_type):
|
||||
assert cls().type == expected_type
|
||||
|
||||
def test_all_types_unique(self):
|
||||
types = [cls().type for cls in ALL_EVENT_CLASSES]
|
||||
assert len(types) == len(set(types)), f"Duplicate type values: {types}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization via dataclasses.asdict
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventSerialization:
|
||||
"""Events should round-trip through asdict for JSON serialization."""
|
||||
|
||||
def test_text_delta_asdict(self):
|
||||
e = TextDeltaEvent(content="chunk", snapshot="full chunk")
|
||||
d = asdict(e)
|
||||
assert d == {"type": "text_delta", "content": "chunk", "snapshot": "full chunk"}
|
||||
|
||||
def test_tool_call_asdict(self):
|
||||
e = ToolCallEvent(
|
||||
tool_use_id="id_1",
|
||||
tool_name="calc",
|
||||
tool_input={"expression": "2+2"},
|
||||
)
|
||||
d = asdict(e)
|
||||
assert d["tool_name"] == "calc"
|
||||
assert d["tool_input"] == {"expression": "2+2"}
|
||||
|
||||
def test_finish_asdict(self):
|
||||
e = FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=20, model="gpt-4")
|
||||
d = asdict(e)
|
||||
assert d == {
|
||||
"type": "finish",
|
||||
"stop_reason": "stop",
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_asdict_contains_type(self, cls):
|
||||
d = asdict(cls())
|
||||
assert "type" in d
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_asdict_keys_match_fields(self, cls):
|
||||
event = cls()
|
||||
d = asdict(event)
|
||||
field_names = {f.name for f in fields(cls)}
|
||||
assert set(d.keys()) == field_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StreamEvent union type
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestStreamEventUnion:
|
||||
"""The StreamEvent union should include all event classes."""
|
||||
|
||||
def test_union_contains_all_classes(self):
|
||||
# StreamEvent is a UnionType (PEP 604 syntax: X | Y | Z)
|
||||
union_args = StreamEvent.__args__ # type: ignore[attr-defined]
|
||||
for cls in ALL_EVENT_CLASSES:
|
||||
assert cls in union_args, f"{cls.__name__} not in StreamEvent union"
|
||||
|
||||
def test_union_has_exactly_expected_members(self):
|
||||
union_args = set(StreamEvent.__args__) # type: ignore[attr-defined]
|
||||
expected = set(ALL_EVENT_CLASSES)
|
||||
assert union_args == expected
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_isinstance_check(self, cls):
|
||||
"""Each event instance should be an instance of its class (basic sanity)."""
|
||||
event = cls()
|
||||
assert isinstance(event, cls)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Equality & hashing (frozen dataclasses support both)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventEquality:
|
||||
"""Frozen dataclasses support equality and hashing."""
|
||||
|
||||
def test_equal_events(self):
|
||||
a = TextDeltaEvent(content="hi", snapshot="hi")
|
||||
b = TextDeltaEvent(content="hi", snapshot="hi")
|
||||
assert a == b
|
||||
|
||||
def test_unequal_events(self):
|
||||
a = TextDeltaEvent(content="hi")
|
||||
b = TextDeltaEvent(content="bye")
|
||||
assert a != b
|
||||
|
||||
def test_different_types_not_equal(self):
|
||||
a = TextDeltaEvent(content="hi")
|
||||
b = ReasoningDeltaEvent(content="hi")
|
||||
assert a != b
|
||||
|
||||
def test_hashable(self):
|
||||
e = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
s = {e} # should be hashable since frozen
|
||||
assert e in s
|
||||
|
||||
def test_equal_events_same_hash(self):
|
||||
a = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
b = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
def test_events_with_dict_not_hashable(self):
|
||||
"""Events containing dict fields (e.g. tool_input) are not hashable."""
|
||||
e = ToolCallEvent(tool_use_id="x", tool_name="y", tool_input={"key": "val"})
|
||||
with pytest.raises(TypeError, match="unhashable type"):
|
||||
hash(e)
|
||||
Reference in New Issue
Block a user