1513 lines
58 KiB
Python
1513 lines
58 KiB
Python
"""Tests for LiteLLM provider.
|
|
|
|
Run with:
|
|
cd core
|
|
uv pip install litellm pytest
|
|
pytest tests/test_litellm_provider.py -v
|
|
|
|
For live tests (requires API keys):
|
|
OPENAI_API_KEY=sk-... pytest tests/test_litellm_provider.py -v -m live
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import threading
|
|
import time
|
|
from datetime import UTC, datetime, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from framework.config import get_llm_extra_kwargs
|
|
from framework.llm.anthropic import AnthropicProvider
|
|
from framework.llm.litellm import (
|
|
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
|
|
LiteLLMProvider,
|
|
_compute_retry_delay,
|
|
_ensure_ollama_chat_prefix,
|
|
_is_ollama_model,
|
|
)
|
|
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
|
|
|
|
|
class TestLiteLLMProviderInit:
|
|
"""Test LiteLLMProvider initialization."""
|
|
|
|
def test_init_with_defaults(self):
|
|
"""Test initialization with default parameters."""
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
provider = LiteLLMProvider()
|
|
assert provider.model == "gpt-4o-mini"
|
|
assert provider.api_key is None
|
|
assert provider.api_base is None
|
|
|
|
def test_init_with_custom_model(self):
|
|
"""Test initialization with custom model."""
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
provider = LiteLLMProvider(model="claude-3-haiku-20240307")
|
|
assert provider.model == "claude-3-haiku-20240307"
|
|
|
|
def test_init_deepseek_model(self):
|
|
"""Test initialization with DeepSeek model."""
|
|
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
|
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
|
|
assert provider.model == "deepseek/deepseek-chat"
|
|
|
|
def test_init_with_api_key(self):
|
|
"""Test initialization with explicit API key."""
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="my-api-key")
|
|
assert provider.api_key == "my-api-key"
|
|
|
|
def test_init_with_api_base(self):
|
|
"""Test initialization with custom API base."""
|
|
provider = LiteLLMProvider(
|
|
model="gpt-4o-mini", api_key="my-key", api_base="https://my-proxy.com/v1"
|
|
)
|
|
assert provider.api_base == "https://my-proxy.com/v1"
|
|
|
|
def test_init_minimax_defaults_api_base(self):
|
|
"""MiniMax should default to the official OpenAI-compatible endpoint."""
|
|
provider = LiteLLMProvider(model="minimax/MiniMax-M2.1", api_key="my-key")
|
|
assert provider.api_base == "https://api.minimax.io/v1"
|
|
|
|
def test_init_minimax_keeps_custom_api_base(self):
|
|
"""Explicit api_base should win over MiniMax defaults."""
|
|
provider = LiteLLMProvider(
|
|
model="minimax/MiniMax-M2.1",
|
|
api_key="my-key",
|
|
api_base="https://proxy.example/v1",
|
|
)
|
|
assert provider.api_base == "https://proxy.example/v1"
|
|
|
|
def test_init_openrouter_defaults_api_base(self):
|
|
"""OpenRouter should default to the official OpenAI-compatible endpoint."""
|
|
provider = LiteLLMProvider(model="openrouter/x-ai/grok-4.20-beta", api_key="my-key")
|
|
assert provider.api_base == "https://openrouter.ai/api/v1"
|
|
|
|
def test_init_openrouter_keeps_custom_api_base(self):
|
|
"""Explicit api_base should win over OpenRouter defaults."""
|
|
provider = LiteLLMProvider(
|
|
model="openrouter/x-ai/grok-4.20-beta",
|
|
api_key="my-key",
|
|
api_base="https://proxy.example/v1",
|
|
)
|
|
assert provider.api_base == "https://proxy.example/v1"
|
|
|
|
def test_init_ollama_no_key_needed(self):
|
|
"""Test that Ollama models don't require API key."""
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
# Should not raise; ollama/ is normalised to ollama_chat/ for tool-call support.
|
|
provider = LiteLLMProvider(model="ollama/llama3")
|
|
assert provider.model == "ollama_chat/llama3"
|
|
|
|
|
|
class TestLiteLLMProviderComplete:
|
|
"""Test LiteLLMProvider.complete() method."""
|
|
|
|
@patch("litellm.completion")
|
|
def test_complete_basic(self, mock_completion):
|
|
"""Test basic completion call."""
|
|
# Mock response
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Hello! I'm an AI assistant."
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "gpt-4o-mini"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 20
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
result = provider.complete(messages=[{"role": "user", "content": "Hello"}])
|
|
|
|
assert result.content == "Hello! I'm an AI assistant."
|
|
assert result.model == "gpt-4o-mini"
|
|
assert result.input_tokens == 10
|
|
assert result.output_tokens == 20
|
|
assert result.stop_reason == "stop"
|
|
|
|
# Verify litellm.completion was called correctly
|
|
mock_completion.assert_called_once()
|
|
call_kwargs = mock_completion.call_args[1]
|
|
assert call_kwargs["model"] == "gpt-4o-mini"
|
|
assert call_kwargs["api_key"] == "test-key"
|
|
|
|
@patch("litellm.completion")
|
|
def test_complete_with_system_prompt(self, mock_completion):
|
|
"""Test completion with system prompt."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Response"
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "gpt-4o-mini"
|
|
mock_response.usage.prompt_tokens = 15
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
provider.complete(
|
|
messages=[{"role": "user", "content": "Hello"}], system="You are a helpful assistant."
|
|
)
|
|
|
|
call_kwargs = mock_completion.call_args[1]
|
|
messages = call_kwargs["messages"]
|
|
assert messages[0]["role"] == "system"
|
|
assert messages[0]["content"] == "You are a helpful assistant."
|
|
|
|
@patch("litellm.completion")
|
|
def test_complete_with_tools(self, mock_completion):
|
|
"""Test completion with tools."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Response"
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "gpt-4o-mini"
|
|
mock_response.usage.prompt_tokens = 20
|
|
mock_response.usage.completion_tokens = 10
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
|
|
tools = [
|
|
Tool(
|
|
name="get_weather",
|
|
description="Get the weather for a location",
|
|
parameters={
|
|
"properties": {"location": {"type": "string", "description": "City name"}},
|
|
"required": ["location"],
|
|
},
|
|
)
|
|
]
|
|
|
|
provider.complete(
|
|
messages=[{"role": "user", "content": "What's the weather?"}], tools=tools
|
|
)
|
|
|
|
call_kwargs = mock_completion.call_args[1]
|
|
assert "tools" in call_kwargs
|
|
assert call_kwargs["tools"][0]["type"] == "function"
|
|
assert call_kwargs["tools"][0]["function"]["name"] == "get_weather"
|
|
|
|
|
|
class TestToolConversion:
|
|
"""Test tool format conversion."""
|
|
|
|
def test_tool_to_openai_format(self):
|
|
"""Test converting Tool to OpenAI format."""
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
|
|
tool = Tool(
|
|
name="search",
|
|
description="Search the web",
|
|
parameters={
|
|
"properties": {"query": {"type": "string", "description": "Search query"}},
|
|
"required": ["query"],
|
|
},
|
|
)
|
|
|
|
result = provider._tool_to_openai_format(tool)
|
|
|
|
assert result["type"] == "function"
|
|
assert result["function"]["name"] == "search"
|
|
assert result["function"]["description"] == "Search the web"
|
|
assert result["function"]["parameters"]["properties"]["query"]["type"] == "string"
|
|
assert result["function"]["parameters"]["required"] == ["query"]
|
|
|
|
def test_parse_tool_call_arguments_repairs_truncated_json(self):
|
|
"""Truncated JSON fragments should be repaired into valid tool inputs."""
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
|
|
parsed = provider._parse_tool_call_arguments(
|
|
(
|
|
'{"question":"What story structure should the agent use?",'
|
|
'"options":["3-act structure","Beginning-Middle-End","Random paragraph"'
|
|
),
|
|
"ask_user",
|
|
)
|
|
|
|
assert parsed == {
|
|
"question": "What story structure should the agent use?",
|
|
"options": [
|
|
"3-act structure",
|
|
"Beginning-Middle-End",
|
|
"Random paragraph",
|
|
],
|
|
}
|
|
|
|
def test_parse_tool_call_arguments_raises_when_unrepairable(self):
|
|
"""Completely invalid JSON should fail fast instead of producing _raw loops."""
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
|
|
with pytest.raises(ValueError, match="Failed to parse tool call arguments"):
|
|
provider._parse_tool_call_arguments('{"question": foo', "ask_user")
|
|
|
|
def test_parse_tool_call_arguments_recovers_pythonish_payloads(self):
|
|
"""Single-quoted and trailing-comma argument payloads should be recovered."""
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex",
|
|
)
|
|
|
|
parsed = provider._parse_tool_call_arguments(
|
|
"{'question': 'Continue?', 'options': ['Yes', 'No'],}",
|
|
"ask_user",
|
|
)
|
|
|
|
assert parsed == {
|
|
"question": "Continue?",
|
|
"options": ["Yes", "No"],
|
|
}
|
|
|
|
def test_parse_tool_call_arguments_keeps_null_inside_strings(self):
|
|
"""Literal normalization should not mutate quoted text values."""
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex",
|
|
)
|
|
|
|
parsed = provider._parse_tool_call_arguments(
|
|
"{'hypothesis': 'null hypothesis', 'approved': false}",
|
|
"summarize",
|
|
)
|
|
|
|
assert parsed == {
|
|
"hypothesis": "null hypothesis",
|
|
"approved": False,
|
|
}
|
|
|
|
def test_parse_tool_call_arguments_strips_json_code_fences(self):
|
|
"""Fence stripping should remove the language tag before JSON parsing."""
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex",
|
|
)
|
|
|
|
parsed = provider._parse_tool_call_arguments(
|
|
'```json\n{"question":"Continue?","options":["Yes","No"]}\n```',
|
|
"ask_user",
|
|
)
|
|
|
|
assert parsed == {
|
|
"question": "Continue?",
|
|
"options": ["Yes", "No"],
|
|
}
|
|
|
|
|
|
class TestAnthropicProviderBackwardCompatibility:
|
|
"""Test AnthropicProvider backward compatibility with LiteLLM backend."""
|
|
|
|
def test_anthropic_provider_is_llm_provider(self):
|
|
"""Test that AnthropicProvider implements LLMProvider interface."""
|
|
provider = AnthropicProvider(api_key="test-key")
|
|
assert isinstance(provider, LLMProvider)
|
|
|
|
def test_anthropic_provider_init_defaults(self):
|
|
"""Test AnthropicProvider initialization with defaults."""
|
|
provider = AnthropicProvider(api_key="test-key")
|
|
assert provider.model == "claude-haiku-4-5-20251001"
|
|
assert provider.api_key == "test-key"
|
|
|
|
def test_anthropic_provider_init_custom_model(self):
|
|
"""Test AnthropicProvider initialization with custom model."""
|
|
provider = AnthropicProvider(api_key="test-key", model="claude-3-haiku-20240307")
|
|
assert provider.model == "claude-3-haiku-20240307"
|
|
|
|
def test_anthropic_provider_uses_litellm_internally(self):
|
|
"""Test that AnthropicProvider delegates to LiteLLMProvider."""
|
|
provider = AnthropicProvider(api_key="test-key", model="claude-3-haiku-20240307")
|
|
assert isinstance(provider._provider, LiteLLMProvider)
|
|
assert provider._provider.model == "claude-3-haiku-20240307"
|
|
assert provider._provider.api_key == "test-key"
|
|
|
|
@patch("litellm.completion")
|
|
def test_anthropic_provider_complete(self, mock_completion):
|
|
"""Test AnthropicProvider.complete() delegates to LiteLLM."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Hello from Claude!"
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "claude-3-haiku-20240307"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = AnthropicProvider(api_key="test-key", model="claude-3-haiku-20240307")
|
|
result = provider.complete(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
system="You are helpful.",
|
|
max_tokens=100,
|
|
)
|
|
|
|
assert result.content == "Hello from Claude!"
|
|
assert result.model == "claude-3-haiku-20240307"
|
|
assert result.input_tokens == 10
|
|
assert result.output_tokens == 5
|
|
|
|
mock_completion.assert_called_once()
|
|
call_kwargs = mock_completion.call_args[1]
|
|
assert call_kwargs["model"] == "claude-3-haiku-20240307"
|
|
assert call_kwargs["api_key"] == "test-key"
|
|
|
|
@patch("litellm.completion")
|
|
def test_anthropic_provider_passes_response_format(self, mock_completion):
|
|
"""Test that AnthropicProvider accepts and forwards response_format."""
|
|
# Setup mock
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "{}"
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "claude-3-haiku-20240307"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = AnthropicProvider(api_key="test-key")
|
|
fmt = {"type": "json_object"}
|
|
|
|
provider.complete(messages=[{"role": "user", "content": "hi"}], response_format=fmt)
|
|
|
|
# Verify it was passed to litellm
|
|
call_kwargs = mock_completion.call_args[1]
|
|
assert call_kwargs["response_format"] == fmt
|
|
|
|
|
|
class TestJsonMode:
|
|
"""Test json_mode parameter for structured JSON output via prompt engineering."""
|
|
|
|
@patch("litellm.completion")
|
|
def test_json_mode_adds_instruction_to_system_prompt(self, mock_completion):
|
|
"""Test that json_mode=True adds JSON instruction to system prompt."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = '{"key": "value"}'
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "gpt-4o-mini"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
provider.complete(
|
|
messages=[{"role": "user", "content": "Return JSON"}],
|
|
system="You are helpful.",
|
|
json_mode=True,
|
|
)
|
|
|
|
call_kwargs = mock_completion.call_args[1]
|
|
# Should NOT use response_format (prompt engineering instead)
|
|
assert "response_format" not in call_kwargs
|
|
# Should have JSON instruction appended to system message
|
|
messages = call_kwargs["messages"]
|
|
assert messages[0]["role"] == "system"
|
|
assert "You are helpful." in messages[0]["content"]
|
|
assert "Please respond with a valid JSON object" in messages[0]["content"]
|
|
|
|
@patch("litellm.completion")
|
|
def test_json_mode_creates_system_prompt_if_none(self, mock_completion):
|
|
"""Test that json_mode=True creates system prompt if none provided."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = '{"key": "value"}'
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "gpt-4o-mini"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
provider.complete(messages=[{"role": "user", "content": "Return JSON"}], json_mode=True)
|
|
|
|
call_kwargs = mock_completion.call_args[1]
|
|
messages = call_kwargs["messages"]
|
|
# Should insert a system message with JSON instruction
|
|
assert messages[0]["role"] == "system"
|
|
assert "Please respond with a valid JSON object" in messages[0]["content"]
|
|
|
|
@patch("litellm.completion")
|
|
def test_json_mode_false_no_instruction(self, mock_completion):
|
|
"""Test that json_mode=False does not add JSON instruction."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Hello"
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "gpt-4o-mini"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
provider.complete(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
system="You are helpful.",
|
|
json_mode=False,
|
|
)
|
|
|
|
call_kwargs = mock_completion.call_args[1]
|
|
assert "response_format" not in call_kwargs
|
|
messages = call_kwargs["messages"]
|
|
assert messages[0]["role"] == "system"
|
|
assert "Please respond with a valid JSON object" not in messages[0]["content"]
|
|
|
|
@patch("litellm.completion")
|
|
def test_json_mode_default_is_false(self, mock_completion):
|
|
"""Test that json_mode defaults to False (no JSON instruction)."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Hello"
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "gpt-4o-mini"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
provider.complete(
|
|
messages=[{"role": "user", "content": "Hello"}], system="You are helpful."
|
|
)
|
|
|
|
call_kwargs = mock_completion.call_args[1]
|
|
assert "response_format" not in call_kwargs
|
|
messages = call_kwargs["messages"]
|
|
# System prompt should be unchanged
|
|
assert messages[0]["content"] == "You are helpful."
|
|
|
|
@patch("litellm.completion")
|
|
def test_anthropic_provider_passes_json_mode(self, mock_completion):
|
|
"""Test that AnthropicProvider passes json_mode through (prompt engineering)."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = '{"result": "ok"}'
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "claude-haiku-4-5-20251001"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_completion.return_value = mock_response
|
|
|
|
provider = AnthropicProvider(api_key="test-key")
|
|
provider.complete(
|
|
messages=[{"role": "user", "content": "Return JSON"}],
|
|
system="You are helpful.",
|
|
json_mode=True,
|
|
)
|
|
|
|
call_kwargs = mock_completion.call_args[1]
|
|
# Should NOT use response_format
|
|
assert "response_format" not in call_kwargs
|
|
# Should have JSON instruction in system prompt
|
|
messages = call_kwargs["messages"]
|
|
assert messages[0]["role"] == "system"
|
|
assert "Please respond with a valid JSON object" in messages[0]["content"]
|
|
|
|
|
|
class TestComputeRetryDelay:
|
|
"""Test _compute_retry_delay() header parsing and fallback logic."""
|
|
|
|
def test_fallback_exponential_backoff(self):
|
|
"""No exception -> exponential backoff."""
|
|
assert _compute_retry_delay(0) == 2 # 2 * 2^0
|
|
assert _compute_retry_delay(1) == 4 # 2 * 2^1
|
|
assert _compute_retry_delay(2) == 8 # 2 * 2^2
|
|
assert _compute_retry_delay(3) == 16 # 2 * 2^3
|
|
|
|
def test_max_delay_cap(self):
|
|
"""Backoff should be capped at RATE_LIMIT_MAX_DELAY."""
|
|
# 2 * 2^10 = 2048, should be capped at 120
|
|
assert _compute_retry_delay(10) == 120
|
|
|
|
def test_custom_max_delay(self):
|
|
"""Custom max_delay should be respected."""
|
|
assert _compute_retry_delay(5, max_delay=10) == 10
|
|
|
|
def test_retry_after_ms_header(self):
|
|
"""retry-after-ms header should be parsed as milliseconds."""
|
|
exc = _make_exception_with_headers({"retry-after-ms": "5000"})
|
|
assert _compute_retry_delay(0, exception=exc) == 5.0
|
|
|
|
def test_retry_after_ms_fractional(self):
|
|
"""retry-after-ms should handle fractional values."""
|
|
exc = _make_exception_with_headers({"retry-after-ms": "1500"})
|
|
assert _compute_retry_delay(0, exception=exc) == 1.5
|
|
|
|
def test_retry_after_seconds_header(self):
|
|
"""retry-after header as seconds should be parsed."""
|
|
exc = _make_exception_with_headers({"retry-after": "3"})
|
|
assert _compute_retry_delay(0, exception=exc) == 3.0
|
|
|
|
def test_retry_after_seconds_fractional(self):
|
|
"""retry-after header should handle fractional seconds."""
|
|
exc = _make_exception_with_headers({"retry-after": "2.5"})
|
|
assert _compute_retry_delay(0, exception=exc) == 2.5
|
|
|
|
def test_retry_after_ms_takes_priority(self):
|
|
"""retry-after-ms should take priority over retry-after."""
|
|
exc = _make_exception_with_headers(
|
|
{
|
|
"retry-after-ms": "2000",
|
|
"retry-after": "10",
|
|
}
|
|
)
|
|
assert _compute_retry_delay(0, exception=exc) == 2.0
|
|
|
|
def test_retry_after_http_date(self):
|
|
"""retry-after as HTTP-date should be parsed."""
|
|
from email.utils import format_datetime
|
|
|
|
future = datetime.now(UTC) + timedelta(seconds=5)
|
|
date_str = format_datetime(future, usegmt=True)
|
|
exc = _make_exception_with_headers({"retry-after": date_str})
|
|
delay = _compute_retry_delay(0, exception=exc)
|
|
assert 3.0 <= delay <= 6.0 # within tolerance
|
|
|
|
def test_exception_without_response(self):
|
|
"""Exception with response=None should fall back to exponential."""
|
|
exc = Exception("test")
|
|
exc.response = None # type: ignore[attr-defined]
|
|
assert _compute_retry_delay(0, exception=exc) == 2 # exponential fallback
|
|
|
|
def test_exception_without_response_attr(self):
|
|
"""Exception without .response attr should fall back to exponential."""
|
|
exc = ValueError("no response attr")
|
|
assert _compute_retry_delay(0, exception=exc) == 2
|
|
|
|
def test_negative_retry_after_clamped_to_zero(self):
|
|
"""Negative retry-after should be clamped to 0."""
|
|
exc = _make_exception_with_headers({"retry-after": "-5"})
|
|
assert _compute_retry_delay(0, exception=exc) == 0
|
|
|
|
def test_negative_retry_after_ms_clamped_to_zero(self):
|
|
"""Negative retry-after-ms should be clamped to 0."""
|
|
exc = _make_exception_with_headers({"retry-after-ms": "-1000"})
|
|
assert _compute_retry_delay(0, exception=exc) == 0
|
|
|
|
def test_invalid_retry_after_falls_back(self):
|
|
"""Non-numeric, non-date retry-after should fall back to exponential."""
|
|
exc = _make_exception_with_headers({"retry-after": "not-a-number-or-date"})
|
|
assert _compute_retry_delay(0, exception=exc) == 2 # exponential fallback
|
|
|
|
def test_invalid_retry_after_ms_falls_back_to_retry_after(self):
|
|
"""Invalid retry-after-ms should fall through to retry-after."""
|
|
exc = _make_exception_with_headers(
|
|
{
|
|
"retry-after-ms": "garbage",
|
|
"retry-after": "7",
|
|
}
|
|
)
|
|
assert _compute_retry_delay(0, exception=exc) == 7.0
|
|
|
|
def test_retry_after_capped_at_max_delay(self):
|
|
"""Server-provided delay should be capped at max_delay."""
|
|
exc = _make_exception_with_headers({"retry-after": "3600"})
|
|
assert _compute_retry_delay(0, exception=exc) == 120 # capped
|
|
|
|
def test_retry_after_ms_capped_at_max_delay(self):
|
|
"""Server-provided ms delay should be capped at max_delay."""
|
|
exc = _make_exception_with_headers({"retry-after-ms": "300000"}) # 300s
|
|
assert _compute_retry_delay(0, exception=exc) == 120 # capped
|
|
|
|
|
|
def _make_exception_with_headers(headers: dict[str, str]) -> BaseException:
|
|
"""Create a mock exception with response headers for testing."""
|
|
exc = Exception("rate limited")
|
|
response = MagicMock()
|
|
response.headers = headers
|
|
exc.response = response # type: ignore[attr-defined]
|
|
return exc
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Async LLM methods — non-blocking event loop tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAsyncComplete:
|
|
"""Test that acomplete/acomplete_with_tools don't block the event loop."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.acompletion")
|
|
async def test_acomplete_uses_acompletion(self, mock_acompletion):
|
|
"""acomplete() should call litellm.acompletion (async), not litellm.completion."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "async hello"
|
|
mock_response.choices[0].message.tool_calls = None
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.model = "gpt-4o-mini"
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
|
|
# acompletion is async, so mock must return a coroutine
|
|
async def async_return(*args, **kwargs):
|
|
return mock_response
|
|
|
|
mock_acompletion.side_effect = async_return
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
result = await provider.acomplete(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
system="You are helpful.",
|
|
)
|
|
|
|
assert result.content == "async hello"
|
|
assert result.model == "gpt-4o-mini"
|
|
assert result.input_tokens == 10
|
|
assert result.output_tokens == 5
|
|
mock_acompletion.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.acompletion")
|
|
async def test_acomplete_does_not_block_event_loop(self, mock_acompletion):
|
|
"""Verify event loop stays responsive during acomplete()."""
|
|
heartbeat_ticks = []
|
|
|
|
async def heartbeat():
|
|
start = time.monotonic()
|
|
for _ in range(10):
|
|
heartbeat_ticks.append(time.monotonic() - start)
|
|
await asyncio.sleep(0.05)
|
|
|
|
async def slow_acompletion(*args, **kwargs):
|
|
# Simulate a 300ms LLM call — async, so event loop should stay free
|
|
await asyncio.sleep(0.3)
|
|
resp = MagicMock()
|
|
resp.choices = [MagicMock()]
|
|
resp.choices[0].message.content = "done"
|
|
resp.choices[0].message.tool_calls = None
|
|
resp.choices[0].finish_reason = "stop"
|
|
resp.model = "gpt-4o-mini"
|
|
resp.usage.prompt_tokens = 5
|
|
resp.usage.completion_tokens = 3
|
|
return resp
|
|
|
|
mock_acompletion.side_effect = slow_acompletion
|
|
|
|
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
|
|
|
# Run heartbeat + acomplete concurrently
|
|
_, result = await asyncio.gather(
|
|
heartbeat(),
|
|
provider.acomplete(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
),
|
|
)
|
|
|
|
assert result.content == "done"
|
|
# Heartbeat should have ticked multiple times during the 300ms LLM call
|
|
# (if the event loop were blocked, we'd see 0-1 ticks)
|
|
assert len(heartbeat_ticks) >= 3, (
|
|
f"Event loop was blocked — only {len(heartbeat_ticks)} heartbeat ticks"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mock_provider_acomplete(self):
|
|
"""MockLLMProvider.acomplete() should work without blocking."""
|
|
from framework.llm.mock import MockLLMProvider
|
|
|
|
provider = MockLLMProvider()
|
|
result = await provider.acomplete(
|
|
messages=[{"role": "user", "content": "test"}],
|
|
system="Be helpful.",
|
|
)
|
|
|
|
assert result.content # Should have some mock content
|
|
assert result.model == "mock-model"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_base_provider_acomplete_offloads_to_executor(self):
|
|
"""Base LLMProvider.acomplete() should offload sync complete() to thread pool."""
|
|
call_thread_ids = []
|
|
|
|
class SlowSyncProvider(LLMProvider):
|
|
def complete(
|
|
self,
|
|
messages,
|
|
system="",
|
|
tools=None,
|
|
max_tokens=1024,
|
|
response_format=None,
|
|
json_mode=False,
|
|
max_retries=None,
|
|
):
|
|
call_thread_ids.append(threading.current_thread().ident)
|
|
time.sleep(0.1) # Sync blocking
|
|
return LLMResponse(content="sync done", model="slow")
|
|
|
|
provider = SlowSyncProvider()
|
|
main_thread_id = threading.current_thread().ident
|
|
|
|
result = await provider.acomplete(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
)
|
|
|
|
assert result.content == "sync done"
|
|
# The sync complete() should have run on a different thread
|
|
assert call_thread_ids[0] != main_thread_id, (
|
|
"Base acomplete() should offload sync complete() to a thread pool"
|
|
)
|
|
|
|
|
|
class TestMiniMaxStreamFallback:
|
|
"""MiniMax models should use non-stream fallback due to parser incompatibility."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_uses_nonstream_fallback_for_minimax(self):
|
|
"""stream() should call acomplete() and synthesize stream events for MiniMax."""
|
|
from framework.llm.stream_events import FinishEvent, TextDeltaEvent
|
|
|
|
provider = LiteLLMProvider(model="minimax-text-01", api_key="test-key")
|
|
|
|
mock_response = LLMResponse(
|
|
content="hello from minimax",
|
|
model="minimax-text-01",
|
|
input_tokens=7,
|
|
output_tokens=4,
|
|
stop_reason="stop",
|
|
raw_response=None,
|
|
)
|
|
provider.acomplete = AsyncMock(return_value=mock_response)
|
|
|
|
events = []
|
|
async for event in provider.stream(messages=[{"role": "user", "content": "hi"}]):
|
|
events.append(event)
|
|
|
|
assert provider.acomplete.await_count == 1
|
|
assert any(isinstance(e, TextDeltaEvent) for e in events)
|
|
finish = [e for e in events if isinstance(e, FinishEvent)]
|
|
assert len(finish) == 1
|
|
assert finish[0].model == "minimax-text-01"
|
|
|
|
def test_is_minimax_model_variants(self):
|
|
"""Recognize both prefixed and plain MiniMax model names."""
|
|
assert LiteLLMProvider(model="minimax-text-01", api_key="x")._is_minimax_model()
|
|
assert LiteLLMProvider(model="minimax/minimax-text-01", api_key="x")._is_minimax_model()
|
|
assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model()
|
|
|
|
|
|
class TestCodexEmptyStreamRecovery:
|
|
"""Codex empty streams should fall back before surfacing ghost-stream retries."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.acompletion")
|
|
async def test_stream_recovers_empty_codex_stream_via_nonstream_completion(
|
|
self,
|
|
mock_acompletion,
|
|
):
|
|
"""An empty Codex stream should be salvaged with a non-stream completion."""
|
|
from framework.llm.stream_events import FinishEvent, TextDeltaEvent
|
|
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex",
|
|
)
|
|
|
|
class EmptyStreamResponse:
|
|
chunks: list = []
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
raise StopAsyncIteration
|
|
|
|
recovered = MagicMock()
|
|
recovered.choices = [MagicMock()]
|
|
recovered.choices[0].message.content = "Recovered via fallback"
|
|
recovered.choices[0].message.tool_calls = []
|
|
recovered.choices[0].finish_reason = "stop"
|
|
recovered.model = provider.model
|
|
recovered.usage.prompt_tokens = 12
|
|
recovered.usage.completion_tokens = 4
|
|
|
|
async def side_effect(*args, **kwargs):
|
|
if kwargs.get("stream"):
|
|
return EmptyStreamResponse()
|
|
return recovered
|
|
|
|
mock_acompletion.side_effect = side_effect
|
|
|
|
events = []
|
|
async for event in provider.stream(messages=[{"role": "user", "content": "hi"}]):
|
|
events.append(event)
|
|
|
|
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
|
assert len(text_events) == 1
|
|
assert text_events[0].snapshot == "Recovered via fallback"
|
|
|
|
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
|
assert len(finish_events) == 1
|
|
assert finish_events[0].stop_reason == "stop"
|
|
assert finish_events[0].input_tokens == 12
|
|
assert finish_events[0].output_tokens == 4
|
|
|
|
assert mock_acompletion.call_count == 2
|
|
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
|
|
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.acompletion")
|
|
async def test_stream_recovers_empty_codex_stream_with_tool_calls(
|
|
self,
|
|
mock_acompletion,
|
|
):
|
|
"""Non-stream fallback should preserve tool calls, not just text."""
|
|
from framework.llm.stream_events import FinishEvent, ToolCallEvent
|
|
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex/responses",
|
|
)
|
|
|
|
class EmptyStreamResponse:
|
|
chunks: list = []
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
raise StopAsyncIteration
|
|
|
|
tc = MagicMock()
|
|
tc.id = "tool_1"
|
|
tc.function.name = "ask_user"
|
|
tc.function.arguments = '{"question":"Continue?","options":["Yes","No"]}'
|
|
|
|
recovered = MagicMock()
|
|
recovered.choices = [MagicMock()]
|
|
recovered.choices[0].message.content = ""
|
|
recovered.choices[0].message.tool_calls = [tc]
|
|
recovered.choices[0].finish_reason = "tool_calls"
|
|
recovered.model = provider.model
|
|
recovered.usage.prompt_tokens = 14
|
|
recovered.usage.completion_tokens = 5
|
|
|
|
async def side_effect(*args, **kwargs):
|
|
if kwargs.get("stream"):
|
|
return EmptyStreamResponse()
|
|
return recovered
|
|
|
|
mock_acompletion.side_effect = side_effect
|
|
|
|
events = []
|
|
async for event in provider.stream(
|
|
messages=[{"role": "user", "content": "Should we continue?"}],
|
|
tools=[
|
|
Tool(
|
|
name="ask_user",
|
|
description="Ask the user",
|
|
parameters={"properties": {"question": {"type": "string"}}},
|
|
)
|
|
],
|
|
):
|
|
events.append(event)
|
|
|
|
tool_events = [event for event in events if isinstance(event, ToolCallEvent)]
|
|
assert len(tool_events) == 1
|
|
assert tool_events[0].tool_name == "ask_user"
|
|
assert tool_events[0].tool_input == {
|
|
"question": "Continue?",
|
|
"options": ["Yes", "No"],
|
|
}
|
|
|
|
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
|
assert len(finish_events) == 1
|
|
assert finish_events[0].stop_reason == "tool_calls"
|
|
|
|
|
|
class TestCodexRequestHardening:
|
|
def test_codex_backend_forces_responses_mode_per_request(self):
|
|
"""Codex backend should scope the Responses override to the active request."""
|
|
import litellm
|
|
|
|
original = litellm.model_cost.get("gpt-5.4")
|
|
litellm.model_cost["gpt-5.4"] = {"mode": "chat"}
|
|
|
|
try:
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex",
|
|
)
|
|
assert provider._codex_backend is True
|
|
|
|
async def _fake_acompletion(**kwargs):
|
|
assert kwargs["model"] == "openai/gpt-5.4"
|
|
assert litellm.model_cost["gpt-5.4"]["mode"] == "responses"
|
|
response = MagicMock()
|
|
response.choices = [MagicMock()]
|
|
response.choices[0].message.content = "ok"
|
|
response.choices[0].message.tool_calls = []
|
|
response.choices[0].finish_reason = "stop"
|
|
return response
|
|
|
|
with patch("litellm.acompletion", new=_fake_acompletion):
|
|
response = asyncio.run(
|
|
provider._acompletion_with_rate_limit_retry(
|
|
model=provider.model,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
)
|
|
)
|
|
|
|
assert response.choices[0].message.content == "ok"
|
|
assert litellm.model_cost["gpt-5.4"]["mode"] == "chat"
|
|
finally:
|
|
if original is None:
|
|
litellm.model_cost.pop("gpt-5.4", None)
|
|
else:
|
|
litellm.model_cost["gpt-5.4"] = original
|
|
|
|
def test_codex_build_completion_kwargs_splits_prompt_and_forces_tool_choice(self):
|
|
"""Codex requests should chunk large system prompts and require tools when needed."""
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex/responses",
|
|
store=True,
|
|
)
|
|
kwargs = provider._build_completion_kwargs(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
system="# Identity\n" + ("rule\n" * 2000),
|
|
tools=[
|
|
Tool(
|
|
name="ask_user",
|
|
description="Ask the user",
|
|
parameters={"properties": {"question": {"type": "string"}}},
|
|
)
|
|
],
|
|
max_tokens=256,
|
|
response_format=None,
|
|
json_mode=False,
|
|
stream=True,
|
|
)
|
|
|
|
system_messages = [m for m in kwargs["messages"] if m["role"] == "system"]
|
|
assert len(system_messages) >= 2
|
|
assert system_messages[0]["content"].startswith("# Codex Execution Contract")
|
|
assert kwargs["tool_choice"] == "required"
|
|
assert kwargs["store"] is False
|
|
assert "max_tokens" not in kwargs
|
|
assert "stream_options" not in kwargs
|
|
assert kwargs["api_base"] == "https://chatgpt.com/backend-api/codex"
|
|
assert "store" in kwargs["allowed_openai_params"]
|
|
|
|
def test_codex_build_system_messages_handles_whitespace_only_prompt(self):
|
|
"""Whitespace-only system prompts should fall back to the default prompt."""
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex",
|
|
)
|
|
|
|
messages = provider._codex_adapter.build_system_messages("\n \n", json_mode=False)
|
|
|
|
assert messages == [{"role": "system", "content": "You are a helpful assistant."}]
|
|
|
|
def test_codex_merge_tool_call_chunk_handles_parallel_calls_with_broken_indexes(self):
|
|
"""Codex chunk merging should survive index=0 for multiple parallel tool calls."""
|
|
from types import SimpleNamespace
|
|
|
|
provider = LiteLLMProvider(
|
|
model="openai/gpt-5.4",
|
|
api_key="test-key",
|
|
api_base="https://chatgpt.com/backend-api/codex",
|
|
)
|
|
acc: dict[int, dict[str, str]] = {}
|
|
last_idx = 0
|
|
|
|
chunks = [
|
|
SimpleNamespace(
|
|
id="tool_1",
|
|
index=0,
|
|
function=SimpleNamespace(name="web_search", arguments='{"query":"alpha'),
|
|
),
|
|
SimpleNamespace(
|
|
id="tool_2",
|
|
index=0,
|
|
function=SimpleNamespace(name="read_file", arguments='{"path":"beta'),
|
|
),
|
|
SimpleNamespace(
|
|
id=None,
|
|
index=0,
|
|
function=SimpleNamespace(name=None, arguments='"}'),
|
|
),
|
|
SimpleNamespace(
|
|
id=None,
|
|
index=0,
|
|
function=SimpleNamespace(name=None, arguments='"}'),
|
|
),
|
|
]
|
|
|
|
for chunk in chunks:
|
|
last_idx = provider._merge_tool_call_chunk(acc, chunk, last_idx)
|
|
|
|
assert len(acc) == 2
|
|
parsed = [
|
|
provider._parse_tool_call_arguments(slot["arguments"], slot["name"])
|
|
for _, slot in sorted(acc.items())
|
|
]
|
|
assert parsed == [
|
|
{"query": "alpha"},
|
|
{"path": "beta"},
|
|
]
|
|
|
|
|
|
class TestOpenRouterToolCompatFallback:
|
|
"""OpenRouter models should fall back when native tool use is unavailable."""
|
|
|
|
def teardown_method(self):
|
|
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.clear()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.acompletion")
|
|
async def test_stream_falls_back_to_json_tool_emulation(self, mock_acompletion):
|
|
"""OpenRouter tool-use 404s should emit synthetic ToolCallEvents instead of errors."""
|
|
from framework.llm.stream_events import FinishEvent, ToolCallEvent
|
|
|
|
provider = LiteLLMProvider(
|
|
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
|
api_key="test-key",
|
|
)
|
|
tools = [
|
|
Tool(
|
|
name="web_search",
|
|
description="Search the web",
|
|
parameters={
|
|
"properties": {
|
|
"query": {"type": "string"},
|
|
"num_results": {"type": "integer"},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
)
|
|
]
|
|
|
|
compat_response = MagicMock()
|
|
compat_response.choices = [MagicMock()]
|
|
compat_response.choices[0].message.content = (
|
|
'{"assistant_response":"","tool_calls":['
|
|
'{"name":"web_search","arguments":'
|
|
'{"query":"Python 3.13 release notes","num_results":3}}'
|
|
"]}"
|
|
)
|
|
compat_response.choices[0].finish_reason = "stop"
|
|
compat_response.model = provider.model
|
|
compat_response.usage.prompt_tokens = 18
|
|
compat_response.usage.completion_tokens = 9
|
|
|
|
async def side_effect(*args, **kwargs):
|
|
if kwargs.get("stream"):
|
|
raise RuntimeError(
|
|
'OpenrouterException - {"error":{"message":"No endpoints found '
|
|
"that support tool use. To learn more about provider routing, "
|
|
'visit: https://openrouter.ai/docs/guides/routing/provider-selection",'
|
|
'"code":404}}'
|
|
)
|
|
return compat_response
|
|
|
|
mock_acompletion.side_effect = side_effect
|
|
|
|
events = []
|
|
async for event in provider.stream(
|
|
messages=[{"role": "user", "content": "Search for the Python 3.13 release notes."}],
|
|
system="Use tools when needed.",
|
|
tools=tools,
|
|
max_tokens=256,
|
|
):
|
|
events.append(event)
|
|
|
|
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0].tool_name == "web_search"
|
|
assert tool_calls[0].tool_input == {
|
|
"query": "Python 3.13 release notes",
|
|
"num_results": 3,
|
|
}
|
|
assert tool_calls[0].tool_use_id.startswith("openrouter_compat_")
|
|
|
|
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
|
assert len(finish_events) == 1
|
|
assert finish_events[0].stop_reason == "tool_calls"
|
|
assert finish_events[0].input_tokens == 18
|
|
assert finish_events[0].output_tokens == 9
|
|
|
|
assert mock_acompletion.call_count == 2
|
|
first_call = mock_acompletion.call_args_list[0].kwargs
|
|
assert first_call["stream"] is True
|
|
assert "tools" in first_call
|
|
|
|
second_call = mock_acompletion.call_args_list[1].kwargs
|
|
assert "tools" not in second_call
|
|
assert "Tool compatibility mode is active" in second_call["messages"][0]["content"]
|
|
assert provider.model in OPENROUTER_TOOL_COMPAT_MODEL_CACHE
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.acompletion")
|
|
async def test_stream_tool_compat_parses_textual_tool_calls_and_uses_cache(
|
|
self,
|
|
mock_acompletion,
|
|
):
|
|
"""Textual tool-call markers should become ToolCallEvents and skip repeat probing."""
|
|
from framework.llm.stream_events import ToolCallEvent
|
|
|
|
provider = LiteLLMProvider(
|
|
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
|
api_key="test-key",
|
|
)
|
|
tools = [
|
|
Tool(
|
|
name="ask_user_multiple",
|
|
description="Ask the user a multiple-choice question",
|
|
parameters={
|
|
"properties": {
|
|
"options": {"type": "array"},
|
|
"question": {"type": "string"},
|
|
"prompt": {"type": "string"},
|
|
},
|
|
"required": ["options", "question", "prompt"],
|
|
},
|
|
)
|
|
]
|
|
|
|
compat_response = MagicMock()
|
|
compat_response.choices = [MagicMock()]
|
|
compat_response.choices[0].message.content = (
|
|
"<|tool_call_start|>"
|
|
"[ask_user_multiple(options=['Quartet Collaborator', 'Project Advisor'], "
|
|
"question='Who are you?', prompt='Who are you?')]"
|
|
"<|tool_call_end|>"
|
|
)
|
|
compat_response.choices[0].finish_reason = "stop"
|
|
compat_response.model = provider.model
|
|
compat_response.usage.prompt_tokens = 10
|
|
compat_response.usage.completion_tokens = 5
|
|
|
|
call_state = {"count": 0}
|
|
|
|
async def side_effect(*args, **kwargs):
|
|
call_state["count"] += 1
|
|
if kwargs.get("stream"):
|
|
raise RuntimeError(
|
|
'OpenrouterException - {"error":{"message":"No endpoints found '
|
|
'that support tool use.","code":404}}'
|
|
)
|
|
return compat_response
|
|
|
|
mock_acompletion.side_effect = side_effect
|
|
|
|
first_events = []
|
|
async for event in provider.stream(
|
|
messages=[{"role": "user", "content": "Who are you?"}],
|
|
system="Use tools when needed.",
|
|
tools=tools,
|
|
max_tokens=128,
|
|
):
|
|
first_events.append(event)
|
|
|
|
tool_calls = [event for event in first_events if isinstance(event, ToolCallEvent)]
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0].tool_name == "ask_user_multiple"
|
|
assert tool_calls[0].tool_input == {
|
|
"options": ["Quartet Collaborator", "Project Advisor"],
|
|
"question": "Who are you?",
|
|
"prompt": "Who are you?",
|
|
}
|
|
|
|
second_events = []
|
|
async for event in provider.stream(
|
|
messages=[{"role": "user", "content": "Who are you?"}],
|
|
system="Use tools when needed.",
|
|
tools=tools,
|
|
max_tokens=128,
|
|
):
|
|
second_events.append(event)
|
|
|
|
second_tool_calls = [event for event in second_events if isinstance(event, ToolCallEvent)]
|
|
assert len(second_tool_calls) == 1
|
|
assert mock_acompletion.call_count == 3
|
|
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
|
|
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
|
|
assert "stream" not in mock_acompletion.call_args_list[2].kwargs
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.acompletion")
|
|
async def test_stream_tool_compat_parses_plain_text_tool_call_lines(
|
|
self,
|
|
mock_acompletion,
|
|
):
|
|
"""Plain textual tool-call lines should execute as tools, not user-visible text."""
|
|
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
|
|
|
provider = LiteLLMProvider(
|
|
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
|
api_key="test-key",
|
|
)
|
|
tools = [
|
|
Tool(
|
|
name="ask_user",
|
|
description="Ask the user a single multiple-choice question",
|
|
parameters={
|
|
"properties": {
|
|
"question": {"type": "string"},
|
|
"options": {"type": "array"},
|
|
},
|
|
"required": ["question", "options"],
|
|
},
|
|
)
|
|
]
|
|
|
|
compat_response = MagicMock()
|
|
compat_response.choices = [MagicMock()]
|
|
compat_response.choices[0].message.content = (
|
|
"Queen has been loaded. It's ready to assist with your planning needs.\n\n"
|
|
"ask_user('What would you like to do?', ['Define a new agent', "
|
|
"'Diagnose an existing agent', 'Explore tools'])"
|
|
)
|
|
compat_response.choices[0].finish_reason = "stop"
|
|
compat_response.model = provider.model
|
|
compat_response.usage.prompt_tokens = 11
|
|
compat_response.usage.completion_tokens = 7
|
|
|
|
async def side_effect(*args, **kwargs):
|
|
if kwargs.get("stream"):
|
|
raise RuntimeError(
|
|
'OpenrouterException - {"error":{"message":"No endpoints found '
|
|
'that support tool use.","code":404}}'
|
|
)
|
|
return compat_response
|
|
|
|
mock_acompletion.side_effect = side_effect
|
|
|
|
events = []
|
|
async for event in provider.stream(
|
|
messages=[{"role": "user", "content": "hello"}],
|
|
system="Use tools when needed.",
|
|
tools=tools,
|
|
max_tokens=128,
|
|
):
|
|
events.append(event)
|
|
|
|
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0].tool_name == "ask_user"
|
|
assert tool_calls[0].tool_input == {
|
|
"question": "What would you like to do?",
|
|
"options": ["Define a new agent", "Diagnose an existing agent", "Explore tools"],
|
|
}
|
|
|
|
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
|
assert len(text_events) == 1
|
|
assert "ask_user(" not in text_events[0].snapshot
|
|
assert text_events[0].snapshot == (
|
|
"Queen has been loaded. It's ready to assist with your planning needs."
|
|
)
|
|
|
|
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
|
assert len(finish_events) == 1
|
|
assert finish_events[0].stop_reason == "tool_calls"
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.acompletion")
|
|
async def test_stream_tool_compat_treats_non_json_as_plain_text(self, mock_acompletion):
|
|
"""If fallback output is not valid JSON, preserve it as assistant text."""
|
|
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
|
|
|
provider = LiteLLMProvider(
|
|
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
|
api_key="test-key",
|
|
)
|
|
tools = [
|
|
Tool(
|
|
name="web_search",
|
|
description="Search the web",
|
|
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
|
|
)
|
|
]
|
|
|
|
compat_response = MagicMock()
|
|
compat_response.choices = [MagicMock()]
|
|
compat_response.choices[0].message.content = "I can answer directly without tools."
|
|
compat_response.choices[0].finish_reason = "stop"
|
|
compat_response.model = provider.model
|
|
compat_response.usage.prompt_tokens = 12
|
|
compat_response.usage.completion_tokens = 6
|
|
|
|
async def side_effect(*args, **kwargs):
|
|
if kwargs.get("stream"):
|
|
raise RuntimeError(
|
|
'OpenrouterException - {"error":{"message":"No endpoints found '
|
|
'that support tool use.","code":404}}'
|
|
)
|
|
return compat_response
|
|
|
|
mock_acompletion.side_effect = side_effect
|
|
|
|
events = []
|
|
async for event in provider.stream(
|
|
messages=[{"role": "user", "content": "Say hello."}],
|
|
system="Be concise.",
|
|
tools=tools,
|
|
max_tokens=128,
|
|
):
|
|
events.append(event)
|
|
|
|
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
|
assert len(text_events) == 1
|
|
assert text_events[0].snapshot == "I can answer directly without tools."
|
|
assert not any(isinstance(event, ToolCallEvent) for event in events)
|
|
|
|
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
|
assert len(finish_events) == 1
|
|
assert finish_events[0].stop_reason == "stop"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AgentRunner._is_local_model — parameterized tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestIsLocalModel:
|
|
"""Parameterized tests for AgentRunner._is_local_model()."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"ollama/llama3",
|
|
"ollama/mistral",
|
|
"ollama_chat/llama3",
|
|
"vllm/mistral",
|
|
"lm_studio/phi3",
|
|
"llamacpp/llama-7b",
|
|
"Ollama/Llama3", # case-insensitive
|
|
"VLLM/Mistral",
|
|
],
|
|
)
|
|
def test_local_models_return_true(self, model):
|
|
"""Local model prefixes should be recognized."""
|
|
from framework.runner.runner import AgentRunner
|
|
|
|
assert AgentRunner._is_local_model(model) is True
|
|
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"anthropic/claude-3-haiku",
|
|
"openai/gpt-4o",
|
|
"gpt-4o-mini",
|
|
"claude-3-haiku-20240307",
|
|
"gemini/gemini-1.5-flash",
|
|
"groq/llama3-70b",
|
|
"mistral/mistral-large",
|
|
"azure/gpt-4",
|
|
"cohere/command-r",
|
|
"together/llama3-70b",
|
|
],
|
|
)
|
|
def test_cloud_models_return_false(self, model):
|
|
"""Cloud model prefixes should not be treated as local."""
|
|
from framework.runner.runner import AgentRunner
|
|
|
|
assert AgentRunner._is_local_model(model) is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Ollama helper functions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestIsOllamaModel:
|
|
"""Tests for _is_ollama_model()."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"ollama/llama3",
|
|
"ollama/mistral:7b",
|
|
"ollama_chat/llama3",
|
|
"ollama_chat/qwen2.5:72b",
|
|
],
|
|
)
|
|
def test_ollama_models_return_true(self, model):
|
|
assert _is_ollama_model(model) is True
|
|
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"gpt-4o-mini",
|
|
"anthropic/claude-3-haiku",
|
|
"openai/gpt-4o",
|
|
"gemini/gemini-1.5-flash",
|
|
"llama3",
|
|
"",
|
|
],
|
|
)
|
|
def test_non_ollama_models_return_false(self, model):
|
|
assert _is_ollama_model(model) is False
|
|
|
|
|
|
class TestEnsureOllamaChatPrefix:
|
|
"""Tests for _ensure_ollama_chat_prefix()."""
|
|
|
|
@pytest.mark.parametrize(
|
|
("input_model", "expected"),
|
|
[
|
|
("ollama/llama3", "ollama_chat/llama3"),
|
|
("ollama/mistral:7b", "ollama_chat/mistral:7b"),
|
|
("ollama/qwen2.5:72b-instruct", "ollama_chat/qwen2.5:72b-instruct"),
|
|
],
|
|
)
|
|
def test_rewrites_ollama_to_ollama_chat(self, input_model, expected):
|
|
assert _ensure_ollama_chat_prefix(input_model) == expected
|
|
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"ollama_chat/llama3",
|
|
"gpt-4o-mini",
|
|
"anthropic/claude-3-haiku",
|
|
"gemini/gemini-1.5-flash",
|
|
"",
|
|
],
|
|
)
|
|
def test_leaves_non_ollama_prefix_unchanged(self, model):
|
|
assert _ensure_ollama_chat_prefix(model) == model
|
|
|
|
|
|
class TestGetLlmExtraKwargsOllama:
|
|
"""Tests for num_ctx injection via get_llm_extra_kwargs() for Ollama."""
|
|
|
|
def test_ollama_provider_returns_num_ctx(self):
|
|
"""Ollama config should inject num_ctx with default 16384."""
|
|
config = {
|
|
"llm": {"provider": "ollama", "model": "ollama/llama3"},
|
|
}
|
|
with patch("framework.config.get_hive_config", return_value=config):
|
|
result = get_llm_extra_kwargs()
|
|
assert result == {"num_ctx": 16384}
|
|
|
|
def test_ollama_provider_respects_custom_num_ctx(self):
|
|
"""User-specified num_ctx in config should take precedence."""
|
|
config = {
|
|
"llm": {"provider": "ollama", "model": "ollama/llama3", "num_ctx": 32768},
|
|
}
|
|
with patch("framework.config.get_hive_config", return_value=config):
|
|
result = get_llm_extra_kwargs()
|
|
assert result == {"num_ctx": 32768}
|
|
|
|
def test_non_ollama_provider_returns_empty(self):
|
|
"""Non-Ollama provider without subscriptions should return empty dict."""
|
|
config = {
|
|
"llm": {"provider": "anthropic", "model": "claude-3-haiku"},
|
|
}
|
|
with patch("framework.config.get_hive_config", return_value=config):
|
|
result = get_llm_extra_kwargs()
|
|
assert result == {}
|
|
|
|
def test_empty_config_returns_empty(self):
|
|
"""Missing config should return empty dict."""
|
|
with patch("framework.config.get_hive_config", return_value={}):
|
|
result = get_llm_extra_kwargs()
|
|
assert result == {}
|