Files
hive/core/tests/test_litellm_provider.py
T
2026-04-23 15:34:07 -07:00

1638 lines
66 KiB
Python

"""Tests for LiteLLM provider.
Run with:
cd core
uv pip install litellm pytest
pytest tests/test_litellm_provider.py -v
For live tests (requires API keys):
OPENAI_API_KEY=sk-... pytest tests/test_litellm_provider.py -v -m live
"""
import asyncio
import os
import threading
import time
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from framework.config import get_llm_extra_kwargs
from framework.llm.anthropic import AnthropicProvider
from framework.llm.litellm import (
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
LiteLLMProvider,
_build_system_message,
_compute_retry_delay,
_cost_from_tokens,
_ensure_ollama_chat_prefix,
_extract_cache_tokens,
_extract_cost,
_is_ollama_model,
_model_supports_cache_control,
_summarize_request_for_log,
)
from framework.llm.provider import LLMProvider, LLMResponse, Tool
class TestLiteLLMProviderInit:
"""Test LiteLLMProvider initialization."""
def test_init_with_defaults(self):
"""Test initialization with default parameters."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
provider = LiteLLMProvider()
assert provider.model == "gpt-4o-mini"
assert provider.api_key is None
assert provider.api_base is None
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
provider = LiteLLMProvider(model="claude-3-haiku-20240307")
assert provider.model == "claude-3-haiku-20240307"
def test_init_deepseek_model(self):
"""Test initialization with DeepSeek model."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
assert provider.model == "deepseek/deepseek-chat"
def test_init_with_api_key(self):
"""Test initialization with explicit API key."""
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="my-api-key")
assert provider.api_key == "my-api-key"
def test_init_with_api_base(self):
"""Test initialization with custom API base."""
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="my-key", api_base="https://my-proxy.com/v1")
assert provider.api_base == "https://my-proxy.com/v1"
def test_init_minimax_defaults_api_base(self):
"""MiniMax should default to the official OpenAI-compatible endpoint."""
provider = LiteLLMProvider(model="minimax/MiniMax-M2.1", api_key="my-key")
assert provider.api_base == "https://api.minimax.io/v1"
def test_init_minimax_keeps_custom_api_base(self):
"""Explicit api_base should win over MiniMax defaults."""
provider = LiteLLMProvider(
model="minimax/MiniMax-M2.1",
api_key="my-key",
api_base="https://proxy.example/v1",
)
assert provider.api_base == "https://proxy.example/v1"
def test_init_openrouter_defaults_api_base(self):
"""OpenRouter should default to the official OpenAI-compatible endpoint."""
provider = LiteLLMProvider(model="openrouter/x-ai/grok-4.20-beta", api_key="my-key")
assert provider.api_base == "https://openrouter.ai/api/v1"
def test_init_openrouter_keeps_custom_api_base(self):
"""Explicit api_base should win over OpenRouter defaults."""
provider = LiteLLMProvider(
model="openrouter/x-ai/grok-4.20-beta",
api_key="my-key",
api_base="https://proxy.example/v1",
)
assert provider.api_base == "https://proxy.example/v1"
def test_init_ollama_no_key_needed(self):
"""Test that Ollama models don't require API key."""
with patch.dict(os.environ, {}, clear=True):
# Should not raise; ollama/ is normalised to ollama_chat/ for tool-call support.
provider = LiteLLMProvider(model="ollama/llama3")
assert provider.model == "ollama_chat/llama3"
def test_summarize_request_flags_system_only_payload(self):
"""Request summaries should make system-only payloads obvious in logs."""
summary = _summarize_request_for_log(
{
"model": "openai/glm-5",
"api_base": "https://api.z.ai/api/coding/paas/v4",
"messages": [{"role": "system", "content": "You are helpful."}],
"tools": [{"type": "function", "function": {"name": "read_file"}}],
"stream": True,
"max_tokens": 8192,
}
)
assert summary["message_count"] == 1
assert summary["non_system_message_count"] == 0
assert summary["first_non_system_role"] is None
assert summary["last_non_system_role"] is None
assert summary["system_only"] is True
class TestLiteLLMProviderComplete:
"""Test LiteLLMProvider.complete() method."""
@patch("litellm.completion")
def test_complete_basic(self, mock_completion):
"""Test basic completion call."""
# Mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello! I'm an AI assistant."
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_completion.return_value = mock_response
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
result = provider.complete(messages=[{"role": "user", "content": "Hello"}])
assert result.content == "Hello! I'm an AI assistant."
assert result.model == "gpt-4o-mini"
assert result.input_tokens == 10
assert result.output_tokens == 20
assert result.stop_reason == "stop"
# Verify litellm.completion was called correctly
mock_completion.assert_called_once()
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["model"] == "gpt-4o-mini"
assert call_kwargs["api_key"] == "test-key"
@patch("litellm.completion")
def test_complete_with_system_prompt(self, mock_completion):
"""Test completion with system prompt."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 5
mock_completion.return_value = mock_response
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
provider.complete(messages=[{"role": "user", "content": "Hello"}], system="You are a helpful assistant.")
call_kwargs = mock_completion.call_args[1]
messages = call_kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "You are a helpful assistant."
@patch("litellm.completion")
def test_complete_with_tools(self, mock_completion):
"""Test completion with tools."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 10
mock_completion.return_value = mock_response
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
tools = [
Tool(
name="get_weather",
description="Get the weather for a location",
parameters={
"properties": {"location": {"type": "string", "description": "City name"}},
"required": ["location"],
},
)
]
provider.complete(messages=[{"role": "user", "content": "What's the weather?"}], tools=tools)
call_kwargs = mock_completion.call_args[1]
assert "tools" in call_kwargs
assert call_kwargs["tools"][0]["type"] == "function"
assert call_kwargs["tools"][0]["function"]["name"] == "get_weather"
class TestToolConversion:
"""Test tool format conversion."""
def test_tool_to_openai_format(self):
"""Test converting Tool to OpenAI format."""
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
tool = Tool(
name="search",
description="Search the web",
parameters={
"properties": {"query": {"type": "string", "description": "Search query"}},
"required": ["query"],
},
)
result = provider._tool_to_openai_format(tool)
assert result["type"] == "function"
assert result["function"]["name"] == "search"
assert result["function"]["description"] == "Search the web"
assert result["function"]["parameters"]["properties"]["query"]["type"] == "string"
assert result["function"]["parameters"]["required"] == ["query"]
def test_parse_tool_call_arguments_repairs_truncated_json(self):
"""Truncated JSON fragments should be repaired into valid tool inputs."""
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
parsed = provider._parse_tool_call_arguments(
(
'{"question":"What story structure should the agent use?",'
'"options":["3-act structure","Beginning-Middle-End","Random paragraph"'
),
"ask_user",
)
assert parsed == {
"question": "What story structure should the agent use?",
"options": [
"3-act structure",
"Beginning-Middle-End",
"Random paragraph",
],
}
def test_parse_tool_call_arguments_raises_when_unrepairable(self):
"""Completely invalid JSON should fail fast instead of producing _raw loops."""
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
with pytest.raises(ValueError, match="Failed to parse tool call arguments"):
provider._parse_tool_call_arguments('{"question": foo', "ask_user")
class TestAnthropicProviderBackwardCompatibility:
"""Test AnthropicProvider backward compatibility with LiteLLM backend."""
def test_anthropic_provider_is_llm_provider(self):
"""Test that AnthropicProvider implements LLMProvider interface."""
provider = AnthropicProvider(api_key="test-key")
assert isinstance(provider, LLMProvider)
def test_anthropic_provider_init_defaults(self):
"""Test AnthropicProvider initialization with defaults."""
provider = AnthropicProvider(api_key="test-key")
assert provider.model == "claude-haiku-4-5-20251001"
assert provider.api_key == "test-key"
def test_anthropic_provider_init_custom_model(self):
"""Test AnthropicProvider initialization with custom model."""
provider = AnthropicProvider(api_key="test-key", model="claude-3-haiku-20240307")
assert provider.model == "claude-3-haiku-20240307"
def test_anthropic_provider_uses_litellm_internally(self):
"""Test that AnthropicProvider delegates to LiteLLMProvider."""
provider = AnthropicProvider(api_key="test-key", model="claude-3-haiku-20240307")
assert isinstance(provider._provider, LiteLLMProvider)
assert provider._provider.model == "claude-3-haiku-20240307"
assert provider._provider.api_key == "test-key"
@patch("litellm.completion")
def test_anthropic_provider_complete(self, mock_completion):
"""Test AnthropicProvider.complete() delegates to LiteLLM."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello from Claude!"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "claude-3-haiku-20240307"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_completion.return_value = mock_response
provider = AnthropicProvider(api_key="test-key", model="claude-3-haiku-20240307")
result = provider.complete(
messages=[{"role": "user", "content": "Hello"}],
system="You are helpful.",
max_tokens=100,
)
assert result.content == "Hello from Claude!"
assert result.model == "claude-3-haiku-20240307"
assert result.input_tokens == 10
assert result.output_tokens == 5
mock_completion.assert_called_once()
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["model"] == "claude-3-haiku-20240307"
assert call_kwargs["api_key"] == "test-key"
@patch("litellm.completion")
def test_anthropic_provider_passes_response_format(self, mock_completion):
"""Test that AnthropicProvider accepts and forwards response_format."""
# Setup mock
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "{}"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "claude-3-haiku-20240307"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_completion.return_value = mock_response
provider = AnthropicProvider(api_key="test-key")
fmt = {"type": "json_object"}
provider.complete(messages=[{"role": "user", "content": "hi"}], response_format=fmt)
# Verify it was passed to litellm
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["response_format"] == fmt
class TestJsonMode:
"""Test json_mode parameter for structured JSON output via prompt engineering."""
@patch("litellm.completion")
def test_json_mode_adds_instruction_to_system_prompt(self, mock_completion):
"""Test that json_mode=True adds JSON instruction to system prompt."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = '{"key": "value"}'
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_completion.return_value = mock_response
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
provider.complete(
messages=[{"role": "user", "content": "Return JSON"}],
system="You are helpful.",
json_mode=True,
)
call_kwargs = mock_completion.call_args[1]
# Should NOT use response_format (prompt engineering instead)
assert "response_format" not in call_kwargs
# Should have JSON instruction appended to system message
messages = call_kwargs["messages"]
assert messages[0]["role"] == "system"
assert "You are helpful." in messages[0]["content"]
assert "Please respond with a valid JSON object" in messages[0]["content"]
@patch("litellm.completion")
def test_json_mode_creates_system_prompt_if_none(self, mock_completion):
"""Test that json_mode=True creates system prompt if none provided."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = '{"key": "value"}'
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_completion.return_value = mock_response
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
provider.complete(messages=[{"role": "user", "content": "Return JSON"}], json_mode=True)
call_kwargs = mock_completion.call_args[1]
messages = call_kwargs["messages"]
# Should insert a system message with JSON instruction
assert messages[0]["role"] == "system"
assert "Please respond with a valid JSON object" in messages[0]["content"]
@patch("litellm.completion")
def test_json_mode_false_no_instruction(self, mock_completion):
"""Test that json_mode=False does not add JSON instruction."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_completion.return_value = mock_response
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
provider.complete(
messages=[{"role": "user", "content": "Hello"}],
system="You are helpful.",
json_mode=False,
)
call_kwargs = mock_completion.call_args[1]
assert "response_format" not in call_kwargs
messages = call_kwargs["messages"]
assert messages[0]["role"] == "system"
assert "Please respond with a valid JSON object" not in messages[0]["content"]
@patch("litellm.completion")
def test_json_mode_default_is_false(self, mock_completion):
"""Test that json_mode defaults to False (no JSON instruction)."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_completion.return_value = mock_response
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
provider.complete(messages=[{"role": "user", "content": "Hello"}], system="You are helpful.")
call_kwargs = mock_completion.call_args[1]
assert "response_format" not in call_kwargs
messages = call_kwargs["messages"]
# System prompt should be unchanged
assert messages[0]["content"] == "You are helpful."
@patch("litellm.completion")
def test_anthropic_provider_passes_json_mode(self, mock_completion):
"""Test that AnthropicProvider passes json_mode through (prompt engineering)."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = '{"result": "ok"}'
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "claude-haiku-4-5-20251001"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_completion.return_value = mock_response
provider = AnthropicProvider(api_key="test-key")
provider.complete(
messages=[{"role": "user", "content": "Return JSON"}],
system="You are helpful.",
json_mode=True,
)
call_kwargs = mock_completion.call_args[1]
# Should NOT use response_format
assert "response_format" not in call_kwargs
# Should have JSON instruction in system prompt
messages = call_kwargs["messages"]
assert messages[0]["role"] == "system"
assert "Please respond with a valid JSON object" in messages[0]["content"]
class TestComputeRetryDelay:
"""Test _compute_retry_delay() header parsing and fallback logic."""
def test_fallback_exponential_backoff(self):
"""No exception -> exponential backoff."""
assert _compute_retry_delay(0) == 2 # 2 * 2^0
assert _compute_retry_delay(1) == 4 # 2 * 2^1
assert _compute_retry_delay(2) == 8 # 2 * 2^2
assert _compute_retry_delay(3) == 16 # 2 * 2^3
def test_max_delay_cap(self):
"""Backoff should be capped at RATE_LIMIT_MAX_DELAY."""
# 2 * 2^10 = 2048, should be capped at 120
assert _compute_retry_delay(10) == 120
def test_custom_max_delay(self):
"""Custom max_delay should be respected."""
assert _compute_retry_delay(5, max_delay=10) == 10
def test_retry_after_ms_header(self):
"""retry-after-ms header should be parsed as milliseconds."""
exc = _make_exception_with_headers({"retry-after-ms": "5000"})
assert _compute_retry_delay(0, exception=exc) == 5.0
def test_retry_after_ms_fractional(self):
"""retry-after-ms should handle fractional values."""
exc = _make_exception_with_headers({"retry-after-ms": "1500"})
assert _compute_retry_delay(0, exception=exc) == 1.5
def test_retry_after_seconds_header(self):
"""retry-after header as seconds should be parsed."""
exc = _make_exception_with_headers({"retry-after": "3"})
assert _compute_retry_delay(0, exception=exc) == 3.0
def test_retry_after_seconds_fractional(self):
"""retry-after header should handle fractional seconds."""
exc = _make_exception_with_headers({"retry-after": "2.5"})
assert _compute_retry_delay(0, exception=exc) == 2.5
def test_retry_after_ms_takes_priority(self):
"""retry-after-ms should take priority over retry-after."""
exc = _make_exception_with_headers(
{
"retry-after-ms": "2000",
"retry-after": "10",
}
)
assert _compute_retry_delay(0, exception=exc) == 2.0
def test_retry_after_http_date(self):
"""retry-after as HTTP-date should be parsed."""
from email.utils import format_datetime
future = datetime.now(UTC) + timedelta(seconds=5)
date_str = format_datetime(future, usegmt=True)
exc = _make_exception_with_headers({"retry-after": date_str})
delay = _compute_retry_delay(0, exception=exc)
assert 3.0 <= delay <= 6.0 # within tolerance
def test_exception_without_response(self):
"""Exception with response=None should fall back to exponential."""
exc = Exception("test")
exc.response = None # type: ignore[attr-defined]
assert _compute_retry_delay(0, exception=exc) == 2 # exponential fallback
def test_exception_without_response_attr(self):
"""Exception without .response attr should fall back to exponential."""
exc = ValueError("no response attr")
assert _compute_retry_delay(0, exception=exc) == 2
def test_negative_retry_after_clamped_to_zero(self):
"""Negative retry-after should be clamped to 0."""
exc = _make_exception_with_headers({"retry-after": "-5"})
assert _compute_retry_delay(0, exception=exc) == 0
def test_negative_retry_after_ms_clamped_to_zero(self):
"""Negative retry-after-ms should be clamped to 0."""
exc = _make_exception_with_headers({"retry-after-ms": "-1000"})
assert _compute_retry_delay(0, exception=exc) == 0
def test_invalid_retry_after_falls_back(self):
"""Non-numeric, non-date retry-after should fall back to exponential."""
exc = _make_exception_with_headers({"retry-after": "not-a-number-or-date"})
assert _compute_retry_delay(0, exception=exc) == 2 # exponential fallback
def test_invalid_retry_after_ms_falls_back_to_retry_after(self):
"""Invalid retry-after-ms should fall through to retry-after."""
exc = _make_exception_with_headers(
{
"retry-after-ms": "garbage",
"retry-after": "7",
}
)
assert _compute_retry_delay(0, exception=exc) == 7.0
def test_retry_after_capped_at_max_delay(self):
"""Server-provided delay should be capped at max_delay."""
exc = _make_exception_with_headers({"retry-after": "3600"})
assert _compute_retry_delay(0, exception=exc) == 120 # capped
def test_retry_after_ms_capped_at_max_delay(self):
"""Server-provided ms delay should be capped at max_delay."""
exc = _make_exception_with_headers({"retry-after-ms": "300000"}) # 300s
assert _compute_retry_delay(0, exception=exc) == 120 # capped
def _make_exception_with_headers(headers: dict[str, str]) -> BaseException:
"""Create a mock exception with response headers for testing."""
exc = Exception("rate limited")
response = MagicMock()
response.headers = headers
exc.response = response # type: ignore[attr-defined]
return exc
# ---------------------------------------------------------------------------
# Async LLM methods — non-blocking event loop tests
# ---------------------------------------------------------------------------
class TestAsyncComplete:
"""Test that acomplete/acomplete_with_tools don't block the event loop."""
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_acomplete_uses_acompletion(self, mock_acompletion):
"""acomplete() should call litellm.acompletion (async), not litellm.completion."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "async hello"
mock_response.choices[0].message.tool_calls = None
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
# acompletion is async, so mock must return a coroutine
async def async_return(*args, **kwargs):
return mock_response
mock_acompletion.side_effect = async_return
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
result = await provider.acomplete(
messages=[{"role": "user", "content": "Hello"}],
system="You are helpful.",
)
assert result.content == "async hello"
assert result.model == "gpt-4o-mini"
assert result.input_tokens == 10
assert result.output_tokens == 5
mock_acompletion.assert_called_once()
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_acomplete_does_not_block_event_loop(self, mock_acompletion):
"""Verify event loop stays responsive during acomplete()."""
heartbeat_ticks = []
async def heartbeat():
start = time.monotonic()
for _ in range(10):
heartbeat_ticks.append(time.monotonic() - start)
await asyncio.sleep(0.05)
async def slow_acompletion(*args, **kwargs):
# Simulate a 300ms LLM call — async, so event loop should stay free
await asyncio.sleep(0.3)
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = "done"
resp.choices[0].message.tool_calls = None
resp.choices[0].finish_reason = "stop"
resp.model = "gpt-4o-mini"
resp.usage.prompt_tokens = 5
resp.usage.completion_tokens = 3
return resp
mock_acompletion.side_effect = slow_acompletion
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
# Run heartbeat + acomplete concurrently
_, result = await asyncio.gather(
heartbeat(),
provider.acomplete(
messages=[{"role": "user", "content": "hi"}],
),
)
assert result.content == "done"
# Heartbeat should have ticked multiple times during the 300ms LLM call
# (if the event loop were blocked, we'd see 0-1 ticks)
assert len(heartbeat_ticks) >= 3, f"Event loop was blocked — only {len(heartbeat_ticks)} heartbeat ticks"
@pytest.mark.asyncio
async def test_mock_provider_acomplete(self):
"""MockLLMProvider.acomplete() should work without blocking."""
from framework.llm.mock import MockLLMProvider
provider = MockLLMProvider()
result = await provider.acomplete(
messages=[{"role": "user", "content": "test"}],
system="Be helpful.",
)
assert result.content # Should have some mock content
assert result.model == "mock-model"
@pytest.mark.asyncio
async def test_base_provider_acomplete_offloads_to_executor(self):
"""Base LLMProvider.acomplete() should offload sync complete() to thread pool."""
call_thread_ids = []
class SlowSyncProvider(LLMProvider):
model: str = "mock"
def complete(
self,
messages,
system="",
tools=None,
max_tokens=1024,
response_format=None,
json_mode=False,
max_retries=None,
):
call_thread_ids.append(threading.current_thread().ident)
time.sleep(0.1) # Sync blocking
return LLMResponse(content="sync done", model="slow")
provider = SlowSyncProvider()
main_thread_id = threading.current_thread().ident
result = await provider.acomplete(
messages=[{"role": "user", "content": "hi"}],
)
assert result.content == "sync done"
# The sync complete() should have run on a different thread
assert call_thread_ids[0] != main_thread_id, "Base acomplete() should offload sync complete() to a thread pool"
class TestMiniMaxStreamFallback:
"""MiniMax models should use non-stream fallback due to parser incompatibility."""
@pytest.mark.asyncio
async def test_stream_uses_nonstream_fallback_for_minimax(self):
"""stream() should call acomplete() and synthesize stream events for MiniMax."""
from framework.llm.stream_events import FinishEvent, TextDeltaEvent
provider = LiteLLMProvider(model="minimax-text-01", api_key="test-key")
mock_response = LLMResponse(
content="hello from minimax",
model="minimax-text-01",
input_tokens=7,
output_tokens=4,
stop_reason="stop",
raw_response=None,
)
provider.acomplete = AsyncMock(return_value=mock_response)
events = []
async for event in provider.stream(messages=[{"role": "user", "content": "hi"}]):
events.append(event)
assert provider.acomplete.await_count == 1
assert any(isinstance(e, TextDeltaEvent) for e in events)
finish = [e for e in events if isinstance(e, FinishEvent)]
assert len(finish) == 1
assert finish[0].model == "minimax-text-01"
def test_is_minimax_model_variants(self):
"""Recognize both prefixed and plain MiniMax model names."""
assert LiteLLMProvider(model="minimax-text-01", api_key="x")._is_minimax_model()
assert LiteLLMProvider(model="minimax/minimax-text-01", api_key="x")._is_minimax_model()
assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model()
class TestOpenRouterToolCompatFallback:
"""OpenRouter models should fall back when native tool use is unavailable."""
def teardown_method(self):
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.clear()
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_falls_back_to_json_tool_emulation(self, mock_acompletion):
"""OpenRouter tool-use 404s should emit synthetic ToolCallEvents instead of errors."""
from framework.llm.stream_events import FinishEvent, ToolCallEvent
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="web_search",
description="Search the web",
parameters={
"properties": {
"query": {"type": "string"},
"num_results": {"type": "integer"},
},
"required": ["query"],
},
)
]
compat_response = MagicMock()
compat_response.choices = [MagicMock()]
compat_response.choices[0].message.content = (
'{"assistant_response":"","tool_calls":['
'{"name":"web_search","arguments":'
'{"query":"Python 3.13 release notes","num_results":3}}'
"]}"
)
compat_response.choices[0].finish_reason = "stop"
compat_response.model = provider.model
compat_response.usage.prompt_tokens = 18
compat_response.usage.completion_tokens = 9
async def side_effect(*args, **kwargs):
if kwargs.get("stream"):
raise RuntimeError(
'OpenrouterException - {"error":{"message":"No endpoints found '
"that support tool use. To learn more about provider routing, "
'visit: https://openrouter.ai/docs/guides/routing/provider-selection",'
'"code":404}}'
)
return compat_response
mock_acompletion.side_effect = side_effect
events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Search for the Python 3.13 release notes."}],
system="Use tools when needed.",
tools=tools,
max_tokens=256,
):
events.append(event)
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
assert len(tool_calls) == 1
assert tool_calls[0].tool_name == "web_search"
assert tool_calls[0].tool_input == {
"query": "Python 3.13 release notes",
"num_results": 3,
}
assert tool_calls[0].tool_use_id.startswith("openrouter_compat_")
finish_events = [event for event in events if isinstance(event, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason == "tool_calls"
assert finish_events[0].input_tokens == 18
assert finish_events[0].output_tokens == 9
assert mock_acompletion.call_count == 2
first_call = mock_acompletion.call_args_list[0].kwargs
assert first_call["stream"] is True
assert "tools" in first_call
second_call = mock_acompletion.call_args_list[1].kwargs
assert "tools" not in second_call
assert "Tool compatibility mode is active" in second_call["messages"][0]["content"]
assert provider.model in OPENROUTER_TOOL_COMPAT_MODEL_CACHE
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_tool_compat_parses_textual_tool_calls_and_uses_cache(
self,
mock_acompletion,
):
"""Textual tool-call markers should become ToolCallEvents and skip repeat probing."""
from framework.llm.stream_events import ToolCallEvent
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="choose_one",
description="Ask the user a multiple-choice question",
parameters={
"properties": {
"options": {"type": "array"},
"question": {"type": "string"},
"prompt": {"type": "string"},
},
"required": ["options", "question", "prompt"],
},
)
]
compat_response = MagicMock()
compat_response.choices = [MagicMock()]
compat_response.choices[0].message.content = (
"<|tool_call_start|>"
"[choose_one(options=['Quartet Collaborator', 'Project Advisor'], "
"question='Who are you?', prompt='Who are you?')]"
"<|tool_call_end|>"
)
compat_response.choices[0].finish_reason = "stop"
compat_response.model = provider.model
compat_response.usage.prompt_tokens = 10
compat_response.usage.completion_tokens = 5
call_state = {"count": 0}
async def side_effect(*args, **kwargs):
call_state["count"] += 1
if kwargs.get("stream"):
raise RuntimeError(
'OpenrouterException - {"error":{"message":"No endpoints found that support tool use.","code":404}}'
)
return compat_response
mock_acompletion.side_effect = side_effect
first_events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Who are you?"}],
system="Use tools when needed.",
tools=tools,
max_tokens=128,
):
first_events.append(event)
tool_calls = [event for event in first_events if isinstance(event, ToolCallEvent)]
assert len(tool_calls) == 1
assert tool_calls[0].tool_name == "choose_one"
assert tool_calls[0].tool_input == {
"options": ["Quartet Collaborator", "Project Advisor"],
"question": "Who are you?",
"prompt": "Who are you?",
}
second_events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Who are you?"}],
system="Use tools when needed.",
tools=tools,
max_tokens=128,
):
second_events.append(event)
second_tool_calls = [event for event in second_events if isinstance(event, ToolCallEvent)]
assert len(second_tool_calls) == 1
assert mock_acompletion.call_count == 3
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
assert "stream" not in mock_acompletion.call_args_list[2].kwargs
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_tool_compat_parses_plain_text_tool_call_lines(
self,
mock_acompletion,
):
"""Plain textual tool-call lines should execute as tools, not user-visible text."""
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="ask_user",
description="Ask the user a single multiple-choice question",
parameters={
"properties": {
"question": {"type": "string"},
"options": {"type": "array"},
},
"required": ["question", "options"],
},
)
]
compat_response = MagicMock()
compat_response.choices = [MagicMock()]
compat_response.choices[0].message.content = (
"Queen has been loaded. It's ready to assist with your planning needs.\n\n"
"ask_user('What would you like to do?', ['Define a new agent', "
"'Diagnose an existing agent', 'Explore tools'])"
)
compat_response.choices[0].finish_reason = "stop"
compat_response.model = provider.model
compat_response.usage.prompt_tokens = 11
compat_response.usage.completion_tokens = 7
async def side_effect(*args, **kwargs):
if kwargs.get("stream"):
raise RuntimeError(
'OpenrouterException - {"error":{"message":"No endpoints found that support tool use.","code":404}}'
)
return compat_response
mock_acompletion.side_effect = side_effect
events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "hello"}],
system="Use tools when needed.",
tools=tools,
max_tokens=128,
):
events.append(event)
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
assert len(tool_calls) == 1
assert tool_calls[0].tool_name == "ask_user"
assert tool_calls[0].tool_input == {
"question": "What would you like to do?",
"options": ["Define a new agent", "Diagnose an existing agent", "Explore tools"],
}
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
assert len(text_events) == 1
assert "ask_user(" not in text_events[0].snapshot
assert text_events[0].snapshot == ("Queen has been loaded. It's ready to assist with your planning needs.")
finish_events = [event for event in events if isinstance(event, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason == "tool_calls"
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_tool_compat_treats_non_json_as_plain_text(self, mock_acompletion):
"""If fallback output is not valid JSON, preserve it as assistant text."""
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="web_search",
description="Search the web",
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
)
]
compat_response = MagicMock()
compat_response.choices = [MagicMock()]
compat_response.choices[0].message.content = "I can answer directly without tools."
compat_response.choices[0].finish_reason = "stop"
compat_response.model = provider.model
compat_response.usage.prompt_tokens = 12
compat_response.usage.completion_tokens = 6
async def side_effect(*args, **kwargs):
if kwargs.get("stream"):
raise RuntimeError(
'OpenrouterException - {"error":{"message":"No endpoints found that support tool use.","code":404}}'
)
return compat_response
mock_acompletion.side_effect = side_effect
events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Say hello."}],
system="Be concise.",
tools=tools,
max_tokens=128,
):
events.append(event)
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
assert len(text_events) == 1
assert text_events[0].snapshot == "I can answer directly without tools."
assert not any(isinstance(event, ToolCallEvent) for event in events)
finish_events = [event for event in events if isinstance(event, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason == "stop"
# ---------------------------------------------------------------------------
# AgentRunner._is_local_model — parameterized tests
# ---------------------------------------------------------------------------
class TestIsLocalModel:
"""Parameterized tests for AgentRunner._is_local_model()."""
@pytest.mark.parametrize(
"model",
[
"ollama/llama3",
"ollama/mistral",
"ollama_chat/llama3",
"vllm/mistral",
"lm_studio/phi3",
"llamacpp/llama-7b",
"Ollama/Llama3", # case-insensitive
"VLLM/Mistral",
],
)
def test_local_models_return_true(self, model):
"""Local model prefixes should be recognized."""
from framework.loader.agent_loader import AgentLoader
assert AgentLoader._is_local_model(model) is True
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-3-haiku",
"openai/gpt-4o",
"gpt-4o-mini",
"claude-3-haiku-20240307",
"gemini/gemini-1.5-flash",
"groq/llama3-70b",
"mistral/mistral-large",
"azure/gpt-4",
"cohere/command-r",
"together/llama3-70b",
],
)
def test_cloud_models_return_false(self, model):
"""Cloud model prefixes should not be treated as local."""
from framework.loader.agent_loader import AgentLoader
assert AgentLoader._is_local_model(model) is False
# ---------------------------------------------------------------------------
# Ollama helper functions
# ---------------------------------------------------------------------------
class TestIsOllamaModel:
"""Tests for _is_ollama_model()."""
@pytest.mark.parametrize(
"model",
[
"ollama/llama3",
"ollama/mistral:7b",
"ollama_chat/llama3",
"ollama_chat/qwen2.5:72b",
],
)
def test_ollama_models_return_true(self, model):
assert _is_ollama_model(model) is True
@pytest.mark.parametrize(
"model",
[
"gpt-4o-mini",
"anthropic/claude-3-haiku",
"openai/gpt-4o",
"gemini/gemini-1.5-flash",
"llama3",
"",
],
)
def test_non_ollama_models_return_false(self, model):
assert _is_ollama_model(model) is False
class TestEnsureOllamaChatPrefix:
"""Tests for _ensure_ollama_chat_prefix()."""
@pytest.mark.parametrize(
("input_model", "expected"),
[
("ollama/llama3", "ollama_chat/llama3"),
("ollama/mistral:7b", "ollama_chat/mistral:7b"),
("ollama/qwen2.5:72b-instruct", "ollama_chat/qwen2.5:72b-instruct"),
],
)
def test_rewrites_ollama_to_ollama_chat(self, input_model, expected):
assert _ensure_ollama_chat_prefix(input_model) == expected
@pytest.mark.parametrize(
"model",
[
"ollama_chat/llama3",
"gpt-4o-mini",
"anthropic/claude-3-haiku",
"gemini/gemini-1.5-flash",
"",
],
)
def test_leaves_non_ollama_prefix_unchanged(self, model):
assert _ensure_ollama_chat_prefix(model) == model
class TestGetLlmExtraKwargsOllama:
"""Tests for num_ctx injection via get_llm_extra_kwargs() for Ollama."""
def test_ollama_provider_returns_num_ctx(self):
"""Ollama config should inject num_ctx with default 16384."""
config = {
"llm": {"provider": "ollama", "model": "ollama/llama3"},
}
with patch("framework.config.get_hive_config", return_value=config):
result = get_llm_extra_kwargs()
assert result == {"num_ctx": 16384}
def test_ollama_provider_respects_custom_num_ctx(self):
"""User-specified num_ctx in config should take precedence."""
config = {
"llm": {"provider": "ollama", "model": "ollama/llama3", "num_ctx": 32768},
}
with patch("framework.config.get_hive_config", return_value=config):
result = get_llm_extra_kwargs()
assert result == {"num_ctx": 32768}
def test_non_ollama_provider_returns_empty(self):
"""Non-Ollama provider without subscriptions should return empty dict."""
config = {
"llm": {"provider": "anthropic", "model": "claude-3-haiku"},
}
with patch("framework.config.get_hive_config", return_value=config):
result = get_llm_extra_kwargs()
assert result == {}
def test_empty_config_returns_empty(self):
"""Missing config should return empty dict."""
with patch("framework.config.get_hive_config", return_value={}):
result = get_llm_extra_kwargs()
assert result == {}
class TestModelSupportsCacheControl:
"""`cache_control` allowlist covers native providers AND OpenRouter sub-providers
whose upstream API honors the marker (Anthropic, Gemini, GLM, MiniMax).
Auto-cache sub-providers (OpenAI, DeepSeek, Grok, Moonshot, Groq) are
intentionally excluded: sending cache_control is a no-op and a false win."""
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-opus-4-5",
"claude-3-5-sonnet-20241022",
"minimax/minimax-text-01",
"MiniMax-Text-01",
"zai-glm-4.6",
"glm-4.6",
"openrouter/anthropic/claude-opus-4.5",
"openrouter/anthropic/claude-sonnet-4.5",
"openrouter/google/gemini-2.5-pro",
"openrouter/google/gemini-2.5-flash",
"openrouter/z-ai/glm-5.1",
"openrouter/z-ai/glm-4.6",
"openrouter/minimax/minimax-text-01",
],
)
def test_supported(self, model):
assert _model_supports_cache_control(model) is True
@pytest.mark.parametrize(
"model",
[
"gpt-4o-mini",
"gemini/gemini-1.5-flash",
"ollama_chat/llama3",
"openrouter/openai/gpt-4o",
"openrouter/deepseek/deepseek-chat",
"openrouter/x-ai/grok-2",
"openrouter/moonshotai/kimi-k2",
"openrouter/liquid/lfm-2.5-1.2b-thinking:free",
],
)
def test_unsupported(self, model):
assert _model_supports_cache_control(model) is False
class TestBuildSystemMessageOpenRouter:
"""`_build_system_message` should split static/dynamic blocks whenever
the model — native OR OpenRouter-routed — supports cache_control."""
def test_openrouter_anthropic_splits_into_two_blocks(self):
msg = _build_system_message(
system="static prefix",
system_dynamic_suffix="dynamic tail",
model="openrouter/anthropic/claude-opus-4.5",
)
assert msg == {
"role": "system",
"content": [
{
"type": "text",
"text": "static prefix",
"cache_control": {"type": "ephemeral"},
},
{"type": "text", "text": "dynamic tail"},
],
}
def test_openrouter_gemini_splits_into_two_blocks(self):
msg = _build_system_message(
system="static prefix",
system_dynamic_suffix="dynamic tail",
model="openrouter/google/gemini-2.5-pro",
)
assert isinstance(msg["content"], list)
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
assert msg["content"][1] == {"type": "text", "text": "dynamic tail"}
def test_openrouter_glm_splits_into_two_blocks(self):
msg = _build_system_message(
system="static prefix",
system_dynamic_suffix="dynamic tail",
model="openrouter/z-ai/glm-5.1",
)
assert isinstance(msg["content"], list)
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
def test_openrouter_openai_stays_concatenated(self):
"""OpenAI via OpenRouter auto-caches; sending cache_control is a no-op."""
msg = _build_system_message(
system="static prefix",
system_dynamic_suffix="dynamic tail",
model="openrouter/openai/gpt-4o",
)
assert msg == {
"role": "system",
"content": "static prefix\n\ndynamic tail",
}
def test_no_suffix_anthropic_gets_top_level_cache_control(self):
msg = _build_system_message(
system="static prefix",
system_dynamic_suffix=None,
model="openrouter/anthropic/claude-opus-4.5",
)
assert msg == {
"role": "system",
"content": "static prefix",
"cache_control": {"type": "ephemeral"},
}
class TestOpenRouterToolCompatCacheControl:
"""Tool-compat path must pass cache_control through when the routed
sub-provider honors it. Before this, the queen persona+tool-list prefix
was recomputed every turn on Anthropic/GLM via OpenRouter."""
def test_tool_compat_messages_split_for_cache_capable_model(self):
provider = LiteLLMProvider(
model="openrouter/anthropic/claude-opus-4.5",
api_key="test-key",
)
tools = [
Tool(
name="web_search",
description="Search the web",
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
)
]
full_messages = provider._build_openrouter_tool_compat_messages(
messages=[{"role": "user", "content": "hi"}],
system="You are a queen.",
tools=tools,
system_dynamic_suffix="Current time: 2026-04-23T00:00:00Z",
)
system_msg = full_messages[0]
assert system_msg["role"] == "system"
assert isinstance(system_msg["content"], list)
assert len(system_msg["content"]) == 2
static_block = system_msg["content"][0]
assert static_block["cache_control"] == {"type": "ephemeral"}
assert "You are a queen." in static_block["text"]
assert "Tool compatibility mode is active" in static_block["text"]
assert "web_search" in static_block["text"]
assert "2026-04-23" not in static_block["text"]
dynamic_block = system_msg["content"][1]
assert "cache_control" not in dynamic_block
assert dynamic_block["text"] == "Current time: 2026-04-23T00:00:00Z"
def test_tool_compat_messages_stay_concatenated_for_liquid(self):
"""Liquid (and other non-cache-control OR sub-providers) keep legacy behavior."""
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="web_search",
description="Search the web",
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
)
]
full_messages = provider._build_openrouter_tool_compat_messages(
messages=[{"role": "user", "content": "hi"}],
system="You are a queen.",
tools=tools,
system_dynamic_suffix="Current time: 2026-04-23T00:00:00Z",
)
system_msg = full_messages[0]
assert isinstance(system_msg["content"], str)
assert "2026-04-23" in system_msg["content"]
assert "cache_control" not in system_msg
class TestExtractCacheTokens:
"""`_extract_cache_tokens` reads cache_read + cache_creation from the
LiteLLM-normalized usage object. Both fields are subsets of
``prompt_tokens`` — the helper surfaces them for display, the call sites
are responsible for never adding them to a total."""
def test_none_usage_returns_zero(self):
assert _extract_cache_tokens(None) == (0, 0)
def test_openai_shape(self):
"""Pure OpenAI responses expose cached reads via
``prompt_tokens_details.cached_tokens`` and have no cache write
field at all (OpenAI's automatic caching is read-only from the
client's perspective)."""
usage = MagicMock(spec=["prompt_tokens_details", "cache_creation_input_tokens"])
usage.prompt_tokens_details = MagicMock(
spec=["cached_tokens"], cached_tokens=120,
)
usage.cache_creation_input_tokens = 0
cache_read, cache_creation = _extract_cache_tokens(usage)
assert cache_read == 120
assert cache_creation == 0
def test_openrouter_cache_write_tokens_shape(self):
"""OpenRouter normalizes cache writes into
``prompt_tokens_details.cache_write_tokens`` (verified empirically
against openrouter/anthropic and openrouter/z-ai responses). The
legacy ``usage.cache_creation_input_tokens`` field is NOT set on
OpenRouter responses, so this is the path that matters in practice."""
usage = MagicMock()
usage.prompt_tokens_details = MagicMock(
cached_tokens=80, cache_write_tokens=50,
)
# Explicitly set the Anthropic-native field to 0 to prove we don't
# depend on it for OpenRouter responses.
usage.cache_creation_input_tokens = 0
cache_read, cache_creation = _extract_cache_tokens(usage)
assert cache_read == 80
assert cache_creation == 50
def test_anthropic_native_cache_creation_field_still_works(self):
"""Direct Anthropic API responses (not via OpenRouter) put cache
writes on the top-level ``cache_creation_input_tokens`` field. Keep
the fallback so non-OpenRouter Anthropic continues to work."""
usage = MagicMock(spec=["prompt_tokens_details", "cache_creation_input_tokens"])
usage.prompt_tokens_details = MagicMock(
spec=["cached_tokens"], cached_tokens=80,
)
usage.cache_creation_input_tokens = 50
cache_read, cache_creation = _extract_cache_tokens(usage)
assert cache_read == 80
assert cache_creation == 50
def test_raw_anthropic_shape_falls_back(self):
"""Raw Anthropic usage (no prompt_tokens_details) — fall back to
cache_read_input_tokens."""
usage = MagicMock(spec=["cache_read_input_tokens", "cache_creation_input_tokens"])
usage.cache_read_input_tokens = 200
usage.cache_creation_input_tokens = 75
# Force prompt_tokens_details to be missing on the spec'd mock.
cache_read, cache_creation = _extract_cache_tokens(usage)
assert cache_read == 200
assert cache_creation == 75
def test_no_cache_fields_returns_zero(self):
"""A provider that doesn't report cache tokens at all (e.g. Gemini)
returns (0, 0) — never raises."""
usage = MagicMock(spec=["prompt_tokens", "completion_tokens"])
cache_read, cache_creation = _extract_cache_tokens(usage)
assert cache_read == 0
assert cache_creation == 0
class TestStreamingChunksFallbackPreservesCacheFields:
"""Regression: when LiteLLM strips usage from yielded streaming chunks,
we fall back to ``response.chunks`` to recover token totals. LiteLLM's
own ``calculate_total_usage()`` aggregates ``prompt_tokens`` /
``completion_tokens`` correctly but DROPS ``prompt_tokens_details`` —
which is where OpenRouter places ``cached_tokens`` and
``cache_write_tokens``. The fallback path must walk the raw chunks to
recover those fields, otherwise streaming OpenRouter calls always
report zero cache tokens. (Verified empirically against
openrouter/anthropic/* and openrouter/z-ai/*.)"""
def test_chunks_with_cache_fields_recovered(self):
"""Simulate the chunks-fallback hot path: build raw chunks where the
last one carries cache_write_tokens, run the same recovery loop the
streaming code uses, and assert we surface the cache fields."""
# Three chunks: text deltas, then a final chunk with usage.
empty_usage_chunk = MagicMock()
empty_usage_chunk.usage = None
last_chunk = MagicMock()
last_chunk.usage = MagicMock()
last_chunk.usage.prompt_tokens_details = MagicMock(
cached_tokens=0, cache_write_tokens=5601,
)
last_chunk.usage.cache_creation_input_tokens = 0
chunks = [empty_usage_chunk, empty_usage_chunk, last_chunk]
# Mirror the production loop in litellm.py's chunks-fallback.
cached, creation = 0, 0
for raw in reversed(chunks):
usage = getattr(raw, "usage", None)
if usage is None:
continue
cr, cc = _extract_cache_tokens(usage)
if cr or cc:
cached, creation = cr, cc
break
assert cached == 0
assert creation == 5601, (
"chunks-fallback must recover cache_write_tokens from the raw "
"chunk, not from calculate_total_usage which strips details"
)
def test_chunks_with_cache_read_recovered(self):
"""Same path, but for a cache HIT (cached_tokens populated)."""
last_chunk = MagicMock()
last_chunk.usage = MagicMock()
last_chunk.usage.prompt_tokens_details = MagicMock(
cached_tokens=5601, cache_write_tokens=0,
)
last_chunk.usage.cache_creation_input_tokens = 0
cached, creation = 0, 0
for raw in reversed([last_chunk]):
usage = getattr(raw, "usage", None)
if usage is None:
continue
cr, cc = _extract_cache_tokens(usage)
if cr or cc:
cached, creation = cr, cc
break
assert cached == 5601
assert creation == 0
class TestExtractCost:
"""`_extract_cost` pulls USD cost from three sources in order:
usage.cost (OpenRouter native / include_cost_in_streaming_usage) →
response._hidden_params['response_cost'] (LiteLLM logging) →
litellm.completion_cost() (pricing-table fallback)."""
def test_none_response_returns_zero(self):
assert _extract_cost(None, "gpt-4o-mini") == 0.0
def test_openrouter_usage_cost_is_preferred(self):
"""OpenRouter returns authoritative per-call cost on usage.cost when
the caller opts in (usage.include=true). That beats LiteLLM's
pricing-table estimate because it reflects promo pricing and BYOK markup."""
response = MagicMock()
response.usage = MagicMock(cost=0.00123)
response._hidden_params = {"response_cost": 99.99} # should be ignored
assert _extract_cost(response, "openrouter/anthropic/claude-opus-4.5") == 0.00123
def test_hidden_params_response_cost_used_when_no_usage_cost(self):
"""LiteLLM's logging layer attaches response_cost after most
completions — this is how OpenAI/Anthropic responses get costed
without going back to the pricing table."""
response = MagicMock()
response.usage = MagicMock(spec=[]) # no .cost attribute
response._hidden_params = {"response_cost": 0.0042}
assert _extract_cost(response, "gpt-4o-mini") == 0.0042
def test_falls_back_to_completion_cost_when_nothing_pre_populated(self):
"""For providers where LiteLLM didn't pre-populate cost, call
litellm.completion_cost() against the pricing table. Mocked here
because we don't want tests depending on the exact price of
claude-sonnet-4.5 in LiteLLM's model map."""
response = MagicMock()
response.usage = MagicMock(spec=[])
response._hidden_params = {}
with patch("litellm.completion_cost", return_value=0.00789):
assert _extract_cost(response, "anthropic/claude-sonnet-4.5") == 0.00789
def test_completion_cost_exception_returns_zero(self):
"""Unpriced models (e.g. new OpenRouter routes not yet in LiteLLM's
catalog) must not crash the hot path."""
response = MagicMock()
response.usage = MagicMock(spec=[])
response._hidden_params = {}
with patch("litellm.completion_cost", side_effect=Exception("no pricing")):
assert _extract_cost(response, "openrouter/mystery/model") == 0.0
def test_zero_cost_falls_through_to_next_source(self):
"""usage.cost == 0 should NOT short-circuit; fall through to
_hidden_params / completion_cost so we don't cement a false zero."""
response = MagicMock()
response.usage = MagicMock(cost=0.0)
response._hidden_params = {"response_cost": 0.0055}
assert _extract_cost(response, "gpt-4o-mini") == 0.0055
class TestCostFromTokens:
"""`_cost_from_tokens` is the streaming-path cost helper: stream wrappers
don't expose the full ModelResponse shape that completion_cost() expects,
so we go through cost_per_token() with the already-extracted totals."""
def test_zero_tokens_returns_zero_without_calling_litellm(self):
with patch("litellm.cost_per_token") as mock:
assert _cost_from_tokens("claude-opus-4.5", 0, 0) == 0.0
mock.assert_not_called()
def test_empty_model_returns_zero(self):
assert _cost_from_tokens("", 1000, 500) == 0.0
def test_computes_from_tokens(self):
with patch("litellm.cost_per_token", return_value=(0.001, 0.002)) as mock:
cost = _cost_from_tokens(
"anthropic/claude-opus-4.5",
input_tokens=1000,
output_tokens=500,
cached_tokens=200,
cache_creation_tokens=100,
)
assert cost == pytest.approx(0.003)
# Verify the cache-aware kwargs are threaded through — Anthropic
# needs these to apply the 1.25x write / 0.1x read multipliers.
call_kwargs = mock.call_args.kwargs
assert call_kwargs["prompt_tokens"] == 1000
assert call_kwargs["completion_tokens"] == 500
assert call_kwargs["cache_read_input_tokens"] == 200
assert call_kwargs["cache_creation_input_tokens"] == 100
def test_exception_returns_zero(self):
with patch("litellm.cost_per_token", side_effect=Exception("unpriced")):
assert _cost_from_tokens("mystery/model", 1000, 500) == 0.0
def test_negative_or_none_components_coerce_to_zero(self):
"""LiteLLM returns (None, None) for unknown models in some versions;
treat as 0 rather than crashing on None+None."""
with patch("litellm.cost_per_token", return_value=(None, None)):
assert _cost_from_tokens("some/model", 1, 1) == 0.0
class TestLLMResponseAndFinishEventHaveCostUsd:
"""Regression: both LLMResponse and FinishEvent must carry cost_usd so
the agent loop → event bus → frontend pipeline doesn't lose cost."""
def test_llm_response_defaults_cost_to_zero(self):
from framework.llm.provider import LLMResponse
r = LLMResponse(content="", model="m")
assert r.cost_usd == 0.0
def test_finish_event_defaults_cost_to_zero(self):
from framework.llm.stream_events import FinishEvent
e = FinishEvent()
assert e.cost_usd == 0.0
def test_finish_event_accepts_cost(self):
from framework.llm.stream_events import FinishEvent
e = FinishEvent(cost_usd=0.0123)
assert e.cost_usd == 0.0123