refactor: deprecate the unused llm functions
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
"""Anthropic Claude LLM provider - backward compatible wrapper around LiteLLM."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
|
||||
|
||||
def _get_api_key_from_credential_store() -> str | None:
|
||||
@@ -83,23 +82,6 @@ class AnthropicProvider(LLMProvider):
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until Claude produces a final response (via LiteLLM)."""
|
||||
return self._provider.complete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -121,19 +103,3 @@ class AnthropicProvider(LLMProvider):
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Async tool-use loop via LiteLLM."""
|
||||
return await self._provider.acomplete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,7 +23,7 @@ except ImportError:
|
||||
litellm = None # type: ignore[assignment]
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -511,127 +511,6 @@ class LiteLLMProvider(LLMProvider):
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 4096,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until the LLM produces a final response."""
|
||||
# Prepare messages with system prompt
|
||||
current_messages = []
|
||||
if system:
|
||||
current_messages.append({"role": "system", "content": system})
|
||||
current_messages.extend(messages)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Convert tools to OpenAI format
|
||||
openai_tools = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for _ in range(max_iterations):
|
||||
# Build kwargs
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = self._completion_with_rate_limit_retry(**kwargs)
|
||||
|
||||
# Track tokens
|
||||
usage = response.usage
|
||||
if usage:
|
||||
total_input_tokens += usage.prompt_tokens
|
||||
total_output_tokens += usage.completion_tokens
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# Check if we're done (no tool calls)
|
||||
if choice.finish_reason == "stop" or not message.tool_calls:
|
||||
return LLMResponse(
|
||||
content=message.content or "",
|
||||
model=response.model or self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason=choice.finish_reason or "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
# Process tool calls.
|
||||
# Add assistant message with tool calls.
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Execute tools and add results.
|
||||
for tool_call in message.tool_calls:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
# Surface error to LLM and skip tool execution
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=args,
|
||||
)
|
||||
|
||||
result = tool_executor(tool_use)
|
||||
|
||||
# Add tool result message
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
# Max iterations reached
|
||||
return LLMResponse(
|
||||
content="Max tool iterations reached",
|
||||
model=self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason="max_iterations",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Async variants — non-blocking on the event loop
|
||||
# ------------------------------------------------------------------
|
||||
@@ -835,115 +714,6 @@ class LiteLLMProvider(LLMProvider):
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 4096,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete_with_tools(). Uses litellm.acompletion — non-blocking."""
|
||||
current_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
current_messages.append({"role": "system", "content": system})
|
||||
current_messages.extend(messages)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
openai_tools = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for _ in range(max_iterations):
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(**kwargs)
|
||||
|
||||
usage = response.usage
|
||||
if usage:
|
||||
total_input_tokens += usage.prompt_tokens
|
||||
total_output_tokens += usage.completion_tokens
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
if choice.finish_reason == "stop" or not message.tool_calls:
|
||||
return LLMResponse(
|
||||
content=message.content or "",
|
||||
model=response.model or self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason=choice.finish_reason or "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=args,
|
||||
)
|
||||
|
||||
result = tool_executor(tool_use)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content="Max tool iterations reached",
|
||||
model=self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason="max_iterations",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def _tool_to_openai_format(self, tool: Tool) -> dict[str, Any]:
|
||||
"""Convert Tool to OpenAI function calling format."""
|
||||
return {
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
@@ -146,43 +146,6 @@ class MockLLMProvider(LLMProvider):
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a mock completion without tool use.
|
||||
|
||||
In mock mode, we skip tool execution and return a final response immediately.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation (ignored in mock mode)
|
||||
system: System prompt (used to extract expected output keys)
|
||||
tools: Available tools (ignored in mock mode)
|
||||
tool_executor: Tool executor function (ignored in mock mode)
|
||||
max_iterations: Max iterations (ignored in mock mode)
|
||||
|
||||
Returns:
|
||||
LLMResponse with mock content
|
||||
"""
|
||||
# In mock mode, we don't execute tools - just return a final response
|
||||
# Try to generate JSON if the system prompt suggests structured output
|
||||
json_mode = "json" in system.lower() or "output_keys" in system.lower()
|
||||
|
||||
content = self._generate_mock_response(system=system, json_mode=json_mode)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=self.model,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -204,23 +167,6 @@ class MockLLMProvider(LLMProvider):
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Async mock tool-use completion (no I/O, returns immediately)."""
|
||||
return self.complete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@@ -90,30 +90,6 @@ class LLMProvider(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[["ToolUse"], "ToolResult"],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Run a tool-use loop until the LLM produces a final response.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation
|
||||
system: System prompt
|
||||
tools: Available tools
|
||||
tool_executor: Function to execute tools: (ToolUse) -> ToolResult
|
||||
max_iterations: Max tool calls before stopping
|
||||
|
||||
Returns:
|
||||
Final LLMResponse after tool use completes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -144,32 +120,6 @@ class LLMProvider(ABC):
|
||||
),
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list["Tool"],
|
||||
tool_executor: Callable[["ToolUse"], "ToolResult"],
|
||||
max_iterations: int = 10,
|
||||
) -> "LLMResponse":
|
||||
"""Async version of complete_with_tools(). Non-blocking on the event loop.
|
||||
|
||||
Default implementation offloads the sync complete_with_tools() to a thread pool.
|
||||
Subclasses SHOULD override for native async I/O.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.complete_with_tools,
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
),
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -36,15 +36,6 @@ class FailingLLMProvider(LLMProvider):
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list,
|
||||
tool_executor: Any,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
|
||||
async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation:
|
||||
|
||||
@@ -62,8 +62,6 @@ class MockStreamingLLM(LLMProvider):
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -64,8 +64,6 @@ class MockStreamingLLM(LLMProvider):
|
||||
self.complete_calls.append({"messages": messages, "system": system})
|
||||
return LLMResponse(content=self.complete_response, model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -95,17 +95,6 @@ class ScriptableMockLLMProvider(LLMProvider):
|
||||
output_tokens=10,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
tool_executor: Callable[[ToolUse], ToolResult] | None = None,
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 1024,
|
||||
) -> LLMResponse:
|
||||
return self.complete(messages, system, tools, max_tokens)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -68,8 +68,6 @@ class MockStreamingLLM(LLMProvider):
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1026,8 +1024,6 @@ class ErrorThenSuccessLLM(LLMProvider):
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="ok", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
class TestTransientErrorRetry:
|
||||
@@ -1131,20 +1127,6 @@ class TestTransientErrorRetry:
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kwargs,
|
||||
):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
llm = StreamErrorThenSuccessLLM()
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
@@ -1227,9 +1209,6 @@ class TestTransientErrorRetry:
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
return LLMResponse(content="ok", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs):
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
llm = RecoverableErrorThenSuccessLLM()
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
@@ -1412,19 +1391,6 @@ class ToolRepeatLLM(LLMProvider):
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
|
||||
class TestToolDoomLoopIntegration:
|
||||
@@ -1650,20 +1616,6 @@ class TestToolDoomLoopIntegration:
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kw,
|
||||
):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
llm = DiffArgsLLM()
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for ExecutionStream retention behavior."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@@ -38,16 +38,6 @@ class DummyLLMProvider(LLMProvider):
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(content="Summary for compaction.", model="dummy")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(content="Summary for compaction.", model="dummy")
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -20,7 +20,7 @@ import pytest
|
||||
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider, _compute_retry_delay
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
|
||||
|
||||
class TestLiteLLMProviderInit:
|
||||
@@ -154,124 +154,6 @@ class TestLiteLLMProviderComplete:
|
||||
assert call_kwargs["tools"][0]["function"]["name"] == "get_weather"
|
||||
|
||||
|
||||
class TestLiteLLMProviderToolUse:
|
||||
"""Test LiteLLMProvider.complete_with_tools() method."""
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools_single_iteration(self, mock_completion):
|
||||
"""Test tool use with single iteration."""
|
||||
# First response: tool call
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.choices = [MagicMock()]
|
||||
tool_call_response.choices[0].message.content = None
|
||||
tool_call_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
tool_call_response.choices[0].message.tool_calls[0].id = "call_123"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.name = "get_weather"
|
||||
tool_call_response.choices[0].message.tool_calls[
|
||||
0
|
||||
].function.arguments = '{"location": "London"}'
|
||||
tool_call_response.choices[0].finish_reason = "tool_calls"
|
||||
tool_call_response.model = "gpt-4o-mini"
|
||||
tool_call_response.usage.prompt_tokens = 20
|
||||
tool_call_response.usage.completion_tokens = 15
|
||||
|
||||
# Second response: final answer
|
||||
final_response = MagicMock()
|
||||
final_response.choices = [MagicMock()]
|
||||
final_response.choices[0].message.content = "The weather in London is sunny."
|
||||
final_response.choices[0].message.tool_calls = None
|
||||
final_response.choices[0].finish_reason = "stop"
|
||||
final_response.model = "gpt-4o-mini"
|
||||
final_response.usage.prompt_tokens = 30
|
||||
final_response.usage.completion_tokens = 10
|
||||
|
||||
mock_completion.side_effect = [tool_call_response, final_response]
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="get_weather",
|
||||
description="Get the weather",
|
||||
parameters={
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(tool_use_id=tool_use.id, content="Sunny, 22C", is_error=False)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "What's the weather in London?"}],
|
||||
system="You are a weather assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert result.content == "The weather in London is sunny."
|
||||
assert result.input_tokens == 50 # 20 + 30
|
||||
assert result.output_tokens == 25 # 15 + 10
|
||||
assert mock_completion.call_count == 2
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools_invalid_json_arguments_are_handled(self, mock_completion):
|
||||
"""Test that invalid JSON tool arguments do not execute the tool."""
|
||||
# Mock response with invalid JSON arguments
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.choices = [MagicMock()]
|
||||
tool_call_response.choices[0].message.content = None
|
||||
tool_call_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
tool_call_response.choices[0].message.tool_calls[0].id = "call_123"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.name = "test_tool"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.arguments = "{invalid json"
|
||||
tool_call_response.choices[0].finish_reason = "tool_calls"
|
||||
tool_call_response.model = "gpt-4o-mini"
|
||||
tool_call_response.usage.prompt_tokens = 10
|
||||
tool_call_response.usage.completion_tokens = 5
|
||||
|
||||
# Final response (LLM continues after tool error)
|
||||
final_response = MagicMock()
|
||||
final_response.choices = [MagicMock()]
|
||||
final_response.choices[0].message.content = "Handled error"
|
||||
final_response.choices[0].message.tool_calls = None
|
||||
final_response.choices[0].finish_reason = "stop"
|
||||
final_response.model = "gpt-4o-mini"
|
||||
final_response.usage.prompt_tokens = 5
|
||||
final_response.usage.completion_tokens = 5
|
||||
|
||||
mock_completion.side_effect = [tool_call_response, final_response]
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={"properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
called["value"] = True
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id, content="should not be called", is_error=False
|
||||
)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "Run tool"}],
|
||||
system="You are a test assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert called["value"] is False
|
||||
assert result.content == "Handled error"
|
||||
|
||||
|
||||
class TestToolConversion:
|
||||
"""Test tool format conversion."""
|
||||
|
||||
@@ -352,43 +234,6 @@ class TestAnthropicProviderBackwardCompatibility:
|
||||
assert call_kwargs["model"] == "claude-3-haiku-20240307"
|
||||
assert call_kwargs["api_key"] == "test-key"
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_anthropic_provider_complete_with_tools(self, mock_completion):
|
||||
"""Test AnthropicProvider.complete_with_tools() delegates to LiteLLM."""
|
||||
# Mock a simple response (no tool calls)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "The time is 3:00 PM."
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 10
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = AnthropicProvider(api_key="test-key", model="claude-3-haiku-20240307")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="get_time",
|
||||
description="Get current time",
|
||||
parameters={"properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(tool_use_id=tool_use.id, content="3:00 PM", is_error=False)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "What time is it?"}],
|
||||
system="You are a time assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert result.content == "The time is 3:00 PM."
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_anthropic_provider_passes_response_format(self, mock_completion):
|
||||
"""Test that AnthropicProvider accepts and forwards response_format."""
|
||||
@@ -738,43 +583,6 @@ class TestAsyncComplete:
|
||||
f"Event loop was blocked — only {len(heartbeat_ticks)} heartbeat ticks"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_acomplete_with_tools_uses_acompletion(self, mock_acompletion):
|
||||
"""acomplete_with_tools() should use litellm.acompletion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "tool result"
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
async def async_return(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
mock_acompletion.side_effect = async_return
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
tools = [
|
||||
Tool(
|
||||
name="search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"q": {"type": "string"}}, "required": ["q"]},
|
||||
)
|
||||
]
|
||||
|
||||
result = await provider.acomplete_with_tools(
|
||||
messages=[{"role": "user", "content": "Search for cats"}],
|
||||
system="You are helpful.",
|
||||
tools=tools,
|
||||
tool_executor=lambda tu: ToolResult(tool_use_id=tu.id, content="cats"),
|
||||
)
|
||||
|
||||
assert result.content == "tool result"
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_provider_acomplete(self):
|
||||
"""MockLLMProvider.acomplete() should work without blocking."""
|
||||
@@ -809,11 +617,6 @@ class TestAsyncComplete:
|
||||
time.sleep(0.1) # Sync blocking
|
||||
return LLMResponse(content="sync done", model="slow")
|
||||
|
||||
def complete_with_tools(
|
||||
self, messages, system, tools, tool_executor, max_iterations=10
|
||||
):
|
||||
return LLMResponse(content="sync tools done", model="slow")
|
||||
|
||||
provider = SlowSyncProvider()
|
||||
main_thread_id = threading.current_thread().ident
|
||||
|
||||
|
||||
@@ -52,8 +52,6 @@ class MockLLMProvider(LLMProvider):
|
||||
output_tokens=50,
|
||||
)
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, max_iterations=10):
|
||||
raise NotImplementedError("Tool use not needed for judge tests")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -102,4 +102,3 @@ class TestOrchestratorLLMProviderType:
|
||||
|
||||
assert isinstance(orchestrator._llm, LLMProvider)
|
||||
assert hasattr(orchestrator._llm, "complete")
|
||||
assert hasattr(orchestrator._llm, "complete_with_tools")
|
||||
|
||||
Reference in New Issue
Block a user