core: introduce litellm provider
This commit is contained in:
@@ -2,5 +2,6 @@
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "AnthropicProvider"]
|
||||
__all__ = ["LLMProvider", "LLMResponse", "AnthropicProvider", "LiteLLMProvider"]
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
"""LiteLLM provider for pluggable multi-provider LLM support.
|
||||
|
||||
LiteLLM provides a unified, OpenAI-compatible interface that supports
|
||||
multiple LLM providers including OpenAI, Anthropic, Gemini, Mistral,
|
||||
Groq, and local models.
|
||||
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolUse, ToolResult
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""
|
||||
LiteLLM-based LLM provider for multi-provider support.
|
||||
|
||||
Supports any model that LiteLLM supports, including:
|
||||
- OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
|
||||
- Anthropic: claude-3-opus, claude-3-sonnet, claude-3-haiku
|
||||
- Google: gemini-pro, gemini-1.5-pro, gemini-1.5-flash
|
||||
- Mistral: mistral-large, mistral-medium, mistral-small
|
||||
- Groq: llama3-70b, mixtral-8x7b
|
||||
- Local: ollama/llama3, ollama/mistral
|
||||
- And many more...
|
||||
|
||||
Usage:
|
||||
# OpenAI
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini")
|
||||
|
||||
# Anthropic
|
||||
provider = LiteLLMProvider(model="claude-3-haiku-20240307")
|
||||
|
||||
# Google Gemini
|
||||
provider = LiteLLMProvider(model="gemini/gemini-1.5-flash")
|
||||
|
||||
# Local Ollama
|
||||
provider = LiteLLMProvider(model="ollama/llama3")
|
||||
|
||||
# With custom API base
|
||||
provider = LiteLLMProvider(
|
||||
model="gpt-4o-mini",
|
||||
api_base="https://my-proxy.com/v1"
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize the LiteLLM provider.
|
||||
|
||||
Args:
|
||||
model: Model identifier (e.g., "gpt-4o-mini", "claude-3-haiku-20240307")
|
||||
LiteLLM auto-detects the provider from the model name.
|
||||
api_key: API key for the provider. If not provided, LiteLLM will
|
||||
look for the appropriate env var (OPENAI_API_KEY,
|
||||
ANTHROPIC_API_KEY, etc.)
|
||||
api_base: Custom API base URL (for proxies or local deployments)
|
||||
**kwargs: Additional arguments passed to litellm.completion()
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.extra_kwargs = kwargs
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion using LiteLLM."""
|
||||
# Prepare messages with system prompt
|
||||
full_messages = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
# Build kwargs
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
# Make the call
|
||||
response = litellm.completion(**kwargs)
|
||||
|
||||
# Extract content
|
||||
content = response.choices[0].message.content or ""
|
||||
|
||||
# Get usage info
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=response.choices[0].finish_reason or "",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: callable,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until the LLM produces a final response."""
|
||||
# Prepare messages with system prompt
|
||||
current_messages = []
|
||||
if system:
|
||||
current_messages.append({"role": "system", "content": system})
|
||||
current_messages.extend(messages)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Convert tools to OpenAI format
|
||||
openai_tools = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for _ in range(max_iterations):
|
||||
# Build kwargs
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": 1024,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = litellm.completion(**kwargs)
|
||||
|
||||
# Track tokens
|
||||
usage = response.usage
|
||||
if usage:
|
||||
total_input_tokens += usage.prompt_tokens
|
||||
total_output_tokens += usage.completion_tokens
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# Check if we're done (no tool calls)
|
||||
if choice.finish_reason == "stop" or not message.tool_calls:
|
||||
return LLMResponse(
|
||||
content=message.content or "",
|
||||
model=response.model or self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason=choice.finish_reason or "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
# Process tool calls.
|
||||
# Add assistant message with tool calls.
|
||||
current_messages.append({
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
})
|
||||
|
||||
# Execute tools and add results.
|
||||
for tool_call in message.tool_calls:
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=args,
|
||||
)
|
||||
|
||||
result = tool_executor(tool_use)
|
||||
|
||||
# Add tool result message
|
||||
current_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
})
|
||||
|
||||
# Max iterations reached
|
||||
return LLMResponse(
|
||||
content="Max tool iterations reached",
|
||||
model=self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason="max_iterations",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def _tool_to_openai_format(self, tool: Tool) -> dict[str, Any]:
|
||||
"""Convert Tool to OpenAI function calling format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": tool.parameters.get("properties", {}),
|
||||
"required": tool.parameters.get("required", []),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -12,6 +12,7 @@ from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.graph.executor import GraphExecutor, ExecutionResult
|
||||
from framework.llm.provider import LLMProvider, Tool, ToolResult, ToolUse
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
@@ -182,7 +183,8 @@ class AgentRunner:
|
||||
goal: Loaded Goal object
|
||||
mock_mode: If True, use mock LLM responses
|
||||
storage_path: Path for runtime storage (defaults to temp)
|
||||
model: Anthropic model to use
|
||||
model: Model to use - any LiteLLM-compatible model name
|
||||
(e.g., "claude-sonnet-4-20250514", "gpt-4o-mini", "gemini/gemini-pro")
|
||||
"""
|
||||
self.agent_path = agent_path
|
||||
self.graph = graph
|
||||
@@ -366,11 +368,11 @@ class AgentRunner:
|
||||
# Create runtime
|
||||
self._runtime = Runtime(storage_path=self._storage_path)
|
||||
|
||||
# Create LLM provider (if not mock mode and API key available)
|
||||
if not self.mock_mode and os.environ.get("ANTHROPIC_API_KEY"):
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
|
||||
self._llm = AnthropicProvider(model=self.model)
|
||||
# Create LLM provider (if not mock mode)
|
||||
# Use LiteLLM as the unified backend for all providers
|
||||
if not self.mock_mode:
|
||||
# LiteLLM auto-detects the provider from model name and finds the right API key
|
||||
self._llm = LiteLLMProvider(model=self.model)
|
||||
|
||||
# Create executor
|
||||
self._executor = GraphExecutor(
|
||||
|
||||
@@ -7,6 +7,7 @@ requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"pydantic>=2.0",
|
||||
"anthropic>=0.40.0",
|
||||
"litellm>=1.81.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
pydantic>=2.0
|
||||
anthropic>=0.40.0
|
||||
httpx>=0.27.0
|
||||
litellm>=1.81.0
|
||||
|
||||
# MCP server dependencies
|
||||
mcp
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Tests for LiteLLM provider.
|
||||
|
||||
Run with:
|
||||
cd core
|
||||
pip install litellm pytest
|
||||
pytest tests/test_litellm_provider.py -v
|
||||
|
||||
For live tests (requires API keys):
|
||||
OPENAI_API_KEY=sk-... pytest tests/test_litellm_provider.py -v -m live
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.provider import LLMProvider, Tool, ToolUse, ToolResult
|
||||
|
||||
|
||||
class TestLiteLLMProviderInit:
|
||||
"""Test LiteLLMProvider initialization."""
|
||||
|
||||
def test_init_with_defaults(self):
|
||||
"""Test initialization with default parameters."""
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
provider = LiteLLMProvider()
|
||||
assert provider.model == "gpt-4o-mini"
|
||||
assert provider.api_key is None
|
||||
assert provider.api_base is None
|
||||
|
||||
def test_init_with_custom_model(self):
|
||||
"""Test initialization with custom model."""
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
provider = LiteLLMProvider(model="claude-3-haiku-20240307")
|
||||
assert provider.model == "claude-3-haiku-20240307"
|
||||
|
||||
def test_init_with_api_key(self):
|
||||
"""Test initialization with explicit API key."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="my-api-key")
|
||||
assert provider.api_key == "my-api-key"
|
||||
|
||||
def test_init_with_api_base(self):
|
||||
"""Test initialization with custom API base."""
|
||||
provider = LiteLLMProvider(
|
||||
model="gpt-4o-mini",
|
||||
api_key="my-key",
|
||||
api_base="https://my-proxy.com/v1"
|
||||
)
|
||||
assert provider.api_base == "https://my-proxy.com/v1"
|
||||
|
||||
def test_init_ollama_no_key_needed(self):
|
||||
"""Test that Ollama models don't require API key."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Should not raise.
|
||||
provider = LiteLLMProvider(model="ollama/llama3")
|
||||
assert provider.model == "ollama/llama3"
|
||||
|
||||
|
||||
class TestLiteLLMProviderComplete:
|
||||
"""Test LiteLLMProvider.complete() method."""
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_basic(self, mock_completion):
|
||||
"""Test basic completion call."""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello! I'm an AI assistant."
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 20
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
result = provider.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
assert result.content == "Hello! I'm an AI assistant."
|
||||
assert result.model == "gpt-4o-mini"
|
||||
assert result.input_tokens == 10
|
||||
assert result.output_tokens == 20
|
||||
assert result.stop_reason == "stop"
|
||||
|
||||
# Verify litellm.completion was called correctly
|
||||
mock_completion.assert_called_once()
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
assert call_kwargs["model"] == "gpt-4o-mini"
|
||||
assert call_kwargs["api_key"] == "test-key"
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_system_prompt(self, mock_completion):
|
||||
"""Test completion with system prompt."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
provider.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system="You are a helpful assistant."
|
||||
)
|
||||
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
messages = call_kwargs["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[0]["content"] == "You are a helpful assistant."
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools(self, mock_completion):
|
||||
"""Test completion with tools."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 10
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="get_weather",
|
||||
description="Get the weather for a location",
|
||||
parameters={
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
provider.complete(
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=tools
|
||||
)
|
||||
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
assert "tools" in call_kwargs
|
||||
assert call_kwargs["tools"][0]["type"] == "function"
|
||||
assert call_kwargs["tools"][0]["function"]["name"] == "get_weather"
|
||||
|
||||
|
||||
class TestLiteLLMProviderToolUse:
|
||||
"""Test LiteLLMProvider.complete_with_tools() method."""
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools_single_iteration(self, mock_completion):
|
||||
"""Test tool use with single iteration."""
|
||||
# First response: tool call
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.choices = [MagicMock()]
|
||||
tool_call_response.choices[0].message.content = None
|
||||
tool_call_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
tool_call_response.choices[0].message.tool_calls[0].id = "call_123"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.name = "get_weather"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.arguments = '{"location": "London"}'
|
||||
tool_call_response.choices[0].finish_reason = "tool_calls"
|
||||
tool_call_response.model = "gpt-4o-mini"
|
||||
tool_call_response.usage.prompt_tokens = 20
|
||||
tool_call_response.usage.completion_tokens = 15
|
||||
|
||||
# Second response: final answer
|
||||
final_response = MagicMock()
|
||||
final_response.choices = [MagicMock()]
|
||||
final_response.choices[0].message.content = "The weather in London is sunny."
|
||||
final_response.choices[0].message.tool_calls = None
|
||||
final_response.choices[0].finish_reason = "stop"
|
||||
final_response.model = "gpt-4o-mini"
|
||||
final_response.usage.prompt_tokens = 30
|
||||
final_response.usage.completion_tokens = 10
|
||||
|
||||
mock_completion.side_effect = [tool_call_response, final_response]
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="get_weather",
|
||||
description="Get the weather",
|
||||
parameters={"properties": {"location": {"type": "string"}}, "required": ["location"]}
|
||||
)
|
||||
]
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="Sunny, 22C",
|
||||
is_error=False
|
||||
)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "What's the weather in London?"}],
|
||||
system="You are a weather assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor
|
||||
)
|
||||
|
||||
assert result.content == "The weather in London is sunny."
|
||||
assert result.input_tokens == 50 # 20 + 30
|
||||
assert result.output_tokens == 25 # 15 + 10
|
||||
assert mock_completion.call_count == 2
|
||||
|
||||
|
||||
class TestToolConversion:
|
||||
"""Test tool format conversion."""
|
||||
|
||||
def test_tool_to_openai_format(self):
|
||||
"""Test converting Tool to OpenAI format."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tool = Tool(
|
||||
name="search",
|
||||
description="Search the web",
|
||||
parameters={
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
)
|
||||
|
||||
result = provider._tool_to_openai_format(tool)
|
||||
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "search"
|
||||
assert result["function"]["description"] == "Search the web"
|
||||
assert result["function"]["parameters"]["properties"]["query"]["type"] == "string"
|
||||
assert result["function"]["parameters"]["required"] == ["query"]
|
||||
Reference in New Issue
Block a user