Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6da5f4d0d4 |
@@ -42,8 +42,13 @@ except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from framework.llm.mock import MockLLMProvider # noqa: F401
|
||||
from framework.llm.mock import ( # noqa: F401
|
||||
MockLLMProvider,
|
||||
error_scenario,
|
||||
text_scenario,
|
||||
tool_call_scenario,
|
||||
)
|
||||
|
||||
__all__.append("MockLLMProvider")
|
||||
__all__.extend(["MockLLMProvider", "text_scenario", "tool_call_scenario", "error_scenario"])
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
+80
-14
@@ -8,9 +8,11 @@ from typing import Any
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,14 +34,25 @@ class MockLLMProvider(LLMProvider):
|
||||
# Returns: {"name": "mock_value", "age": "mock_value"}
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "mock-model"):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "mock-model",
|
||||
scenarios: list[list[StreamEvent]] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the mock LLM provider.
|
||||
|
||||
Args:
|
||||
model: Model name to report in responses (default: "mock-model")
|
||||
scenarios: Optional list of pre-programmed StreamEvent sequences.
|
||||
Each call to stream() consumes the next scenario, cycling
|
||||
back when exhausted. If None, stream() falls back to the
|
||||
default word-splitting behavior.
|
||||
"""
|
||||
self.model = model
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
def _extract_output_keys(self, system: str) -> list[str]:
|
||||
"""
|
||||
@@ -189,20 +202,73 @@ class MockLLMProvider(LLMProvider):
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a mock completion as word-level TextDeltaEvents.
|
||||
"""Stream a mock completion.
|
||||
|
||||
Splits the mock response into words and yields each as a separate
|
||||
TextDeltaEvent with an accumulating snapshot, exercising the full
|
||||
streaming pipeline without any API calls.
|
||||
With scenarios: yield events from next scenario, cycling via
|
||||
``_call_index % len(scenarios)``.
|
||||
|
||||
Without scenarios: fall back to word-splitting behavior (backward
|
||||
compatible).
|
||||
|
||||
Every call is recorded in ``self.stream_calls`` for test assertions.
|
||||
"""
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
|
||||
for i, word in enumerate(words):
|
||||
chunk = word if i == 0 else " " + word
|
||||
accumulated += chunk
|
||||
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
|
||||
if self.scenarios:
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
else:
|
||||
# Original default behavior preserved
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
for i, word in enumerate(words):
|
||||
chunk = word if i == 0 else " " + word
|
||||
accumulated += chunk
|
||||
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(stop_reason="mock_complete", model=self.model)
|
||||
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(stop_reason="mock_complete", model=self.model)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario helpers — convenience builders for common stream event sequences
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list[StreamEvent]:
|
||||
"""Build a stream scenario that produces text and finishes."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(
|
||||
stop_reason="stop",
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model="mock",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def tool_call_scenario(
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
tool_use_id: str = "call_1",
|
||||
text: str = "",
|
||||
) -> list[StreamEvent]:
|
||||
"""Build a stream scenario that produces a tool call (optionally preceded by text)."""
|
||||
events: list[StreamEvent] = []
|
||||
if text:
|
||||
events.append(TextDeltaEvent(content=text, snapshot=text))
|
||||
events.append(
|
||||
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
|
||||
)
|
||||
events.append(
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def error_scenario(error: str = "Connection lost", recoverable: bool = False) -> list[StreamEvent]:
|
||||
"""Build a stream scenario that produces a StreamErrorEvent."""
|
||||
return [StreamErrorEvent(error=error, recoverable=recoverable)]
|
||||
|
||||
@@ -181,9 +181,9 @@ class TestRealAPITextStreaming:
|
||||
|
||||
# Snapshot must accumulate monotonically
|
||||
for i in range(1, len(text_deltas)):
|
||||
assert len(text_deltas[i].snapshot) > len(
|
||||
text_deltas[i - 1].snapshot
|
||||
), f"Snapshot did not grow at index {i}"
|
||||
assert len(text_deltas[i].snapshot) > len(text_deltas[i - 1].snapshot), (
|
||||
f"Snapshot did not grow at index {i}"
|
||||
)
|
||||
|
||||
# Must end with TextEndEvent then FinishEvent
|
||||
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
|
||||
@@ -283,9 +283,9 @@ class TestRealAPIToolCallStreaming:
|
||||
|
||||
# Must have multiple tool call events
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert (
|
||||
len(tool_calls) >= 2
|
||||
), f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
|
||||
assert len(tool_calls) >= 2, (
|
||||
f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
|
||||
)
|
||||
|
||||
# Verify tool names used
|
||||
tool_names = {tc.tool_name for tc in tool_calls}
|
||||
|
||||
@@ -0,0 +1,304 @@
|
||||
"""Tests for MockLLMProvider streaming scenarios and helper functions.
|
||||
|
||||
Proves that the enhanced MockLLMProvider can deterministically simulate
|
||||
text-only, tool-call, error, and multi-turn streaming sequences for CI
|
||||
without any real API calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.mock import (
|
||||
MockLLMProvider,
|
||||
error_scenario,
|
||||
text_scenario,
|
||||
tool_call_scenario,
|
||||
)
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _collect(provider, **kwargs):
|
||||
"""Collect all events from a single stream() call."""
|
||||
events = []
|
||||
async for event in provider.stream(messages=[{"role": "user", "content": "hi"}], **kwargs):
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Default (no scenarios) — backward compatibility
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestDefaultNoScenarios:
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_no_scenarios(self):
|
||||
"""No scenarios = word-split TextDeltaEvents + TextEndEvent + FinishEvent."""
|
||||
provider = MockLLMProvider()
|
||||
events = await _collect(provider)
|
||||
|
||||
text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
|
||||
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
|
||||
finishes = [e for e in events if isinstance(e, FinishEvent)]
|
||||
|
||||
assert len(text_deltas) >= 1
|
||||
assert len(text_ends) == 1
|
||||
assert len(finishes) == 1
|
||||
assert finishes[0].stop_reason == "mock_complete"
|
||||
|
||||
# Snapshot of last delta should match full_text of TextEndEvent
|
||||
assert text_deltas[-1].snapshot == text_ends[0].full_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_type_sequence(self):
|
||||
"""Events follow TextDeltaEvent*, TextEndEvent, FinishEvent pattern."""
|
||||
provider = MockLLMProvider(model="mock-test")
|
||||
events = await _collect(provider)
|
||||
|
||||
# All events before the last two must be TextDeltaEvent
|
||||
for e in events[:-2]:
|
||||
assert isinstance(e, TextDeltaEvent)
|
||||
assert isinstance(events[-2], TextEndEvent)
|
||||
assert isinstance(events[-1], FinishEvent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_snapshot_monotonic_growth(self):
|
||||
"""Each TextDeltaEvent.snapshot is a prefix of the next."""
|
||||
provider = MockLLMProvider(model="mock-test")
|
||||
events = await _collect(provider)
|
||||
|
||||
deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
|
||||
for i in range(1, len(deltas)):
|
||||
assert deltas[i].snapshot.startswith(deltas[i - 1].snapshot)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_text_matches_chunks(self):
|
||||
"""TextEndEvent.full_text == concatenation of all chunk contents."""
|
||||
provider = MockLLMProvider(model="mock-test")
|
||||
events = await _collect(provider)
|
||||
|
||||
deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
|
||||
concatenated = "".join(e.content for e in deltas)
|
||||
end_event = next(e for e in events if isinstance(e, TextEndEvent))
|
||||
assert end_event.full_text == concatenated
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# text_scenario
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestTextScenario:
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_scenario(self):
|
||||
"""text_scenario yields TextDeltaEvent then FinishEvent(stop_reason='stop')."""
|
||||
provider = MockLLMProvider(scenarios=[text_scenario("Hello world")])
|
||||
events = await _collect(provider)
|
||||
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TextDeltaEvent)
|
||||
assert events[0].content == "Hello world"
|
||||
assert events[0].snapshot == "Hello world"
|
||||
assert isinstance(events[1], FinishEvent)
|
||||
assert events[1].stop_reason == "stop"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_scenario_custom_tokens(self):
|
||||
"""text_scenario respects custom input/output token counts."""
|
||||
provider = MockLLMProvider(
|
||||
scenarios=[text_scenario("ok", input_tokens=100, output_tokens=50)]
|
||||
)
|
||||
events = await _collect(provider)
|
||||
|
||||
finish = events[-1]
|
||||
assert finish.input_tokens == 100
|
||||
assert finish.output_tokens == 50
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# tool_call_scenario
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestToolCallScenario:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_scenario(self):
|
||||
"""tool_call_scenario yields ToolCallEvent + FinishEvent(stop_reason='tool_calls')."""
|
||||
provider = MockLLMProvider(scenarios=[tool_call_scenario("search", {"query": "test"})])
|
||||
events = await _collect(provider)
|
||||
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], ToolCallEvent)
|
||||
assert events[0].tool_name == "search"
|
||||
assert events[0].tool_input == {"query": "test"}
|
||||
assert events[0].tool_use_id == "call_1"
|
||||
assert isinstance(events[1], FinishEvent)
|
||||
assert events[1].stop_reason == "tool_calls"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_text(self):
|
||||
"""tool_call_scenario with text= yields TextDeltaEvent before ToolCallEvent."""
|
||||
provider = MockLLMProvider(
|
||||
scenarios=[tool_call_scenario("run", {"cmd": "ls"}, text="Let me check")]
|
||||
)
|
||||
events = await _collect(provider)
|
||||
|
||||
assert len(events) == 3
|
||||
assert isinstance(events[0], TextDeltaEvent)
|
||||
assert events[0].content == "Let me check"
|
||||
assert isinstance(events[1], ToolCallEvent)
|
||||
assert isinstance(events[2], FinishEvent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_custom_id(self):
|
||||
"""tool_call_scenario respects custom tool_use_id."""
|
||||
provider = MockLLMProvider(
|
||||
scenarios=[tool_call_scenario("search", {}, tool_use_id="custom_123")]
|
||||
)
|
||||
events = await _collect(provider)
|
||||
|
||||
tool_event = events[0]
|
||||
assert tool_event.tool_use_id == "custom_123"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# error_scenario
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestErrorScenario:
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_scenario(self):
|
||||
"""error_scenario yields StreamErrorEvent with correct fields."""
|
||||
provider = MockLLMProvider(scenarios=[error_scenario()])
|
||||
events = await _collect(provider)
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], StreamErrorEvent)
|
||||
assert events[0].error == "Connection lost"
|
||||
assert events[0].recoverable is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_scenario_custom(self):
|
||||
"""error_scenario respects custom error and recoverable."""
|
||||
provider = MockLLMProvider(scenarios=[error_scenario(error="Rate limit", recoverable=True)])
|
||||
events = await _collect(provider)
|
||||
|
||||
assert events[0].error == "Rate limit"
|
||||
assert events[0].recoverable is True
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Multi-turn cycling
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestMultiTurnCycling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_cycling(self):
|
||||
"""Two scenarios, two stream() calls — first yields scenario 1, second scenario 2."""
|
||||
provider = MockLLMProvider(
|
||||
scenarios=[
|
||||
text_scenario("first"),
|
||||
text_scenario("second"),
|
||||
]
|
||||
)
|
||||
|
||||
events1 = await _collect(provider)
|
||||
events2 = await _collect(provider)
|
||||
|
||||
assert events1[0].content == "first"
|
||||
assert events2[0].content == "second"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_wraps_around(self):
|
||||
"""One scenario, three stream() calls — all yield the same events."""
|
||||
provider = MockLLMProvider(scenarios=[text_scenario("only")])
|
||||
|
||||
events1 = await _collect(provider)
|
||||
events2 = await _collect(provider)
|
||||
events3 = await _collect(provider)
|
||||
|
||||
assert events1[0].content == "only"
|
||||
assert events2[0].content == "only"
|
||||
assert events3[0].content == "only"
|
||||
assert provider._call_index == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Call recording
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCallRecording:
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_recording(self):
|
||||
"""stream_calls captures messages, system, tools from each call."""
|
||||
provider = MockLLMProvider(scenarios=[text_scenario("ok")])
|
||||
|
||||
tools = [{"name": "search"}]
|
||||
await _collect(provider, system="Be helpful", tools=tools)
|
||||
|
||||
assert len(provider.stream_calls) == 1
|
||||
call = provider.stream_calls[0]
|
||||
# _collect passes messages=[{"role": "user", "content": "hi"}]
|
||||
assert call["messages"] == [{"role": "user", "content": "hi"}]
|
||||
assert call["system"] == "Be helpful"
|
||||
assert call["tools"] == tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_recording_default_fallback(self):
|
||||
"""stream_calls works in default (no scenario) mode too."""
|
||||
provider = MockLLMProvider()
|
||||
await _collect(provider)
|
||||
|
||||
assert len(provider.stream_calls) == 1
|
||||
assert provider.stream_calls[0]["system"] == ""
|
||||
assert provider.stream_calls[0]["tools"] is None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Model preservation
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestModelPreserved:
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_preserved(self):
|
||||
"""MockLLMProvider(model='gpt-test') default stream uses that model in FinishEvent."""
|
||||
provider = MockLLMProvider(model="gpt-test")
|
||||
events = await _collect(provider)
|
||||
|
||||
finish = [e for e in events if isinstance(e, FinishEvent)][0]
|
||||
assert finish.model == "gpt-test"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# No-arg construction (SpyLLMProvider compat)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNoArgConstruction:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_arg_construction(self):
|
||||
"""MockLLMProvider() works with no arguments (SpyLLMProvider compat)."""
|
||||
provider = MockLLMProvider()
|
||||
assert provider.model == "mock-model"
|
||||
assert provider.scenarios == []
|
||||
assert provider._call_index == 0
|
||||
assert provider.stream_calls == []
|
||||
|
||||
# Should still stream without error
|
||||
events = await _collect(provider)
|
||||
assert len(events) >= 2 # at least one TextDelta + FinishEvent
|
||||
Reference in New Issue
Block a user