From 22e816bf86b38709d8d0d21859b893268097de64 Mon Sep 17 00:00:00 2001 From: Timothy Date: Mon, 2 Feb 2026 10:30:03 -0800 Subject: [PATCH] chore: update gitignore --- .gitignore | 2 +- core/tests/test_event_loop_node.py | 746 ++++++++++++++++++ core/tests/test_event_type_extension.py | 978 ++++++++++++++++++++++++ core/tests/test_litellm_streaming.py | 389 ++++++++++ core/tests/test_stream_events.py | 318 ++++++++ 5 files changed, 2432 insertions(+), 1 deletion(-) create mode 100644 core/tests/test_event_loop_node.py create mode 100644 core/tests/test_event_type_extension.py create mode 100644 core/tests/test_litellm_streaming.py create mode 100644 core/tests/test_stream_events.py diff --git a/.gitignore b/.gitignore index 53c37fc3..9e74ef3c 100644 --- a/.gitignore +++ b/.gitignore @@ -72,4 +72,4 @@ exports/* .venv docs/github-issues/* -core/tests/* +core/tests/*dumps/* diff --git a/core/tests/test_event_loop_node.py b/core/tests/test_event_loop_node.py new file mode 100644 index 00000000..5b65f316 --- /dev/null +++ b/core/tests/test_event_loop_node.py @@ -0,0 +1,746 @@ +"""WP-8: Tests for EventLoopNode, OutputAccumulator, LoopConfig, JudgeProtocol. + +Uses real FileConversationStore (no mocks for storage) and a MockStreamingLLM +that yields pre-programmed StreamEvents to control the loop deterministically. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from framework.graph.conversation import NodeConversation +from framework.graph.event_loop_node import ( + EventLoopNode, + JudgeProtocol, + JudgeVerdict, + LoopConfig, + OutputAccumulator, +) +from framework.graph.node import NodeContext, NodeProtocol, NodeSpec, SharedMemory +from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse +from framework.llm.stream_events import ( + FinishEvent, + StreamErrorEvent, + TextDeltaEvent, + ToolCallEvent, +) +from framework.runtime.core import Runtime +from framework.runtime.event_bus import EventBus, EventType +from framework.storage.conversation_store import FileConversationStore + +# --------------------------------------------------------------------------- +# Mock LLM that yields pre-programmed stream events +# --------------------------------------------------------------------------- + + +class MockStreamingLLM(LLMProvider): + """Mock LLM that yields pre-programmed StreamEvent sequences. + + Each call to stream() consumes the next scenario from the list. + Cycles back to the beginning if more calls are made than scenarios. + """ + + def __init__(self, scenarios: list[list] | None = None): + self.scenarios = scenarios or [] + self._call_index = 0 + self.stream_calls: list[dict] = [] + + async def stream( + self, + messages: list[dict[str, Any]], + system: str = "", + tools: list[Tool] | None = None, + max_tokens: int = 4096, + ) -> AsyncIterator: + self.stream_calls.append({"messages": messages, "system": system, "tools": tools}) + if not self.scenarios: + return + events = self.scenarios[self._call_index % len(self.scenarios)] + self._call_index += 1 + for event in events: + yield event + + def complete(self, messages, system="", **kwargs) -> LLMResponse: + return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop") + + def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse: + return LLMResponse(content="", model="mock", stop_reason="stop") + + +# --------------------------------------------------------------------------- +# Helper: build a simple text-only scenario +# --------------------------------------------------------------------------- + + +def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list: + """Build a stream scenario that produces text and finishes.""" + return [ + TextDeltaEvent(content=text, snapshot=text), + FinishEvent( + stop_reason="stop", input_tokens=input_tokens, output_tokens=output_tokens, model="mock" + ), + ] + + +def tool_call_scenario( + tool_name: str, + tool_input: dict, + tool_use_id: str = "call_1", + text: str = "", +) -> list: + """Build a stream scenario that produces a tool call.""" + events = [] + if text: + events.append(TextDeltaEvent(content=text, snapshot=text)) + events.append( + ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input) + ) + events.append( + FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock") + ) + return events + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def runtime(): + rt = MagicMock(spec=Runtime) + rt.start_run = MagicMock(return_value="run_1") + rt.decide = MagicMock(return_value="dec_1") + rt.record_outcome = MagicMock() + rt.end_run = MagicMock() + rt.report_problem = MagicMock() + rt.set_node = MagicMock() + return rt + + +@pytest.fixture +def node_spec(): + return NodeSpec( + id="test_loop", + name="Test Loop", + description="A test event loop node", + node_type="event_loop", + output_keys=["result"], + system_prompt="You are a test assistant.", + ) + + +@pytest.fixture +def memory(): + return SharedMemory() + + +def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""): + """Build a NodeContext for testing.""" + return NodeContext( + runtime=runtime, + node_id=node_spec.id, + node_spec=node_spec, + memory=memory, + input_data=input_data or {}, + llm=llm, + available_tools=tools or [], + goal_context=goal_context, + ) + + +# =========================================================================== +# NodeProtocol conformance +# =========================================================================== + + +class TestNodeProtocolConformance: + def test_subclasses_node_protocol(self): + """EventLoopNode must be a subclass of NodeProtocol.""" + assert issubclass(EventLoopNode, NodeProtocol) + + def test_has_execute_method(self): + node = EventLoopNode() + assert hasattr(node, "execute") + assert asyncio.iscoroutinefunction(node.execute) + + def test_has_validate_input(self): + node = EventLoopNode() + assert hasattr(node, "validate_input") + + +# =========================================================================== +# Basic loop execution +# =========================================================================== + + +class TestBasicLoop: + @pytest.mark.asyncio + async def test_basic_text_only_implicit_accept(self, runtime, node_spec, memory): + """No tools, no judge. LLM produces text, implicit accept on stop.""" + # Override to no output_keys so implicit judge accepts immediately + node_spec.output_keys = [] + llm = MockStreamingLLM(scenarios=[text_scenario("Hello world")]) + ctx = build_ctx(runtime, node_spec, memory, llm) + + node = EventLoopNode(config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is True + assert result.tokens_used > 0 + + @pytest.mark.asyncio + async def test_no_llm_returns_failure(self, runtime, node_spec, memory): + """ctx.llm=None should return failure immediately.""" + ctx = build_ctx(runtime, node_spec, memory, llm=None) + + node = EventLoopNode() + result = await node.execute(ctx) + + assert result.success is False + assert "LLM" in result.error + + @pytest.mark.asyncio + async def test_max_iterations_failure(self, runtime, node_spec, memory): + """When max_iterations is reached without acceptance, should fail.""" + # LLM always produces text but never calls set_output, so implicit + # judge retries asking for missing keys + llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")]) + ctx = build_ctx(runtime, node_spec, memory, llm) + + node = EventLoopNode(config=LoopConfig(max_iterations=2)) + result = await node.execute(ctx) + + assert result.success is False + assert "Max iterations" in result.error + + +# =========================================================================== +# Judge integration +# =========================================================================== + + +class TestJudgeIntegration: + @pytest.mark.asyncio + async def test_judge_accept(self, runtime, node_spec, memory): + """Mock judge ACCEPT -> success.""" + node_spec.output_keys = [] + llm = MockStreamingLLM(scenarios=[text_scenario("Done!")]) + + judge = AsyncMock(spec=JudgeProtocol) + judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT")) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is True + judge.evaluate.assert_called_once() + + @pytest.mark.asyncio + async def test_judge_escalate(self, runtime, node_spec, memory): + """Mock judge ESCALATE -> failure.""" + node_spec.output_keys = [] + llm = MockStreamingLLM(scenarios=[text_scenario("Attempt")]) + + judge = AsyncMock(spec=JudgeProtocol) + judge.evaluate = AsyncMock( + return_value=JudgeVerdict(action="ESCALATE", feedback="Tone violation") + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is False + assert "escalated" in result.error.lower() + assert "Tone violation" in result.error + + @pytest.mark.asyncio + async def test_judge_retry_then_accept(self, runtime, node_spec, memory): + """RETRY twice, then ACCEPT. Should run 3 iterations.""" + node_spec.output_keys = [] + llm = MockStreamingLLM( + scenarios=[ + text_scenario("attempt 1"), + text_scenario("attempt 2"), + text_scenario("attempt 3"), + ] + ) + + call_count = 0 + + async def evaluate_fn(context): + nonlocal call_count + call_count += 1 + if call_count < 3: + return JudgeVerdict(action="RETRY", feedback="Try harder") + return JudgeVerdict(action="ACCEPT") + + judge = AsyncMock(spec=JudgeProtocol) + judge.evaluate = AsyncMock(side_effect=evaluate_fn) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=10)) + result = await node.execute(ctx) + + assert result.success is True + assert call_count == 3 + + +# =========================================================================== +# set_output tool +# =========================================================================== + + +class TestSetOutput: + @pytest.mark.asyncio + async def test_set_output_accumulates(self, runtime, node_spec, memory): + """LLM calls set_output -> values appear in NodeResult.output.""" + llm = MockStreamingLLM( + scenarios=[ + # Turn 1: call set_output + tool_call_scenario("set_output", {"key": "result", "value": "42"}), + # Turn 2: text response (triggers implicit judge) + text_scenario("Done, result is 42"), + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is True + assert result.output["result"] == "42" + + @pytest.mark.asyncio + async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory): + """set_output with key not in output_keys -> is_error=True.""" + llm = MockStreamingLLM( + scenarios=[ + # Turn 1: call set_output with bad key + tool_call_scenario("set_output", {"key": "bad_key", "value": "x"}), + # Turn 2: call set_output with good key + tool_call_scenario("set_output", {"key": "result", "value": "ok"}), + # Turn 3: text done + text_scenario("Done"), + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is True + assert result.output["result"] == "ok" + assert "bad_key" not in result.output + + @pytest.mark.asyncio + async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory): + """Judge accepts but output keys are missing -> retry with hint.""" + judge = AsyncMock(spec=JudgeProtocol) + judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT")) + + llm = MockStreamingLLM( + scenarios=[ + # Turn 1: text without set_output -> judge accepts but keys missing -> retry + text_scenario("I'll get to it"), + # Turn 2: set_output + tool_call_scenario("set_output", {"key": "result", "value": "done"}), + # Turn 3: text -> judge accepts, keys present -> success + text_scenario("All done"), + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is True + assert result.output["result"] == "done" + + +# =========================================================================== +# Stall detection +# =========================================================================== + + +class TestStallDetection: + @pytest.mark.asyncio + async def test_stall_detection(self, runtime, node_spec, memory): + """3 identical responses should trigger stall detection.""" + node_spec.output_keys = [] # so implicit judge would accept + # But we need the judge to RETRY so we actually get 3 identical responses + judge = AsyncMock(spec=JudgeProtocol) + judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY")) + + llm = MockStreamingLLM(scenarios=[text_scenario("same answer")]) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode( + judge=judge, + config=LoopConfig(max_iterations=10, stall_detection_threshold=3), + ) + result = await node.execute(ctx) + + assert result.success is False + assert "stalled" in result.error.lower() + + +# =========================================================================== +# EventBus lifecycle events +# =========================================================================== + + +class TestEventBusLifecycle: + @pytest.mark.asyncio + async def test_lifecycle_events_published(self, runtime, node_spec, memory): + """NODE_LOOP_STARTED, NODE_LOOP_ITERATION, NODE_LOOP_COMPLETED should be published.""" + node_spec.output_keys = [] + llm = MockStreamingLLM(scenarios=[text_scenario("ok")]) + bus = EventBus() + + received_events = [] + bus.subscribe( + event_types=[ + EventType.NODE_LOOP_STARTED, + EventType.NODE_LOOP_ITERATION, + EventType.NODE_LOOP_COMPLETED, + ], + handler=lambda e: received_events.append(e.type), + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is True + assert EventType.NODE_LOOP_STARTED in received_events + assert EventType.NODE_LOOP_ITERATION in received_events + assert EventType.NODE_LOOP_COMPLETED in received_events + + @pytest.mark.asyncio + async def test_client_facing_uses_client_output_delta(self, runtime, memory): + """client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA.""" + spec = NodeSpec( + id="ui_node", + name="UI Node", + description="Streams to user", + node_type="event_loop", + output_keys=[], + client_facing=True, + ) + llm = MockStreamingLLM(scenarios=[text_scenario("visible to user")]) + bus = EventBus() + + received_types = [] + bus.subscribe( + event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.LLM_TEXT_DELTA], + handler=lambda e: received_types.append(e.type), + ) + + ctx = build_ctx(runtime, spec, memory, llm) + node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5)) + await node.execute(ctx) + + assert EventType.CLIENT_OUTPUT_DELTA in received_types + assert EventType.LLM_TEXT_DELTA not in received_types + + +# =========================================================================== +# Tool execution +# =========================================================================== + + +class TestToolExecution: + @pytest.mark.asyncio + async def test_tool_execution_feedback(self, runtime, node_spec, memory): + """Tool call -> result fed back to conversation via stream loop.""" + node_spec.output_keys = [] + + def my_tool_executor(tool_use: ToolUse) -> ToolResult: + return ToolResult( + tool_use_id=tool_use.id, + content=f"Result for {tool_use.name}", + is_error=False, + ) + + llm = MockStreamingLLM( + scenarios=[ + # Turn 1: call a tool + tool_call_scenario("search", {"query": "test"}, tool_use_id="call_search"), + # Turn 2: text response after seeing tool result + text_scenario("Found the answer"), + ] + ) + + ctx = build_ctx( + runtime, + node_spec, + memory, + llm, + tools=[Tool(name="search", description="Search", parameters={})], + ) + node = EventLoopNode( + tool_executor=my_tool_executor, + config=LoopConfig(max_iterations=5), + ) + result = await node.execute(ctx) + + assert result.success is True + # stream() should have been called twice (tool call turn + final text turn) + assert llm._call_index >= 2 + + +# =========================================================================== +# Write-through persistence with real FileConversationStore +# =========================================================================== + + +class TestWriteThroughPersistence: + @pytest.mark.asyncio + async def test_messages_written_to_store(self, tmp_path, runtime, node_spec, memory): + """Messages should be persisted immediately via write-through.""" + store = FileConversationStore(tmp_path / "conv") + node_spec.output_keys = [] + llm = MockStreamingLLM(scenarios=[text_scenario("Hello")]) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode( + conversation_store=store, + config=LoopConfig(max_iterations=5), + ) + result = await node.execute(ctx) + + assert result.success is True + + # Verify parts were written to disk + parts = await store.read_parts() + assert len(parts) >= 2 # at least initial user msg + assistant msg + + @pytest.mark.asyncio + async def test_output_accumulator_write_through(self, tmp_path, runtime, node_spec, memory): + """set_output values should be persisted in cursor immediately.""" + store = FileConversationStore(tmp_path / "conv") + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario("set_output", {"key": "result", "value": "persisted_value"}), + text_scenario("Done"), + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode( + conversation_store=store, + config=LoopConfig(max_iterations=5), + ) + result = await node.execute(ctx) + + assert result.success is True + assert result.output["result"] == "persisted_value" + + # Verify output was written to cursor on disk + cursor = await store.read_cursor() + assert cursor is not None + assert cursor["outputs"]["result"] == "persisted_value" + + +# =========================================================================== +# Crash recovery (restore from real FileConversationStore) +# =========================================================================== + + +class TestCrashRecovery: + @pytest.mark.asyncio + async def test_restore_from_checkpoint(self, tmp_path, runtime, node_spec, memory): + """Populate a store with state, then verify EventLoopNode restores from it.""" + store = FileConversationStore(tmp_path / "conv") + + # Simulate a previous run that wrote conversation + cursor + conv = NodeConversation( + system_prompt="You are a test assistant.", + output_keys=["result"], + store=store, + ) + await conv.add_user_message("Initial input") + await conv.add_assistant_message("Working on it...") + + # Write cursor with iteration and outputs + await store.write_cursor( + { + "iteration": 1, + "next_seq": conv.next_seq, + "outputs": {"result": "partial_value"}, + } + ) + + # Now create a new EventLoopNode and execute -- it should restore + node_spec.output_keys = [] # no required keys so implicit accept works + llm = MockStreamingLLM(scenarios=[text_scenario("Continuing...")]) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode( + conversation_store=store, + config=LoopConfig(max_iterations=5), + ) + result = await node.execute(ctx) + + assert result.success is True + # Should have the restored output + assert result.output.get("result") == "partial_value" + + +# =========================================================================== +# External event injection +# =========================================================================== + + +class TestEventInjection: + @pytest.mark.asyncio + async def test_inject_event(self, runtime, node_spec, memory): + """inject_event() content should appear as user message in next iteration.""" + node_spec.output_keys = [] + + judge_calls = [] + + async def evaluate_fn(context): + judge_calls.append(context) + if len(judge_calls) >= 2: + return JudgeVerdict(action="ACCEPT") + return JudgeVerdict(action="RETRY") + + judge = AsyncMock(spec=JudgeProtocol) + judge.evaluate = AsyncMock(side_effect=evaluate_fn) + + llm = MockStreamingLLM( + scenarios=[ + text_scenario("iteration 1"), + text_scenario("iteration 2"), + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode( + judge=judge, + config=LoopConfig(max_iterations=5), + ) + + # Pre-inject an event before execute runs + await node.inject_event("Priority: CEO wants meeting rescheduled") + + result = await node.execute(ctx) + assert result.success is True + + # Verify the injected content made it into the LLM messages + all_messages = [] + for call in llm.stream_calls: + all_messages.extend(call["messages"]) + injected_found = any("[External event]" in str(m.get("content", "")) for m in all_messages) + assert injected_found + + +# =========================================================================== +# Pause/resume +# =========================================================================== + + +class TestPauseResume: + @pytest.mark.asyncio + async def test_pause_returns_early(self, runtime, node_spec, memory): + """pause_requested in input_data should trigger early return.""" + node_spec.output_keys = [] + llm = MockStreamingLLM(scenarios=[text_scenario("should not run")]) + + ctx = build_ctx( + runtime, + node_spec, + memory, + llm, + input_data={"pause_requested": True}, + ) + node = EventLoopNode(config=LoopConfig(max_iterations=10)) + result = await node.execute(ctx) + + # Should return success (paused, not failed) + assert result.success is True + # LLM should not have been called (paused before first turn) + assert llm._call_index == 0 + + +# =========================================================================== +# Stream errors +# =========================================================================== + + +class TestStreamErrors: + @pytest.mark.asyncio + async def test_non_recoverable_stream_error_raises(self, runtime, node_spec, memory): + """Non-recoverable StreamErrorEvent should raise RuntimeError.""" + node_spec.output_keys = [] + llm = MockStreamingLLM( + scenarios=[ + [StreamErrorEvent(error="Connection lost", recoverable=False)], + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(config=LoopConfig(max_iterations=5)) + + with pytest.raises(RuntimeError, match="Stream error"): + await node.execute(ctx) + + +# =========================================================================== +# OutputAccumulator unit tests +# =========================================================================== + + +class TestOutputAccumulator: + @pytest.mark.asyncio + async def test_set_and_get(self): + acc = OutputAccumulator() + await acc.set("key1", "value1") + assert acc.get("key1") == "value1" + assert acc.get("nonexistent") is None + + @pytest.mark.asyncio + async def test_to_dict(self): + acc = OutputAccumulator() + await acc.set("a", 1) + await acc.set("b", 2) + assert acc.to_dict() == {"a": 1, "b": 2} + + @pytest.mark.asyncio + async def test_has_all_keys(self): + acc = OutputAccumulator() + assert acc.has_all_keys([]) is True + assert acc.has_all_keys(["x"]) is False + await acc.set("x", "val") + assert acc.has_all_keys(["x"]) is True + + @pytest.mark.asyncio + async def test_write_through_to_real_store(self, tmp_path): + """OutputAccumulator should write through to FileConversationStore cursor.""" + store = FileConversationStore(tmp_path / "acc_test") + acc = OutputAccumulator(store=store) + + await acc.set("result", "hello") + + cursor = await store.read_cursor() + assert cursor["outputs"]["result"] == "hello" + + @pytest.mark.asyncio + async def test_restore_from_real_store(self, tmp_path): + """OutputAccumulator.restore() should rebuild from FileConversationStore.""" + store = FileConversationStore(tmp_path / "acc_restore") + await store.write_cursor({"outputs": {"key1": "val1", "key2": "val2"}}) + + acc = await OutputAccumulator.restore(store) + assert acc.get("key1") == "val1" + assert acc.get("key2") == "val2" + assert acc.has_all_keys(["key1", "key2"]) is True diff --git a/core/tests/test_event_type_extension.py b/core/tests/test_event_type_extension.py new file mode 100644 index 00000000..5ca5ce1b --- /dev/null +++ b/core/tests/test_event_type_extension.py @@ -0,0 +1,978 @@ +"""Tests for extending the stream event type system. + +Validates that the StreamEvent discriminated union pattern supports: +- Type-based dispatch (matching on event.type) +- Pattern matching / isinstance branching +- Custom event subclasses following the same frozen-dataclass convention +- Serialization of mixed event sequences + +WP-2 tests validate EventType enum extension and node-level event routing: +- All 12 new EventType enum members with correct string values +- node_id routing on AgentEvent +- filter_node on Subscription +- Backward compatibility with existing enum members +""" + +import asyncio +from dataclasses import FrozenInstanceError, asdict, dataclass, field +from typing import Any, Literal + +import pytest + +from framework.llm.stream_events import ( + FinishEvent, + ReasoningDeltaEvent, + ReasoningStartEvent, + StreamErrorEvent, + TextDeltaEvent, + TextEndEvent, + ToolCallEvent, + ToolResultEvent, +) +from framework.runtime.event_bus import AgentEvent, EventBus, EventType, Subscription + + +# --------------------------------------------------------------------------- +# Helpers: type-based dispatch +# --------------------------------------------------------------------------- +def dispatch_event(event) -> str: + """Dispatch an event by its type field, returning a label.""" + handlers = { + "text_delta": lambda e: f"text:{e.content}", + "text_end": lambda e: f"end:{len(e.full_text)}chars", + "tool_call": lambda e: f"call:{e.tool_name}", + "tool_result": lambda e: f"result:{e.tool_use_id}", + "reasoning_start": lambda _: "reasoning:start", + "reasoning_delta": lambda e: f"reasoning:{e.content[:20]}", + "finish": lambda e: f"finish:{e.stop_reason}", + "error": lambda e: f"error:{e.error}", + } + handler = handlers.get(event.type) + if handler is None: + return f"unknown:{event.type}" + return handler(event) + + +def collect_text(events: list) -> str: + """Accumulate full text from a stream of events.""" + for event in reversed(events): + if isinstance(event, TextEndEvent): + return event.full_text + if isinstance(event, TextDeltaEvent): + return event.snapshot + return "" + + +def extract_tool_calls(events: list) -> list[dict[str, Any]]: + """Extract tool call info from a stream of events.""" + return [ + {"id": e.tool_use_id, "name": e.tool_name, "input": e.tool_input} + for e in events + if isinstance(e, ToolCallEvent) + ] + + +# --------------------------------------------------------------------------- +# Type-based dispatch tests +# --------------------------------------------------------------------------- +class TestTypeDispatch: + """Dispatch on event.type string for handler routing.""" + + def test_dispatch_text_delta(self): + e = TextDeltaEvent(content="hello") + assert dispatch_event(e) == "text:hello" + + def test_dispatch_text_end(self): + e = TextEndEvent(full_text="hello world") + assert dispatch_event(e) == "end:11chars" + + def test_dispatch_tool_call(self): + e = ToolCallEvent(tool_name="web_search") + assert dispatch_event(e) == "call:web_search" + + def test_dispatch_tool_result(self): + e = ToolResultEvent(tool_use_id="abc") + assert dispatch_event(e) == "result:abc" + + def test_dispatch_reasoning_start(self): + e = ReasoningStartEvent() + assert dispatch_event(e) == "reasoning:start" + + def test_dispatch_reasoning_delta(self): + e = ReasoningDeltaEvent(content="Let me think step by step") + assert dispatch_event(e) == "reasoning:Let me think step by" + + def test_dispatch_finish(self): + e = FinishEvent(stop_reason="end_turn") + assert dispatch_event(e) == "finish:end_turn" + + def test_dispatch_error(self): + e = StreamErrorEvent(error="timeout") + assert dispatch_event(e) == "error:timeout" + + +# --------------------------------------------------------------------------- +# isinstance-based filtering +# --------------------------------------------------------------------------- +class TestInstanceFiltering: + """Filter event streams using isinstance for each event type.""" + + @pytest.fixture + def text_stream(self) -> list: + """Simulate a text-only stream.""" + return [ + TextDeltaEvent(content="Hello", snapshot="Hello"), + TextDeltaEvent(content=" world", snapshot="Hello world"), + TextDeltaEvent(content="!", snapshot="Hello world!"), + TextEndEvent(full_text="Hello world!"), + FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=3, model="test"), + ] + + @pytest.fixture + def tool_stream(self) -> list: + """Simulate a tool call stream.""" + return [ + ToolCallEvent( + tool_use_id="call_1", + tool_name="get_weather", + tool_input={"city": "London"}, + ), + ToolCallEvent( + tool_use_id="call_2", + tool_name="calculator", + tool_input={"expression": "2+2"}, + ), + FinishEvent(stop_reason="tool_calls"), + ] + + @pytest.fixture + def reasoning_stream(self) -> list: + """Simulate a stream with reasoning blocks.""" + return [ + ReasoningStartEvent(), + ReasoningDeltaEvent(content="Let me analyze this..."), + ReasoningDeltaEvent(content="The answer is 42."), + TextDeltaEvent(content="The answer is 42.", snapshot="The answer is 42."), + TextEndEvent(full_text="The answer is 42."), + FinishEvent(stop_reason="end_turn"), + ] + + def test_collect_text(self, text_stream): + assert collect_text(text_stream) == "Hello world!" + + def test_collect_text_from_tool_stream(self, tool_stream): + assert collect_text(tool_stream) == "" + + def test_extract_tool_calls(self, tool_stream): + calls = extract_tool_calls(tool_stream) + assert len(calls) == 2 + assert calls[0]["name"] == "get_weather" + assert calls[1]["name"] == "calculator" + + def test_extract_tool_calls_from_text_stream(self, text_stream): + assert extract_tool_calls(text_stream) == [] + + def test_filter_text_deltas(self, text_stream): + deltas = [e for e in text_stream if isinstance(e, TextDeltaEvent)] + assert len(deltas) == 3 + + def test_filter_finish(self, text_stream): + finishes = [e for e in text_stream if isinstance(e, FinishEvent)] + assert len(finishes) == 1 + assert finishes[0].stop_reason == "stop" + + def test_reasoning_then_text(self, reasoning_stream): + reasoning = [e for e in reasoning_stream if isinstance(e, ReasoningDeltaEvent)] + text = collect_text(reasoning_stream) + assert len(reasoning) == 2 + assert text == "The answer is 42." + + def test_mixed_stream_type_counts(self, reasoning_stream): + type_counts = {} + for e in reasoning_stream: + type_counts[e.type] = type_counts.get(e.type, 0) + 1 + assert type_counts == { + "reasoning_start": 1, + "reasoning_delta": 2, + "text_delta": 1, + "text_end": 1, + "finish": 1, + } + + +# --------------------------------------------------------------------------- +# Custom event extension pattern +# --------------------------------------------------------------------------- +@dataclass(frozen=True) +class CustomMetricsEvent: + """Example custom event following the same pattern.""" + + type: Literal["custom_metrics"] = "custom_metrics" + latency_ms: float = 0.0 + tokens_per_second: float = 0.0 + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class CustomCitationEvent: + """Example citation event extending the pattern.""" + + type: Literal["citation"] = "citation" + source_url: str = "" + quote: str = "" + confidence: float = 0.0 + + +class TestCustomEventExtension: + """Custom events should follow the same frozen-dataclass convention.""" + + def test_custom_event_construction(self): + e = CustomMetricsEvent(latency_ms=150.5, tokens_per_second=42.3) + assert e.type == "custom_metrics" + assert e.latency_ms == 150.5 + + def test_custom_event_frozen(self): + e = CustomMetricsEvent() + with pytest.raises(FrozenInstanceError): + e.type = "modified" + + def test_custom_event_serialization(self): + e = CustomMetricsEvent( + latency_ms=100.0, + tokens_per_second=50.0, + metadata={"provider": "anthropic"}, + ) + d = asdict(e) + assert d["type"] == "custom_metrics" + assert d["metadata"] == {"provider": "anthropic"} + + def test_custom_event_dispatch(self): + """Custom events can extend the dispatch map.""" + e = CustomMetricsEvent(latency_ms=200.0) + # Falls through to "unknown" in our dispatch_event + assert dispatch_event(e) == "unknown:custom_metrics" + + def test_custom_event_in_mixed_stream(self): + """Custom events can coexist with standard events in a list.""" + stream = [ + TextDeltaEvent(content="hi", snapshot="hi"), + CustomMetricsEvent(latency_ms=50.0), + TextEndEvent(full_text="hi"), + CustomCitationEvent(source_url="https://example.com", quote="hi"), + FinishEvent(stop_reason="stop"), + ] + standard = [ + e + for e in stream + if hasattr(e, "type") + and e.type + in { + "text_delta", + "text_end", + "tool_call", + "tool_result", + "reasoning_start", + "reasoning_delta", + "finish", + "error", + } + ] + custom = [ + e + for e in stream + if e.type + not in { + "text_delta", + "text_end", + "tool_call", + "tool_result", + "reasoning_start", + "reasoning_delta", + "finish", + "error", + } + ] + assert len(standard) == 3 + assert len(custom) == 2 + + +# --------------------------------------------------------------------------- +# Serialization of full event sequences +# --------------------------------------------------------------------------- +class TestSequenceSerialization: + """Serialize entire event sequences, as done by the dump tests.""" + + def test_serialize_text_sequence(self): + events = [ + TextDeltaEvent(content="Hello", snapshot="Hello"), + TextDeltaEvent(content=" world", snapshot="Hello world"), + TextEndEvent(full_text="Hello world"), + FinishEvent(stop_reason="stop", model="test-model"), + ] + serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)] + assert len(serialized) == 4 + assert serialized[0]["index"] == 0 + assert serialized[0]["type"] == "text_delta" + assert serialized[-1]["type"] == "finish" + assert serialized[-1]["model"] == "test-model" + + def test_serialize_tool_sequence(self): + events = [ + ToolCallEvent( + tool_use_id="call_1", + tool_name="search", + tool_input={"query": "test"}, + ), + FinishEvent(stop_reason="tool_calls"), + ] + serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)] + assert serialized[0]["tool_input"] == {"query": "test"} + assert serialized[1]["stop_reason"] == "tool_calls" + + def test_serialize_error_sequence(self): + events = [ + TextDeltaEvent(content="partial"), + StreamErrorEvent(error="connection reset", recoverable=True), + FinishEvent(stop_reason="error"), + ] + serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)] + assert serialized[1]["type"] == "error" + assert serialized[1]["recoverable"] is True + + def test_roundtrip_snapshot_accumulation(self): + """Verify snapshot grows monotonically through serialization.""" + chunks = ["Hello", " beautiful", " world", "!"] + events = [] + snapshot = "" + for chunk in chunks: + snapshot += chunk + events.append(TextDeltaEvent(content=chunk, snapshot=snapshot)) + + serialized = [asdict(e) for e in events] + for i in range(1, len(serialized)): + assert len(serialized[i]["snapshot"]) > len(serialized[i - 1]["snapshot"]) + assert serialized[-1]["snapshot"] == "Hello beautiful world!" + + +# =========================================================================== +# WP-2: EventType Enum Extension + Node-Level Event Routing +# =========================================================================== + +# The 12 new EventType members added by WP-2 +WP2_EVENT_TYPES = { + # Node event-loop lifecycle + EventType.NODE_LOOP_STARTED: "node_loop_started", + EventType.NODE_LOOP_ITERATION: "node_loop_iteration", + EventType.NODE_LOOP_COMPLETED: "node_loop_completed", + # LLM streaming observability + EventType.LLM_TEXT_DELTA: "llm_text_delta", + EventType.LLM_REASONING_DELTA: "llm_reasoning_delta", + # Tool lifecycle + EventType.TOOL_CALL_STARTED: "tool_call_started", + EventType.TOOL_CALL_COMPLETED: "tool_call_completed", + # Client I/O + EventType.CLIENT_OUTPUT_DELTA: "client_output_delta", + EventType.CLIENT_INPUT_REQUESTED: "client_input_requested", + # Internal node observability + EventType.NODE_INTERNAL_OUTPUT: "node_internal_output", + EventType.NODE_INPUT_BLOCKED: "node_input_blocked", + EventType.NODE_STALLED: "node_stalled", +} + +# Pre-existing enum members that must remain unchanged +ORIGINAL_EVENT_TYPES = { + EventType.EXECUTION_STARTED: "execution_started", + EventType.EXECUTION_COMPLETED: "execution_completed", + EventType.EXECUTION_FAILED: "execution_failed", + EventType.EXECUTION_PAUSED: "execution_paused", + EventType.EXECUTION_RESUMED: "execution_resumed", + EventType.STATE_CHANGED: "state_changed", + EventType.STATE_CONFLICT: "state_conflict", + EventType.GOAL_PROGRESS: "goal_progress", + EventType.GOAL_ACHIEVED: "goal_achieved", + EventType.CONSTRAINT_VIOLATION: "constraint_violation", + EventType.STREAM_STARTED: "stream_started", + EventType.STREAM_STOPPED: "stream_stopped", + EventType.CUSTOM: "custom", +} + + +# --------------------------------------------------------------------------- +# WP-2 Part A: EventType enum members +# --------------------------------------------------------------------------- +class TestWP2EventTypeEnumMembers: + """All 12 new EventType members exist with correct string values.""" + + @pytest.mark.parametrize( + "member,expected_value", + WP2_EVENT_TYPES.items(), + ids=lambda x: x.name if isinstance(x, EventType) else x, + ) + def test_new_member_value(self, member, expected_value): + assert member.value == expected_value + + def test_all_12_new_members_exist(self): + assert len(WP2_EVENT_TYPES) == 12 + + def test_new_member_string_values_are_unique(self): + values = list(WP2_EVENT_TYPES.values()) + assert len(values) == len(set(values)) + + def test_no_collision_with_original_members(self): + new_values = set(WP2_EVENT_TYPES.values()) + old_values = set(ORIGINAL_EVENT_TYPES.values()) + overlap = new_values & old_values + assert overlap == set(), f"Colliding values: {overlap}" + + @pytest.mark.parametrize( + "member,expected_value", + ORIGINAL_EVENT_TYPES.items(), + ids=lambda x: x.name if isinstance(x, EventType) else x, + ) + def test_original_members_unchanged(self, member, expected_value): + assert member.value == expected_value + + def test_event_type_is_str_enum(self): + """EventType members compare equal to their string values.""" + assert EventType.NODE_LOOP_STARTED == "node_loop_started" + assert EventType.LLM_TEXT_DELTA == "llm_text_delta" + assert EventType.LLM_TEXT_DELTA.value == "llm_text_delta" + + def test_event_type_accessible_by_name(self): + assert EventType["NODE_LOOP_STARTED"] is EventType.NODE_LOOP_STARTED + assert EventType["TOOL_CALL_COMPLETED"] is EventType.TOOL_CALL_COMPLETED + + def test_event_type_accessible_by_value(self): + assert EventType("node_loop_started") is EventType.NODE_LOOP_STARTED + assert EventType("tool_call_completed") is EventType.TOOL_CALL_COMPLETED + + +# --------------------------------------------------------------------------- +# WP-2 Part B: AgentEvent.node_id and Subscription.filter_node +# --------------------------------------------------------------------------- +class TestWP2AgentEventNodeId: + """AgentEvent supports node_id as a first-class field.""" + + def test_node_id_defaults_to_none(self): + event = AgentEvent( + type=EventType.EXECUTION_STARTED, + stream_id="stream-1", + ) + assert event.node_id is None + + def test_node_id_can_be_set(self): + event = AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="stream-1", + node_id="email_composer", + ) + assert event.node_id == "email_composer" + + def test_node_id_in_to_dict(self): + event = AgentEvent( + type=EventType.TOOL_CALL_STARTED, + stream_id="stream-1", + node_id="search_node", + ) + d = event.to_dict() + assert d["node_id"] == "search_node" + + def test_node_id_none_in_to_dict(self): + event = AgentEvent( + type=EventType.EXECUTION_STARTED, + stream_id="stream-1", + ) + d = event.to_dict() + assert "node_id" in d + assert d["node_id"] is None + + +class TestWP2SubscriptionFilterNode: + """Subscription supports filter_node for node-level routing.""" + + @staticmethod + async def _noop_handler(event: AgentEvent) -> None: + pass + + def test_filter_node_defaults_to_none(self): + sub = Subscription( + id="sub_1", + event_types={EventType.LLM_TEXT_DELTA}, + handler=self._noop_handler, + ) + assert sub.filter_node is None + + def test_filter_node_can_be_set(self): + sub = Subscription( + id="sub_1", + event_types={EventType.LLM_TEXT_DELTA}, + handler=self._noop_handler, + filter_node="email_composer", + ) + assert sub.filter_node == "email_composer" + + +# --------------------------------------------------------------------------- +# WP-2 Part B: Node-level event routing integration tests +# --------------------------------------------------------------------------- +class TestWP2NodeLevelRouting: + """EventBus routes events by node_id using filter_node.""" + + @pytest.fixture + def bus(self): + return EventBus() + + @pytest.mark.asyncio + async def test_filter_node_receives_matching_events(self, bus): + """Subscriber with filter_node='node-A' receives events from node-A.""" + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe( + event_types=[EventType.LLM_TEXT_DELTA], + handler=handler, + filter_node="node-A", + ) + + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="stream-1", + node_id="node-A", + data={"content": "hello"}, + ) + ) + + assert len(received) == 1 + assert received[0].node_id == "node-A" + assert received[0].data["content"] == "hello" + + @pytest.mark.asyncio + async def test_filter_node_rejects_non_matching_events(self, bus): + """Subscriber with filter_node='node-B' does NOT receive node-A events.""" + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe( + event_types=[EventType.LLM_TEXT_DELTA], + handler=handler, + filter_node="node-B", + ) + + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="stream-1", + node_id="node-A", + data={"content": "hello"}, + ) + ) + + assert len(received) == 0 + + @pytest.mark.asyncio + async def test_no_filter_node_receives_all_events(self, bus): + """Subscriber with no filter_node receives events from all nodes.""" + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe( + event_types=[EventType.LLM_TEXT_DELTA], + handler=handler, + ) + + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="stream-1", + node_id="node-A", + ) + ) + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="stream-1", + node_id="node-B", + ) + ) + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="stream-1", + node_id=None, + ) + ) + + assert len(received) == 3 + + @pytest.mark.asyncio + async def test_interleaved_nodes_separated_by_filter(self, bus): + """Two subscribers on different nodes get only their node's events.""" + node_a_events = [] + node_b_events = [] + + async def handler_a(event): + node_a_events.append(event) + + async def handler_b(event): + node_b_events.append(event) + + bus.subscribe( + event_types=[EventType.LLM_TEXT_DELTA], + handler=handler_a, + filter_node="email_sender", + ) + bus.subscribe( + event_types=[EventType.LLM_TEXT_DELTA], + handler=handler_b, + filter_node="inbox_scanner", + ) + + # Interleaved events from both nodes + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="webhook", + node_id="email_sender", + data={"content": "Dear Jo"}, + ) + ) + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="webhook", + node_id="inbox_scanner", + data={"content": "RE: Meeting conf"}, + ) + ) + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="webhook", + node_id="email_sender", + data={"content": "hn, Thank you for"}, + ) + ) + await bus.publish( + AgentEvent( + type=EventType.LLM_TEXT_DELTA, + stream_id="webhook", + node_id="inbox_scanner", + data={"content": "irmed for Thursday"}, + ) + ) + + assert len(node_a_events) == 2 + assert len(node_b_events) == 2 + assert node_a_events[0].data["content"] == "Dear Jo" + assert node_a_events[1].data["content"] == "hn, Thank you for" + assert node_b_events[0].data["content"] == "RE: Meeting conf" + assert node_b_events[1].data["content"] == "irmed for Thursday" + + @pytest.mark.asyncio + async def test_filter_node_combined_with_filter_stream(self, bus): + """filter_node and filter_stream work together.""" + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe( + event_types=[EventType.TOOL_CALL_STARTED], + handler=handler, + filter_stream="webhook", + filter_node="search_node", + ) + + # Matching both filters + await bus.publish( + AgentEvent( + type=EventType.TOOL_CALL_STARTED, + stream_id="webhook", + node_id="search_node", + ) + ) + # Wrong stream + await bus.publish( + AgentEvent( + type=EventType.TOOL_CALL_STARTED, + stream_id="api", + node_id="search_node", + ) + ) + # Wrong node + await bus.publish( + AgentEvent( + type=EventType.TOOL_CALL_STARTED, + stream_id="webhook", + node_id="other_node", + ) + ) + + assert len(received) == 1 + assert received[0].stream_id == "webhook" + assert received[0].node_id == "search_node" + + @pytest.mark.asyncio + async def test_wait_for_with_node_id(self, bus): + """wait_for() accepts node_id parameter for filtering.""" + + async def publish_later(): + await asyncio.sleep(0.01) + await bus.publish( + AgentEvent( + type=EventType.NODE_LOOP_COMPLETED, + stream_id="stream-1", + node_id="target_node", + data={"iterations": 3}, + ) + ) + + task = asyncio.create_task(publish_later()) + event = await bus.wait_for( + event_type=EventType.NODE_LOOP_COMPLETED, + node_id="target_node", + timeout=2.0, + ) + await task + + assert event is not None + assert event.node_id == "target_node" + assert event.data["iterations"] == 3 + + @pytest.mark.asyncio + async def test_wait_for_ignores_wrong_node(self, bus): + """wait_for() with node_id ignores events from other nodes.""" + + async def publish_wrong_then_right(): + await asyncio.sleep(0.01) + # Wrong node — should be ignored + await bus.publish( + AgentEvent( + type=EventType.NODE_LOOP_COMPLETED, + stream_id="stream-1", + node_id="wrong_node", + ) + ) + await asyncio.sleep(0.01) + # Right node + await bus.publish( + AgentEvent( + type=EventType.NODE_LOOP_COMPLETED, + stream_id="stream-1", + node_id="target_node", + data={"iterations": 5}, + ) + ) + + task = asyncio.create_task(publish_wrong_then_right()) + event = await bus.wait_for( + event_type=EventType.NODE_LOOP_COMPLETED, + node_id="target_node", + timeout=2.0, + ) + await task + + assert event is not None + assert event.node_id == "target_node" + assert event.data["iterations"] == 5 + + +# --------------------------------------------------------------------------- +# WP-2: Convenience publisher methods +# --------------------------------------------------------------------------- +class TestWP2ConveniencePublishers: + """EventBus convenience methods for new WP-2 event types.""" + + @pytest.fixture + def bus(self): + return EventBus() + + @pytest.mark.asyncio + async def test_emit_node_loop_started(self, bus): + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe(event_types=[EventType.NODE_LOOP_STARTED], handler=handler) + await bus.emit_node_loop_started( + stream_id="s1", + node_id="n1", + max_iterations=10, + ) + + assert len(received) == 1 + assert received[0].node_id == "n1" + assert received[0].data["max_iterations"] == 10 + + @pytest.mark.asyncio + async def test_emit_node_loop_iteration(self, bus): + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe(event_types=[EventType.NODE_LOOP_ITERATION], handler=handler) + await bus.emit_node_loop_iteration( + stream_id="s1", + node_id="n1", + iteration=3, + ) + + assert len(received) == 1 + assert received[0].data["iteration"] == 3 + + @pytest.mark.asyncio + async def test_emit_node_loop_completed(self, bus): + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe(event_types=[EventType.NODE_LOOP_COMPLETED], handler=handler) + await bus.emit_node_loop_completed( + stream_id="s1", + node_id="n1", + iterations=5, + ) + + assert len(received) == 1 + assert received[0].data["iterations"] == 5 + + @pytest.mark.asyncio + async def test_emit_llm_text_delta(self, bus): + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe(event_types=[EventType.LLM_TEXT_DELTA], handler=handler) + await bus.emit_llm_text_delta( + stream_id="s1", + node_id="n1", + content="hello", + snapshot="hello world", + ) + + assert len(received) == 1 + assert received[0].data["content"] == "hello" + assert received[0].data["snapshot"] == "hello world" + + @pytest.mark.asyncio + async def test_emit_tool_call_started(self, bus): + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler) + await bus.emit_tool_call_started( + stream_id="s1", + node_id="n1", + tool_use_id="call_1", + tool_name="web_search", + tool_input={"query": "test"}, + ) + + assert len(received) == 1 + assert received[0].data["tool_name"] == "web_search" + assert received[0].data["tool_input"] == {"query": "test"} + + @pytest.mark.asyncio + async def test_emit_tool_call_completed(self, bus): + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler) + await bus.emit_tool_call_completed( + stream_id="s1", + node_id="n1", + tool_use_id="call_1", + tool_name="web_search", + result="3 results found", + ) + + assert len(received) == 1 + assert received[0].data["result"] == "3 results found" + assert received[0].data["is_error"] is False + + @pytest.mark.asyncio + async def test_emit_client_output_delta(self, bus): + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe(event_types=[EventType.CLIENT_OUTPUT_DELTA], handler=handler) + await bus.emit_client_output_delta( + stream_id="s1", + node_id="n1", + content="chunk", + snapshot="full chunk", + ) + + assert len(received) == 1 + assert received[0].data["content"] == "chunk" + + @pytest.mark.asyncio + async def test_emit_node_stalled(self, bus): + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe(event_types=[EventType.NODE_STALLED], handler=handler) + await bus.emit_node_stalled( + stream_id="s1", + node_id="n1", + reason="no progress after 10 iterations", + ) + + assert len(received) == 1 + assert received[0].data["reason"] == "no progress after 10 iterations" + + @pytest.mark.asyncio + async def test_convenience_publishers_set_node_id(self, bus): + """All WP-2 convenience publishers set node_id on the emitted event.""" + received = [] + + async def handler(event): + received.append(event) + + bus.subscribe( + event_types=[EventType.LLM_TEXT_DELTA, EventType.TOOL_CALL_STARTED], + handler=handler, + filter_node="my_node", + ) + + await bus.emit_llm_text_delta( + stream_id="s1", + node_id="my_node", + content="hi", + snapshot="hi", + ) + await bus.emit_tool_call_started( + stream_id="s1", + node_id="my_node", + tool_use_id="c1", + tool_name="calc", + ) + # Wrong node — should not be received + await bus.emit_llm_text_delta( + stream_id="s1", + node_id="other_node", + content="bye", + snapshot="bye", + ) + + assert len(received) == 2 + assert all(e.node_id == "my_node" for e in received) diff --git a/core/tests/test_litellm_streaming.py b/core/tests/test_litellm_streaming.py new file mode 100644 index 00000000..d068f865 --- /dev/null +++ b/core/tests/test_litellm_streaming.py @@ -0,0 +1,389 @@ +"""Real-API streaming tests for LiteLLM provider. + +Calls live LLM APIs and dumps stream events to JSON files for review. +Results are saved to core/tests/stream_event_dumps/{provider}_{model}_{scenario}.json + +Run with: + cd core && python -m pytest tests/test_litellm_streaming.py -v -s -k "RealAPI" + +Requires API keys set in environment: + ANTHROPIC_API_KEY, OPENAI_API_KEY, GEMINI_API_KEY (or via credential store) +""" + +import asyncio +import json +import logging +import os +from dataclasses import asdict +from pathlib import Path + +import pytest + +from framework.llm.litellm import LiteLLMProvider +from framework.llm.provider import Tool +from framework.llm.stream_events import ( + FinishEvent, + StreamEvent, + TextDeltaEvent, + TextEndEvent, + ToolCallEvent, +) + +logger = logging.getLogger(__name__) + +DUMP_DIR = Path(__file__).parent / "stream_event_dumps" + + +def _serialize_event(index: int, event: StreamEvent) -> dict: + """Serialize a StreamEvent to a JSON-safe dict.""" + d = asdict(event) # type: ignore[arg-type] + d["index"] = index + # Move index to front for readability + return {"index": index, **{k: v for k, v in d.items() if k != "index"}} + + +def _dump_events(events: list[StreamEvent], filename: str) -> Path: + """Write stream events to a JSON file in the dump directory.""" + DUMP_DIR.mkdir(parents=True, exist_ok=True) + filepath = DUMP_DIR / filename + serialized = [_serialize_event(i, e) for i, e in enumerate(events)] + filepath.write_text(json.dumps(serialized, indent=2) + "\n") + logger.info(f"Dumped {len(events)} events to {filepath}") + return filepath + + +async def _collect_stream(provider: LiteLLMProvider, **kwargs) -> list[StreamEvent]: + """Collect all stream events from a provider.stream() call.""" + events: list[StreamEvent] = [] + async for event in provider.stream(**kwargs): + events.append(event) + # Log each event type as it arrives + logger.debug(f" [{len(events) - 1}] {event.type}: {event}") + return events + + +# --------------------------------------------------------------------------- +# Test matrix: (model_id, dump_prefix, env_var_for_skip) +# --------------------------------------------------------------------------- +MODELS = [ + ( + "anthropic/claude-haiku-4-5-20251001", + "anthropic_claude-haiku-4-5-20251001", + "ANTHROPIC_API_KEY", + ), + ("gpt-4.1-nano", "gpt-4.1-nano", "OPENAI_API_KEY"), + ("gemini/gemini-2.0-flash", "gemini_gemini-2.0-flash", "GEMINI_API_KEY"), +] + +WEATHER_TOOL = Tool( + name="get_weather", + description="Get the current weather for a city.", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name, e.g. 'Tokyo'", + } + }, + "required": ["city"], + }, +) + +SEARCH_TOOL = Tool( + name="web_search", + description="Search the web for information.", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query", + }, + "num_results": { + "type": "integer", + "description": "Number of results to return (1-10)", + }, + }, + "required": ["query"], + }, +) + +CALCULATOR_TOOL = Tool( + name="calculator", + description="Perform arithmetic calculations.", + parameters={ + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression to evaluate, e.g. '2 + 2'", + } + }, + "required": ["expression"], + }, +) + + +def _has_api_key(env_var: str) -> bool: + """Check if an API key is available (env var or credential store).""" + if os.environ.get(env_var): + return True + # Try credential store + try: + from aden_tools.credentials import CredentialStoreAdapter + + creds = CredentialStoreAdapter.with_env_storage() + provider_name = env_var.replace("_API_KEY", "").lower() + return creds.is_available(provider_name) + except (ImportError, Exception): + return False + + +# --------------------------------------------------------------------------- +# Real API tests — text streaming +# --------------------------------------------------------------------------- +class TestRealAPITextStreaming: + """Stream a simple text response from each provider and dump events.""" + + @pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS]) + @pytest.mark.asyncio + async def test_text_stream(self, model: str, prefix: str, env_var: str): + """Stream a multi-paragraph response to exercise chunked delivery.""" + if not _has_api_key(env_var): + pytest.skip(f"{env_var} not set") + + provider = LiteLLMProvider(model=model) + events = await _collect_stream( + provider, + messages=[ + { + "role": "user", + "content": ( + "Explain in 3 numbered paragraphs how a CPU executes an instruction. " + "Cover fetch, decode, and execute stages. Be concise but thorough." + ), + } + ], + system="You are a computer science teacher. Give clear, structured explanations.", + max_tokens=512, + ) + + # Dump to file + _dump_events(events, f"{prefix}_text.json") + + # Basic structural assertions + assert len(events) >= 4, f"Expected at least 4 events, got {len(events)}" + + # Must have multiple text deltas for a longer response + text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)] + assert len(text_deltas) >= 3, f"Expected 3+ TextDeltaEvents, got {len(text_deltas)}" + + # Snapshot must accumulate monotonically + for i in range(1, len(text_deltas)): + assert len(text_deltas[i].snapshot) > len( + text_deltas[i - 1].snapshot + ), f"Snapshot did not grow at index {i}" + + # Must end with TextEndEvent then FinishEvent + text_ends = [e for e in events if isinstance(e, TextEndEvent)] + assert len(text_ends) == 1, f"Expected 1 TextEndEvent, got {len(text_ends)}" + + finish_events = [e for e in events if isinstance(e, FinishEvent)] + assert len(finish_events) == 1, f"Expected 1 FinishEvent, got {len(finish_events)}" + assert finish_events[0].stop_reason in ("stop", "end_turn") + + # TextEndEvent.full_text should match last snapshot + assert text_ends[0].full_text == text_deltas[-1].snapshot + + # Response should actually contain multi-paragraph content + full_text = text_ends[0].full_text + assert len(full_text) > 200, f"Response too short ({len(full_text)} chars)" + + +# --------------------------------------------------------------------------- +# Real API tests — tool call streaming +# --------------------------------------------------------------------------- +class TestRealAPIToolCallStreaming: + """Stream a tool call response from each provider and dump events.""" + + @pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS]) + @pytest.mark.asyncio + async def test_tool_call_stream(self, model: str, prefix: str, env_var: str): + """Stream a single tool call with complex arguments.""" + if not _has_api_key(env_var): + pytest.skip(f"{env_var} not set") + + provider = LiteLLMProvider(model=model) + events = await _collect_stream( + provider, + messages=[ + { + "role": "user", + "content": "Search the web for 'Python 3.13 release notes'.", + } + ], + system="You have access to tools. Use the appropriate tool.", + tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL], + max_tokens=512, + ) + + # Dump to file + _dump_events(events, f"{prefix}_tool_call.json") + + # Basic structural assertions + assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}" + + # Must have a tool call event + tool_calls = [e for e in events if isinstance(e, ToolCallEvent)] + assert len(tool_calls) >= 1, "No ToolCallEvent received" + + tc = tool_calls[0] + assert tc.tool_name == "web_search" + assert "query" in tc.tool_input + assert tc.tool_use_id != "" + + # Must end with FinishEvent + finish_events = [e for e in events if isinstance(e, FinishEvent)] + assert len(finish_events) == 1 + assert finish_events[0].stop_reason in ("tool_calls", "tool_use", "stop") + + @pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS]) + @pytest.mark.asyncio + async def test_multi_tool_call_stream(self, model: str, prefix: str, env_var: str): + """Stream a response that should invoke multiple tool calls.""" + if not _has_api_key(env_var): + pytest.skip(f"{env_var} not set") + + provider = LiteLLMProvider(model=model) + events = await _collect_stream( + provider, + messages=[ + { + "role": "user", + "content": ( + "I need three things done in parallel: " + "1) Get the weather in London, " + "2) Get the weather in New York, " + "3) Calculate 1337 * 42. " + "Use the tools for all three." + ), + } + ], + system=( + "You have access to tools. When the user asks for multiple things, " + "call all the needed tools. Always use tools, never guess results." + ), + tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL], + max_tokens=512, + ) + + # Dump to file + _dump_events(events, f"{prefix}_multi_tool.json") + + # Must have multiple tool call events + tool_calls = [e for e in events if isinstance(e, ToolCallEvent)] + assert ( + len(tool_calls) >= 2 + ), f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}" + + # Verify tool names used + tool_names = {tc.tool_name for tc in tool_calls} + assert "get_weather" in tool_names, "Expected get_weather tool call" + + # All tool calls should have non-empty IDs + for tc in tool_calls: + assert tc.tool_use_id != "", f"Empty tool_use_id on {tc.tool_name}" + assert tc.tool_input, f"Empty tool_input on {tc.tool_name}" + + # Must end with FinishEvent + finish_events = [e for e in events if isinstance(e, FinishEvent)] + assert len(finish_events) == 1 + + +# --------------------------------------------------------------------------- +# Convenience runner for manual invocation +# --------------------------------------------------------------------------- +if __name__ == "__main__": + """Run all streaming tests and dump results. Usage: python tests/test_litellm_streaming.py""" + + ALL_TOOLS = [WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL] + + async def _run_all(): + for model, prefix, env_var in MODELS: + if not _has_api_key(env_var): + print(f"SKIP {prefix}: {env_var} not set") + continue + + provider = LiteLLMProvider(model=model) + + # Text streaming (multi-paragraph) + print(f"\n--- {prefix} text ---") + events = await _collect_stream( + provider, + messages=[ + { + "role": "user", + "content": ( + "Explain in 3 numbered paragraphs how a CPU executes an instruction. " + "Cover fetch, decode, and execute stages. Be concise but thorough." + ), + } + ], + system="You are a computer science teacher. Give clear, structured explanations.", + max_tokens=512, + ) + path = _dump_events(events, f"{prefix}_text.json") + print(f" {len(events)} events -> {path}") + for i, e in enumerate(events): + print(f" [{i}] {e.type}: {e}") + + # Tool call streaming + print(f"\n--- {prefix} tool_call ---") + events = await _collect_stream( + provider, + messages=[ + { + "role": "user", + "content": "Search the web for 'Python 3.13 release notes'.", + } + ], + system="You have access to tools. Use the appropriate tool.", + tools=ALL_TOOLS, + max_tokens=512, + ) + path = _dump_events(events, f"{prefix}_tool_call.json") + print(f" {len(events)} events -> {path}") + for i, e in enumerate(events): + print(f" [{i}] {e.type}: {e}") + + # Multi-tool call streaming + print(f"\n--- {prefix} multi_tool ---") + events = await _collect_stream( + provider, + messages=[ + { + "role": "user", + "content": ( + "I need three things done in parallel: " + "1) Get the weather in London, " + "2) Get the weather in New York, " + "3) Calculate 1337 * 42. " + "Use the tools for all three." + ), + } + ], + system=( + "You have access to tools. When the user asks for multiple things, " + "call all the needed tools. Always use tools, never guess results." + ), + tools=ALL_TOOLS, + max_tokens=512, + ) + path = _dump_events(events, f"{prefix}_multi_tool.json") + print(f" {len(events)} events -> {path}") + for i, e in enumerate(events): + print(f" [{i}] {e.type}: {e}") + + logging.basicConfig(level=logging.DEBUG) + asyncio.run(_run_all()) diff --git a/core/tests/test_stream_events.py b/core/tests/test_stream_events.py new file mode 100644 index 00000000..08a30a9a --- /dev/null +++ b/core/tests/test_stream_events.py @@ -0,0 +1,318 @@ +"""Tests for stream event dataclasses. + +Validates construction, defaults, immutability, serialization, and the +StreamEvent discriminated union type. +""" + +from dataclasses import FrozenInstanceError, asdict, fields + +import pytest + +from framework.llm.stream_events import ( + FinishEvent, + ReasoningDeltaEvent, + ReasoningStartEvent, + StreamErrorEvent, + StreamEvent, + TextDeltaEvent, + TextEndEvent, + ToolCallEvent, + ToolResultEvent, +) + +# All concrete event classes in the union +ALL_EVENT_CLASSES = [ + TextDeltaEvent, + TextEndEvent, + ToolCallEvent, + ToolResultEvent, + ReasoningStartEvent, + ReasoningDeltaEvent, + FinishEvent, + StreamErrorEvent, +] + + +# --------------------------------------------------------------------------- +# Construction & defaults +# --------------------------------------------------------------------------- +class TestEventDefaults: + """Each event class should be constructible with zero arguments.""" + + @pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__) + def test_default_construction(self, cls): + event = cls() + assert event.type != "" + + def test_text_delta_defaults(self): + e = TextDeltaEvent() + assert e.type == "text_delta" + assert e.content == "" + assert e.snapshot == "" + + def test_text_end_defaults(self): + e = TextEndEvent() + assert e.type == "text_end" + assert e.full_text == "" + + def test_tool_call_defaults(self): + e = ToolCallEvent() + assert e.type == "tool_call" + assert e.tool_use_id == "" + assert e.tool_name == "" + assert e.tool_input == {} + + def test_tool_result_defaults(self): + e = ToolResultEvent() + assert e.type == "tool_result" + assert e.tool_use_id == "" + assert e.content == "" + assert e.is_error is False + + def test_reasoning_start_defaults(self): + e = ReasoningStartEvent() + assert e.type == "reasoning_start" + + def test_reasoning_delta_defaults(self): + e = ReasoningDeltaEvent() + assert e.type == "reasoning_delta" + assert e.content == "" + + def test_finish_defaults(self): + e = FinishEvent() + assert e.type == "finish" + assert e.stop_reason == "" + assert e.input_tokens == 0 + assert e.output_tokens == 0 + assert e.model == "" + + def test_stream_error_defaults(self): + e = StreamErrorEvent() + assert e.type == "error" + assert e.error == "" + assert e.recoverable is False + + +# --------------------------------------------------------------------------- +# Construction with values +# --------------------------------------------------------------------------- +class TestEventConstruction: + """Events should store provided field values correctly.""" + + def test_text_delta_with_values(self): + e = TextDeltaEvent(content="hello", snapshot="hello world") + assert e.content == "hello" + assert e.snapshot == "hello world" + + def test_text_end_with_values(self): + e = TextEndEvent(full_text="the complete response") + assert e.full_text == "the complete response" + + def test_tool_call_with_values(self): + e = ToolCallEvent( + tool_use_id="call_abc123", + tool_name="web_search", + tool_input={"query": "python", "num_results": 5}, + ) + assert e.tool_use_id == "call_abc123" + assert e.tool_name == "web_search" + assert e.tool_input == {"query": "python", "num_results": 5} + + def test_tool_result_with_values(self): + e = ToolResultEvent( + tool_use_id="call_abc123", + content="search results here", + is_error=False, + ) + assert e.tool_use_id == "call_abc123" + assert e.content == "search results here" + assert e.is_error is False + + def test_tool_result_error(self): + e = ToolResultEvent( + tool_use_id="call_fail", + content="timeout", + is_error=True, + ) + assert e.is_error is True + + def test_reasoning_delta_with_content(self): + e = ReasoningDeltaEvent(content="Let me think about this...") + assert e.content == "Let me think about this..." + + def test_finish_with_values(self): + e = FinishEvent( + stop_reason="end_turn", + input_tokens=150, + output_tokens=300, + model="claude-haiku-4-5", + ) + assert e.stop_reason == "end_turn" + assert e.input_tokens == 150 + assert e.output_tokens == 300 + assert e.model == "claude-haiku-4-5" + + def test_stream_error_with_values(self): + e = StreamErrorEvent(error="rate limit exceeded", recoverable=True) + assert e.error == "rate limit exceeded" + assert e.recoverable is True + + +# --------------------------------------------------------------------------- +# Frozen immutability +# --------------------------------------------------------------------------- +class TestEventImmutability: + """All events are frozen dataclasses — fields cannot be reassigned.""" + + @pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__) + def test_frozen(self, cls): + event = cls() + with pytest.raises(FrozenInstanceError): + event.type = "modified" + + def test_text_delta_frozen_content(self): + e = TextDeltaEvent(content="hello") + with pytest.raises(FrozenInstanceError): + e.content = "modified" + + def test_tool_call_frozen_input(self): + e = ToolCallEvent(tool_input={"key": "value"}) + with pytest.raises(FrozenInstanceError): + e.tool_input = {} + + +# --------------------------------------------------------------------------- +# Type literal values +# --------------------------------------------------------------------------- +class TestTypeLiterals: + """Each event's `type` field should match its Literal annotation.""" + + EXPECTED_TYPES = { + TextDeltaEvent: "text_delta", + TextEndEvent: "text_end", + ToolCallEvent: "tool_call", + ToolResultEvent: "tool_result", + ReasoningStartEvent: "reasoning_start", + ReasoningDeltaEvent: "reasoning_delta", + FinishEvent: "finish", + StreamErrorEvent: "error", + } + + @pytest.mark.parametrize( + "cls,expected_type", + EXPECTED_TYPES.items(), + ids=lambda x: x.__name__ if isinstance(x, type) else x, + ) + def test_type_value(self, cls, expected_type): + assert cls().type == expected_type + + def test_all_types_unique(self): + types = [cls().type for cls in ALL_EVENT_CLASSES] + assert len(types) == len(set(types)), f"Duplicate type values: {types}" + + +# --------------------------------------------------------------------------- +# Serialization via dataclasses.asdict +# --------------------------------------------------------------------------- +class TestEventSerialization: + """Events should round-trip through asdict for JSON serialization.""" + + def test_text_delta_asdict(self): + e = TextDeltaEvent(content="chunk", snapshot="full chunk") + d = asdict(e) + assert d == {"type": "text_delta", "content": "chunk", "snapshot": "full chunk"} + + def test_tool_call_asdict(self): + e = ToolCallEvent( + tool_use_id="id_1", + tool_name="calc", + tool_input={"expression": "2+2"}, + ) + d = asdict(e) + assert d["tool_name"] == "calc" + assert d["tool_input"] == {"expression": "2+2"} + + def test_finish_asdict(self): + e = FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=20, model="gpt-4") + d = asdict(e) + assert d == { + "type": "finish", + "stop_reason": "stop", + "input_tokens": 10, + "output_tokens": 20, + "model": "gpt-4", + } + + @pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__) + def test_asdict_contains_type(self, cls): + d = asdict(cls()) + assert "type" in d + + @pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__) + def test_asdict_keys_match_fields(self, cls): + event = cls() + d = asdict(event) + field_names = {f.name for f in fields(cls)} + assert set(d.keys()) == field_names + + +# --------------------------------------------------------------------------- +# StreamEvent union type +# --------------------------------------------------------------------------- +class TestStreamEventUnion: + """The StreamEvent union should include all event classes.""" + + def test_union_contains_all_classes(self): + # StreamEvent is a UnionType (PEP 604 syntax: X | Y | Z) + union_args = StreamEvent.__args__ # type: ignore[attr-defined] + for cls in ALL_EVENT_CLASSES: + assert cls in union_args, f"{cls.__name__} not in StreamEvent union" + + def test_union_has_exactly_expected_members(self): + union_args = set(StreamEvent.__args__) # type: ignore[attr-defined] + expected = set(ALL_EVENT_CLASSES) + assert union_args == expected + + @pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__) + def test_isinstance_check(self, cls): + """Each event instance should be an instance of its class (basic sanity).""" + event = cls() + assert isinstance(event, cls) + + +# --------------------------------------------------------------------------- +# Equality & hashing (frozen dataclasses support both) +# --------------------------------------------------------------------------- +class TestEventEquality: + """Frozen dataclasses support equality and hashing.""" + + def test_equal_events(self): + a = TextDeltaEvent(content="hi", snapshot="hi") + b = TextDeltaEvent(content="hi", snapshot="hi") + assert a == b + + def test_unequal_events(self): + a = TextDeltaEvent(content="hi") + b = TextDeltaEvent(content="bye") + assert a != b + + def test_different_types_not_equal(self): + a = TextDeltaEvent(content="hi") + b = ReasoningDeltaEvent(content="hi") + assert a != b + + def test_hashable(self): + e = FinishEvent(stop_reason="stop", model="gpt-4") + s = {e} # should be hashable since frozen + assert e in s + + def test_equal_events_same_hash(self): + a = FinishEvent(stop_reason="stop", model="gpt-4") + b = FinishEvent(stop_reason="stop", model="gpt-4") + assert hash(a) == hash(b) + + def test_events_with_dict_not_hashable(self): + """Events containing dict fields (e.g. tool_input) are not hashable.""" + e = ToolCallEvent(tool_use_id="x", tool_name="y", tool_input={"key": "val"}) + with pytest.raises(TypeError, match="unhashable type"): + hash(e)