Compare commits

...

1 Commits

Author SHA1 Message Date
bryan 6da5f4d0d4 wp-5 implemented 2026-02-02 12:12:49 -08:00
4 changed files with 397 additions and 22 deletions
+7 -2
View File
@@ -42,8 +42,13 @@ except ImportError:
pass
try:
from framework.llm.mock import MockLLMProvider # noqa: F401
from framework.llm.mock import ( # noqa: F401
MockLLMProvider,
error_scenario,
text_scenario,
tool_call_scenario,
)
__all__.append("MockLLMProvider")
__all__.extend(["MockLLMProvider", "text_scenario", "tool_call_scenario", "error_scenario"])
except ImportError:
pass
+80 -14
View File
@@ -8,9 +8,11 @@ from typing import Any
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
@@ -32,14 +34,25 @@ class MockLLMProvider(LLMProvider):
# Returns: {"name": "mock_value", "age": "mock_value"}
"""
def __init__(self, model: str = "mock-model"):
def __init__(
self,
model: str = "mock-model",
scenarios: list[list[StreamEvent]] | None = None,
):
"""
Initialize the mock LLM provider.
Args:
model: Model name to report in responses (default: "mock-model")
scenarios: Optional list of pre-programmed StreamEvent sequences.
Each call to stream() consumes the next scenario, cycling
back when exhausted. If None, stream() falls back to the
default word-splitting behavior.
"""
self.model = model
self.scenarios = scenarios or []
self._call_index = 0
self.stream_calls: list[dict] = []
def _extract_output_keys(self, system: str) -> list[str]:
"""
@@ -189,20 +202,73 @@ class MockLLMProvider(LLMProvider):
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
"""Stream a mock completion as word-level TextDeltaEvents.
"""Stream a mock completion.
Splits the mock response into words and yields each as a separate
TextDeltaEvent with an accumulating snapshot, exercising the full
streaming pipeline without any API calls.
With scenarios: yield events from next scenario, cycling via
``_call_index % len(scenarios)``.
Without scenarios: fall back to word-splitting behavior (backward
compatible).
Every call is recorded in ``self.stream_calls`` for test assertions.
"""
content = self._generate_mock_response(system=system, json_mode=False)
words = content.split(" ")
accumulated = ""
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
for i, word in enumerate(words):
chunk = word if i == 0 else " " + word
accumulated += chunk
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
if self.scenarios:
events = self.scenarios[self._call_index % len(self.scenarios)]
self._call_index += 1
for event in events:
yield event
else:
# Original default behavior preserved
content = self._generate_mock_response(system=system, json_mode=False)
words = content.split(" ")
accumulated = ""
for i, word in enumerate(words):
chunk = word if i == 0 else " " + word
accumulated += chunk
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
yield TextEndEvent(full_text=accumulated)
yield FinishEvent(stop_reason="mock_complete", model=self.model)
yield TextEndEvent(full_text=accumulated)
yield FinishEvent(stop_reason="mock_complete", model=self.model)
# ---------------------------------------------------------------------------
# Scenario helpers — convenience builders for common stream event sequences
# ---------------------------------------------------------------------------
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list[StreamEvent]:
"""Build a stream scenario that produces text and finishes."""
return [
TextDeltaEvent(content=text, snapshot=text),
FinishEvent(
stop_reason="stop",
input_tokens=input_tokens,
output_tokens=output_tokens,
model="mock",
),
]
def tool_call_scenario(
tool_name: str,
tool_input: dict,
tool_use_id: str = "call_1",
text: str = "",
) -> list[StreamEvent]:
"""Build a stream scenario that produces a tool call (optionally preceded by text)."""
events: list[StreamEvent] = []
if text:
events.append(TextDeltaEvent(content=text, snapshot=text))
events.append(
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
)
events.append(
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
)
return events
def error_scenario(error: str = "Connection lost", recoverable: bool = False) -> list[StreamEvent]:
"""Build a stream scenario that produces a StreamErrorEvent."""
return [StreamErrorEvent(error=error, recoverable=recoverable)]
+6 -6
View File
@@ -181,9 +181,9 @@ class TestRealAPITextStreaming:
# Snapshot must accumulate monotonically
for i in range(1, len(text_deltas)):
assert len(text_deltas[i].snapshot) > len(
text_deltas[i - 1].snapshot
), f"Snapshot did not grow at index {i}"
assert len(text_deltas[i].snapshot) > len(text_deltas[i - 1].snapshot), (
f"Snapshot did not grow at index {i}"
)
# Must end with TextEndEvent then FinishEvent
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
@@ -283,9 +283,9 @@ class TestRealAPIToolCallStreaming:
# Must have multiple tool call events
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
assert (
len(tool_calls) >= 2
), f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
assert len(tool_calls) >= 2, (
f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
)
# Verify tool names used
tool_names = {tc.tool_name for tc in tool_calls}
+304
View File
@@ -0,0 +1,304 @@
"""Tests for MockLLMProvider streaming scenarios and helper functions.
Proves that the enhanced MockLLMProvider can deterministically simulate
text-only, tool-call, error, and multi-turn streaming sequences for CI
without any real API calls.
"""
from __future__ import annotations
import pytest
from framework.llm.mock import (
MockLLMProvider,
error_scenario,
text_scenario,
tool_call_scenario,
)
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _collect(provider, **kwargs):
"""Collect all events from a single stream() call."""
events = []
async for event in provider.stream(messages=[{"role": "user", "content": "hi"}], **kwargs):
events.append(event)
return events
# ===========================================================================
# Default (no scenarios) — backward compatibility
# ===========================================================================
class TestDefaultNoScenarios:
@pytest.mark.asyncio
async def test_default_no_scenarios(self):
"""No scenarios = word-split TextDeltaEvents + TextEndEvent + FinishEvent."""
provider = MockLLMProvider()
events = await _collect(provider)
text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
finishes = [e for e in events if isinstance(e, FinishEvent)]
assert len(text_deltas) >= 1
assert len(text_ends) == 1
assert len(finishes) == 1
assert finishes[0].stop_reason == "mock_complete"
# Snapshot of last delta should match full_text of TextEndEvent
assert text_deltas[-1].snapshot == text_ends[0].full_text
@pytest.mark.asyncio
async def test_event_type_sequence(self):
"""Events follow TextDeltaEvent*, TextEndEvent, FinishEvent pattern."""
provider = MockLLMProvider(model="mock-test")
events = await _collect(provider)
# All events before the last two must be TextDeltaEvent
for e in events[:-2]:
assert isinstance(e, TextDeltaEvent)
assert isinstance(events[-2], TextEndEvent)
assert isinstance(events[-1], FinishEvent)
@pytest.mark.asyncio
async def test_snapshot_monotonic_growth(self):
"""Each TextDeltaEvent.snapshot is a prefix of the next."""
provider = MockLLMProvider(model="mock-test")
events = await _collect(provider)
deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
for i in range(1, len(deltas)):
assert deltas[i].snapshot.startswith(deltas[i - 1].snapshot)
@pytest.mark.asyncio
async def test_full_text_matches_chunks(self):
"""TextEndEvent.full_text == concatenation of all chunk contents."""
provider = MockLLMProvider(model="mock-test")
events = await _collect(provider)
deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
concatenated = "".join(e.content for e in deltas)
end_event = next(e for e in events if isinstance(e, TextEndEvent))
assert end_event.full_text == concatenated
# ===========================================================================
# text_scenario
# ===========================================================================
class TestTextScenario:
@pytest.mark.asyncio
async def test_text_scenario(self):
"""text_scenario yields TextDeltaEvent then FinishEvent(stop_reason='stop')."""
provider = MockLLMProvider(scenarios=[text_scenario("Hello world")])
events = await _collect(provider)
assert len(events) == 2
assert isinstance(events[0], TextDeltaEvent)
assert events[0].content == "Hello world"
assert events[0].snapshot == "Hello world"
assert isinstance(events[1], FinishEvent)
assert events[1].stop_reason == "stop"
@pytest.mark.asyncio
async def test_text_scenario_custom_tokens(self):
"""text_scenario respects custom input/output token counts."""
provider = MockLLMProvider(
scenarios=[text_scenario("ok", input_tokens=100, output_tokens=50)]
)
events = await _collect(provider)
finish = events[-1]
assert finish.input_tokens == 100
assert finish.output_tokens == 50
# ===========================================================================
# tool_call_scenario
# ===========================================================================
class TestToolCallScenario:
@pytest.mark.asyncio
async def test_tool_call_scenario(self):
"""tool_call_scenario yields ToolCallEvent + FinishEvent(stop_reason='tool_calls')."""
provider = MockLLMProvider(scenarios=[tool_call_scenario("search", {"query": "test"})])
events = await _collect(provider)
assert len(events) == 2
assert isinstance(events[0], ToolCallEvent)
assert events[0].tool_name == "search"
assert events[0].tool_input == {"query": "test"}
assert events[0].tool_use_id == "call_1"
assert isinstance(events[1], FinishEvent)
assert events[1].stop_reason == "tool_calls"
@pytest.mark.asyncio
async def test_tool_call_with_text(self):
"""tool_call_scenario with text= yields TextDeltaEvent before ToolCallEvent."""
provider = MockLLMProvider(
scenarios=[tool_call_scenario("run", {"cmd": "ls"}, text="Let me check")]
)
events = await _collect(provider)
assert len(events) == 3
assert isinstance(events[0], TextDeltaEvent)
assert events[0].content == "Let me check"
assert isinstance(events[1], ToolCallEvent)
assert isinstance(events[2], FinishEvent)
@pytest.mark.asyncio
async def test_tool_call_custom_id(self):
"""tool_call_scenario respects custom tool_use_id."""
provider = MockLLMProvider(
scenarios=[tool_call_scenario("search", {}, tool_use_id="custom_123")]
)
events = await _collect(provider)
tool_event = events[0]
assert tool_event.tool_use_id == "custom_123"
# ===========================================================================
# error_scenario
# ===========================================================================
class TestErrorScenario:
@pytest.mark.asyncio
async def test_error_scenario(self):
"""error_scenario yields StreamErrorEvent with correct fields."""
provider = MockLLMProvider(scenarios=[error_scenario()])
events = await _collect(provider)
assert len(events) == 1
assert isinstance(events[0], StreamErrorEvent)
assert events[0].error == "Connection lost"
assert events[0].recoverable is False
@pytest.mark.asyncio
async def test_error_scenario_custom(self):
"""error_scenario respects custom error and recoverable."""
provider = MockLLMProvider(scenarios=[error_scenario(error="Rate limit", recoverable=True)])
events = await _collect(provider)
assert events[0].error == "Rate limit"
assert events[0].recoverable is True
# ===========================================================================
# Multi-turn cycling
# ===========================================================================
class TestMultiTurnCycling:
@pytest.mark.asyncio
async def test_multi_turn_cycling(self):
"""Two scenarios, two stream() calls — first yields scenario 1, second scenario 2."""
provider = MockLLMProvider(
scenarios=[
text_scenario("first"),
text_scenario("second"),
]
)
events1 = await _collect(provider)
events2 = await _collect(provider)
assert events1[0].content == "first"
assert events2[0].content == "second"
@pytest.mark.asyncio
async def test_scenario_wraps_around(self):
"""One scenario, three stream() calls — all yield the same events."""
provider = MockLLMProvider(scenarios=[text_scenario("only")])
events1 = await _collect(provider)
events2 = await _collect(provider)
events3 = await _collect(provider)
assert events1[0].content == "only"
assert events2[0].content == "only"
assert events3[0].content == "only"
assert provider._call_index == 3
# ===========================================================================
# Call recording
# ===========================================================================
class TestCallRecording:
@pytest.mark.asyncio
async def test_call_recording(self):
"""stream_calls captures messages, system, tools from each call."""
provider = MockLLMProvider(scenarios=[text_scenario("ok")])
tools = [{"name": "search"}]
await _collect(provider, system="Be helpful", tools=tools)
assert len(provider.stream_calls) == 1
call = provider.stream_calls[0]
# _collect passes messages=[{"role": "user", "content": "hi"}]
assert call["messages"] == [{"role": "user", "content": "hi"}]
assert call["system"] == "Be helpful"
assert call["tools"] == tools
@pytest.mark.asyncio
async def test_call_recording_default_fallback(self):
"""stream_calls works in default (no scenario) mode too."""
provider = MockLLMProvider()
await _collect(provider)
assert len(provider.stream_calls) == 1
assert provider.stream_calls[0]["system"] == ""
assert provider.stream_calls[0]["tools"] is None
# ===========================================================================
# Model preservation
# ===========================================================================
class TestModelPreserved:
@pytest.mark.asyncio
async def test_model_preserved(self):
"""MockLLMProvider(model='gpt-test') default stream uses that model in FinishEvent."""
provider = MockLLMProvider(model="gpt-test")
events = await _collect(provider)
finish = [e for e in events if isinstance(e, FinishEvent)][0]
assert finish.model == "gpt-test"
# ===========================================================================
# No-arg construction (SpyLLMProvider compat)
# ===========================================================================
class TestNoArgConstruction:
@pytest.mark.asyncio
async def test_no_arg_construction(self):
"""MockLLMProvider() works with no arguments (SpyLLMProvider compat)."""
provider = MockLLMProvider()
assert provider.model == "mock-model"
assert provider.scenarios == []
assert provider._call_index == 0
assert provider.stream_calls == []
# Should still stream without error
events = await _collect(provider)
assert len(events) >= 2 # at least one TextDelta + FinishEvent