fix(wip): codex tool use bug fixes
This commit is contained in:
@@ -70,8 +70,49 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
AnthropicModelInfo.validate_environment = _patched_validate_environment
|
||||
|
||||
|
||||
def _patch_litellm_metadata_nonetype() -> None:
|
||||
"""Patch litellm entry points to prevent metadata=None TypeError.
|
||||
|
||||
litellm bug: the @client decorator in utils.py has four places that do
|
||||
"model_group" in kwargs.get("metadata", {})
|
||||
but kwargs["metadata"] can be explicitly None (set internally by
|
||||
litellm_params), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
This masks the real API error with a confusing APIConnectionError.
|
||||
|
||||
Fix: wrap the four litellm entry points (completion, acompletion,
|
||||
responses, aresponses) to pop metadata=None before the @client
|
||||
decorator's error handler can crash on it.
|
||||
"""
|
||||
import functools
|
||||
|
||||
for fn_name in ("completion", "acompletion", "responses", "aresponses"):
|
||||
original = getattr(litellm, fn_name, None)
|
||||
if original is None:
|
||||
continue
|
||||
if asyncio.iscoroutinefunction(original):
|
||||
|
||||
@functools.wraps(original)
|
||||
async def _async_wrapper(*args, _orig=original, **kwargs):
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs.pop("metadata", None)
|
||||
return await _orig(*args, **kwargs)
|
||||
|
||||
setattr(litellm, fn_name, _async_wrapper)
|
||||
else:
|
||||
|
||||
@functools.wraps(original)
|
||||
def _sync_wrapper(*args, _orig=original, **kwargs):
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs.pop("metadata", None)
|
||||
return _orig(*args, **kwargs)
|
||||
|
||||
setattr(litellm, fn_name, _sync_wrapper)
|
||||
|
||||
|
||||
if litellm is not None:
|
||||
_patch_litellm_anthropic_oauth()
|
||||
_patch_litellm_metadata_nonetype()
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
@@ -284,6 +325,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
"LiteLLM is not installed. Please install it with: uv pip install litellm"
|
||||
)
|
||||
|
||||
# Note: The Codex ChatGPT backend is a Responses API endpoint at
|
||||
# chatgpt.com/backend-api/codex/responses. LiteLLM's model registry
|
||||
# correctly marks codex models with mode="responses", so we do NOT
|
||||
# override the mode. The responses_api_bridge in litellm handles
|
||||
# converting Chat Completions requests to Responses API format.
|
||||
|
||||
def _completion_with_rate_limit_retry(
|
||||
self, max_retries: int | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@@ -708,6 +755,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
# Codex Responses API requires an `instructions` field (system prompt).
|
||||
# Inject a minimal one when callers don't provide a system message.
|
||||
if self._codex_backend and not any(m["role"] == "system" for m in full_messages):
|
||||
full_messages.insert(0, {"role": "system", "content": "You are a helpful assistant."})
|
||||
|
||||
# Add JSON mode via prompt engineering (works across all providers)
|
||||
if json_mode:
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
@@ -732,7 +784,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
# The Codex ChatGPT backend rejects max_output_tokens and stream_options.
|
||||
# The Codex ChatGPT backend (Responses API) rejects several params.
|
||||
if self._codex_backend:
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs.pop("stream_options", None)
|
||||
@@ -744,6 +796,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
tail_events: list[StreamEvent] = []
|
||||
accumulated_text = ""
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
_last_tool_idx = 0 # tracks most recently opened tool call slot
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
stream_finish_reason: str | None = None
|
||||
@@ -767,9 +820,33 @@ class LiteLLMProvider(LLMProvider):
|
||||
)
|
||||
|
||||
# --- Tool calls (accumulate across chunks) ---
|
||||
# The Codex/Responses API bridge (litellm bug) hardcodes
|
||||
# index=0 on every ChatCompletionToolCallChunk, even for
|
||||
# parallel tool calls. We work around this by using tc.id
|
||||
# (set on output_item.added events) as a "new tool call"
|
||||
# signal and tracking the most recently opened slot for
|
||||
# argument deltas that arrive with id=None.
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
|
||||
if tc.id:
|
||||
# New tool call announced (or done event re-sent).
|
||||
# Check if this id already has a slot.
|
||||
existing_idx = next(
|
||||
(k for k, v in tool_calls_acc.items() if v["id"] == tc.id),
|
||||
None,
|
||||
)
|
||||
if existing_idx is not None:
|
||||
idx = existing_idx
|
||||
elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in ("", tc.id):
|
||||
# Slot taken by a different call — assign new index
|
||||
idx = max(tool_calls_acc.keys()) + 1
|
||||
_last_tool_idx = idx
|
||||
else:
|
||||
# Argument delta with no id — route to last opened slot
|
||||
idx = _last_tool_idx
|
||||
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.id:
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
"""Diagnostic script to reproduce and trace Codex streaming errors.
|
||||
|
||||
Run: .venv/bin/python core/tests/debug_codex_stream.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
import litellm # noqa: E402
|
||||
|
||||
# Enable litellm debug logging to see the raw HTTP exchange
|
||||
litellm._turn_on_debug()
|
||||
|
||||
|
||||
async def test_codex_stream():
|
||||
"""Minimal Codex streaming call via LiteLLMProvider (Responses API path)."""
|
||||
from framework.config import get_api_base, get_api_key, get_llm_extra_kwargs
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
api_key = get_api_key()
|
||||
api_base = get_api_base()
|
||||
extra_kwargs = get_llm_extra_kwargs()
|
||||
|
||||
if not api_key or not api_base:
|
||||
print("ERROR: No Codex subscription configured in ~/.hive/configuration.json")
|
||||
return
|
||||
|
||||
print(f"api_base: {api_base}")
|
||||
print(f"extra_kwargs keys: {list(extra_kwargs.keys())}")
|
||||
print(f"extra_headers: {list(extra_kwargs.get('extra_headers', {}).keys())}")
|
||||
|
||||
model = "openai/gpt-5.3-codex"
|
||||
|
||||
# Create the provider
|
||||
provider = LiteLLMProvider(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
**extra_kwargs,
|
||||
)
|
||||
print(f"_codex_backend: {provider._codex_backend}")
|
||||
|
||||
# Verify mode is "responses" (the correct routing for Codex backend)
|
||||
_strip = model.removeprefix("openai/")
|
||||
mode = litellm.model_cost.get(_strip, {}).get("mode", "NOT SET")
|
||||
print(f"litellm.model_cost['{_strip}']['mode']: {mode}")
|
||||
if mode != "responses":
|
||||
print(" WARNING: Expected mode='responses' for Codex backend!")
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 1: Stream via LiteLLMProvider.stream() (the real code path)
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 1: LiteLLMProvider.stream() — basic text")
|
||||
print("=" * 60)
|
||||
try:
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Say hello in exactly 3 words."}]
|
||||
chunk_count = 0
|
||||
text = ""
|
||||
async for event in provider.stream(messages=messages):
|
||||
chunk_count += 1
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, TextEndEvent):
|
||||
print(f" TextEnd: {event.full_text!r}")
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
print(f" ToolCall: {event.tool_name}({event.tool_input})")
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f" Finish: stop={event.stop_reason} "
|
||||
f"in={event.input_tokens} out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
print(f" Text: {text!r}")
|
||||
print(f" Total events: {chunk_count}")
|
||||
print(" RESULT: OK" if text else " RESULT: EMPTY")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 2: Stream via LiteLLMProvider.stream() with tools
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 2: LiteLLMProvider.stream() — with tools")
|
||||
print("=" * 60)
|
||||
try:
|
||||
from framework.llm.provider import Tool
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="get_weather",
|
||||
description="Get weather for a city",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
)
|
||||
]
|
||||
messages = [{"role": "user", "content": "What is the weather in SF?"}]
|
||||
chunk_count = 0
|
||||
text = ""
|
||||
tool_calls = []
|
||||
async for event in provider.stream(messages=messages, tools=tools):
|
||||
chunk_count += 1
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
tool_calls.append(
|
||||
{"name": event.tool_name, "input": event.tool_input}
|
||||
)
|
||||
print(f" ToolCall: {event.tool_name}({json.dumps(event.tool_input)})")
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f" Finish: stop={event.stop_reason} "
|
||||
f"in={event.input_tokens} out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
print(f" Text: {text!r}")
|
||||
print(f" Tool calls: {json.dumps(tool_calls, indent=2)}")
|
||||
print(f" Total events: {chunk_count}")
|
||||
status = "OK" if (text or tool_calls) else "EMPTY"
|
||||
print(f" RESULT: {status}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 3: acomplete() via provider (uses stream + collect)
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 3: LiteLLMProvider.acomplete() — round-trip")
|
||||
print("=" * 60)
|
||||
try:
|
||||
messages = [{"role": "user", "content": "What is 2+2? Reply with just the number."}]
|
||||
response = await provider.acomplete(messages=messages)
|
||||
print(f" Content: {response.content!r}")
|
||||
print(f" Model: {response.model}")
|
||||
print(f" Tokens: in={response.input_tokens} out={response.output_tokens}")
|
||||
print(f" Stop: {response.stop_reason}")
|
||||
print(" RESULT: OK" if response.content else " RESULT: EMPTY")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 4: Direct litellm.acompletion with metadata fix
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 4: Direct litellm.acompletion (with metadata={})")
|
||||
print("=" * 60)
|
||||
try:
|
||||
direct_kwargs = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Say hello in exactly 3 words."}],
|
||||
"stream": True,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"metadata": {}, # Prevent NoneType masking in error handler
|
||||
**extra_kwargs,
|
||||
}
|
||||
response = await litellm.acompletion(**direct_kwargs)
|
||||
chunk_count = 0
|
||||
text = ""
|
||||
async for chunk in response:
|
||||
chunk_count += 1
|
||||
choices = chunk.choices if chunk.choices else []
|
||||
delta = choices[0].delta if choices else None
|
||||
content = delta.content if delta and delta.content else ""
|
||||
if content:
|
||||
text += content
|
||||
finish = choices[0].finish_reason if choices else None
|
||||
if finish:
|
||||
print(f" finish_reason: {finish}")
|
||||
print(f" Text: {text!r}")
|
||||
print(f" Total chunks: {chunk_count}")
|
||||
print(" RESULT: OK" if text else " RESULT: EMPTY")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 5: Rapid-fire 3 calls via provider.stream()
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 5: Rapid-fire 3 calls via provider.stream()")
|
||||
print("=" * 60)
|
||||
for i in range(3):
|
||||
try:
|
||||
messages = [{"role": "user", "content": f"Say the number {i + 1}."}]
|
||||
text = ""
|
||||
async for event in provider.stream(messages=messages):
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" Call {i + 1}: StreamError: {event.error}")
|
||||
break
|
||||
status = f"OK ({len(text)} chars: {text!r})" if text else "EMPTY"
|
||||
print(f" Call {i + 1}: {status}")
|
||||
except Exception as e:
|
||||
print(f" Call {i + 1}: ERROR {type(e).__name__}: {e}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_codex_stream())
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Run Codex stream with litellm debug logging enabled.
|
||||
|
||||
Run: .venv/bin/python core/tests/debug_codex_verbose.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
import litellm # noqa: E402
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
from framework.config import get_api_base, get_api_key, get_llm_extra_kwargs # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.stream_events import ( # noqa: E402
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
api_key = get_api_key()
|
||||
api_base = get_api_base()
|
||||
extra_kwargs = get_llm_extra_kwargs()
|
||||
|
||||
if not api_key or not api_base:
|
||||
print("ERROR: No Codex config in ~/.hive/configuration.json")
|
||||
return
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openai/gpt-5.3-codex",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
print(f"_codex_backend={provider._codex_backend}")
|
||||
print()
|
||||
|
||||
text = ""
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "What is 2+2? Reply with just the number."}],
|
||||
system="You are a helpful assistant.",
|
||||
):
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, TextEndEvent):
|
||||
print(f"TextEnd: {event.full_text!r}")
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
print(f"ToolCall: {event.tool_name}({event.tool_input})")
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f"Finish: stop={event.stop_reason} "
|
||||
f"in={event.input_tokens} out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f"StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
|
||||
print(f"Text: {text!r}")
|
||||
print("OK" if text else "EMPTY")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,159 @@
|
||||
"""Integration test: Run a real EventLoopNode against the Codex backend.
|
||||
|
||||
Run: .venv/bin/python core/tests/test_codex_eventloop.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
logging.basicConfig(level=logging.WARNING, format="%(levelname)s %(name)s: %(message)s")
|
||||
# Show our provider's retry/stream logs
|
||||
logging.getLogger("framework.llm.litellm").setLevel(logging.DEBUG)
|
||||
|
||||
from framework.config import RuntimeConfig # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
|
||||
|
||||
def make_provider() -> LiteLLMProvider:
|
||||
cfg = RuntimeConfig()
|
||||
if not cfg.api_key:
|
||||
print("ERROR: No API key configured in ~/.hive/configuration.json")
|
||||
sys.exit(1)
|
||||
print(f"Model : {cfg.model}")
|
||||
print(f"Base : {cfg.api_base}")
|
||||
print(f"Codex : {'chatgpt.com/backend-api/codex' in (cfg.api_base or '')}")
|
||||
return LiteLLMProvider(
|
||||
model=cfg.model,
|
||||
api_key=cfg.api_key,
|
||||
api_base=cfg.api_base,
|
||||
**cfg.extra_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_context(
|
||||
llm: LiteLLMProvider,
|
||||
*,
|
||||
node_id: str = "test",
|
||||
system_prompt: str = "You are a helpful assistant.",
|
||||
output_keys: list[str] | None = None,
|
||||
) -> NodeContext:
|
||||
if output_keys is None:
|
||||
output_keys = ["answer"]
|
||||
|
||||
spec = NodeSpec(
|
||||
id=node_id,
|
||||
name="Test Node",
|
||||
description="Integration test node",
|
||||
node_type="event_loop",
|
||||
output_keys=output_keys,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.start_run = MagicMock(return_value="run-1")
|
||||
runtime.decide = MagicMock(return_value="dec-1")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
|
||||
memory = SharedMemory()
|
||||
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_id,
|
||||
node_spec=spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=llm,
|
||||
available_tools=[],
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
async def run_test(name: str, llm: LiteLLMProvider, system: str, output_keys: list[str]) -> NodeResult:
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"TEST: {name}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
ctx = make_context(llm, system_prompt=system, output_keys=output_keys)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=3))
|
||||
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
print(f" Success : {result.success}")
|
||||
print(f" Output : {result.output}")
|
||||
if result.error:
|
||||
print(f" Error : {result.error}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return NodeResult(success=False, error=str(e))
|
||||
|
||||
|
||||
async def main():
|
||||
llm = make_provider()
|
||||
print()
|
||||
|
||||
# Test 1: Simple text output — the node should call set_output to fill "answer"
|
||||
r1 = await run_test(
|
||||
name="Simple text generation",
|
||||
llm=llm,
|
||||
system=(
|
||||
"You are a helpful assistant. When asked a question, use the "
|
||||
"set_output tool to store your answer in the 'answer' key. "
|
||||
"Keep answers short (1-2 sentences)."
|
||||
),
|
||||
output_keys=["answer"],
|
||||
)
|
||||
|
||||
# Test 2: If test 1 failed, try bare stream() to isolate the issue
|
||||
if not r1.success:
|
||||
print(f"\n{'=' * 60}")
|
||||
print("FALLBACK: Testing bare provider.stream() directly")
|
||||
print(f"{'=' * 60}")
|
||||
try:
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
text = ""
|
||||
events = []
|
||||
async for event in llm.stream(
|
||||
messages=[{"role": "user", "content": "Say hello in 3 words."}],
|
||||
):
|
||||
events.append(type(event).__name__)
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(f" Finish: stop={event.stop_reason} in={event.input_tokens} out={event.output_tokens}")
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
print(f" ToolCall: {event.tool_name}")
|
||||
print(f" Text : {text!r}")
|
||||
print(f" Events : {events}")
|
||||
print(f" RESULT : {'OK' if text else 'EMPTY'}")
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("DONE")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,377 @@
|
||||
"""Test script: Codex vs OpenAI — tool call argument truncation repro.
|
||||
|
||||
Run: uv run python core/tests/test_two_llm_calls.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
OPENAI_API_KEY = "sk-*****"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions — mimic the real vulnerability_assessment agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCAN_TOOLS = [
|
||||
Tool(
|
||||
name="ssl_tls_scan",
|
||||
description="Scan SSL/TLS configuration for a hostname",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"hostname": {"type": "string", "description": "Domain name to scan"},
|
||||
"port": {"type": "integer", "description": "Port to connect to", "default": 443},
|
||||
},
|
||||
"required": ["hostname"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="http_headers_scan",
|
||||
description="Scan HTTP security headers for a URL",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "Full URL to scan"},
|
||||
"follow_redirects": {"type": "boolean", "default": True},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="dns_security_scan",
|
||||
description="Scan DNS security configuration for a domain",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {"type": "string", "description": "Domain name to scan"},
|
||||
},
|
||||
"required": ["domain"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="port_scan",
|
||||
description="Scan open ports for a hostname",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"hostname": {"type": "string", "description": "Domain or IP to scan"},
|
||||
"ports": {"type": "string", "default": "top20"},
|
||||
"timeout": {"type": "number", "default": 3.0},
|
||||
},
|
||||
"required": ["hostname"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="tech_stack_detect",
|
||||
description="Detect technology stack for a URL",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "URL to analyze"},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="subdomain_enumerate",
|
||||
description="Enumerate subdomains for a domain",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {"type": "string", "description": "Base domain"},
|
||||
"max_results": {"type": "integer", "default": 50},
|
||||
},
|
||||
"required": ["domain"],
|
||||
},
|
||||
),
|
||||
# The big one — takes 6 JSON-string params (whole scan results)
|
||||
Tool(
|
||||
name="set_output",
|
||||
description="Set the output for this node. Call this when you are done. scan_results must be a JSON string containing the full consolidated results from all scans.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"scan_results": {
|
||||
"type": "string",
|
||||
"description": "JSON string with consolidated scan results including ssl, headers, dns, ports, tech, and subdomain data.",
|
||||
},
|
||||
},
|
||||
"required": ["scan_results"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
# Fake scan results — realistic size to stress-test argument streaming
|
||||
FAKE_SSL_RESULT = {
|
||||
"hostname": "example.com", "port": 443, "tls_version": "TLSv1.3",
|
||||
"cipher": "TLS_AES_256_GCM_SHA384", "cipher_bits": 256,
|
||||
"certificate": {
|
||||
"subject": "CN=example.com", "issuer": "CN=Let's Encrypt Authority X3",
|
||||
"not_before": "2025-01-01T00:00:00Z", "not_after": "2026-01-01T00:00:00Z",
|
||||
"days_until_expiry": 310, "san": ["example.com", "www.example.com"],
|
||||
"self_signed": False, "sha256_fingerprint": "AB:CD:EF:12:34:56:78:90",
|
||||
},
|
||||
"issues": [
|
||||
{"severity": "low", "finding": "Certificate expiring in 310 days", "remediation": "Monitor expiry"},
|
||||
],
|
||||
"grade_input": {"tls_version_ok": True, "cert_valid": True, "cert_expiring_soon": False, "strong_cipher": True, "self_signed": False},
|
||||
}
|
||||
|
||||
FAKE_HEADERS_RESULT = {
|
||||
"url": "https://example.com", "status_code": 200,
|
||||
"headers_present": ["Strict-Transport-Security", "X-Content-Type-Options"],
|
||||
"headers_missing": [
|
||||
{"header": "Content-Security-Policy", "severity": "high", "description": "No CSP header", "remediation": "Add CSP header"},
|
||||
{"header": "X-Frame-Options", "severity": "medium", "description": "No X-Frame-Options", "remediation": "Add DENY or SAMEORIGIN"},
|
||||
{"header": "Permissions-Policy", "severity": "low", "description": "No Permissions-Policy", "remediation": "Add Permissions-Policy"},
|
||||
],
|
||||
"leaky_headers": [
|
||||
{"header": "Server", "value": "nginx/1.21.0", "severity": "low", "remediation": "Remove server version"},
|
||||
],
|
||||
"grade_input": {"hsts": True, "csp": False, "x_frame_options": False, "x_content_type_options": True, "referrer_policy": False, "permissions_policy": False, "no_leaky_headers": False},
|
||||
}
|
||||
|
||||
FAKE_DNS_RESULT = {
|
||||
"domain": "example.com", "source": "crt.sh",
|
||||
"spf": {"present": True, "record": "v=spf1 include:_spf.google.com ~all", "policy": "softfail", "issues": []},
|
||||
"dmarc": {"present": True, "record": "v=DMARC1; p=reject; rua=mailto:dmarc@example.com", "policy": "reject", "issues": []},
|
||||
"dkim": {"selectors_found": ["google", "default"], "selectors_missing": []},
|
||||
"dnssec": {"enabled": False, "issues": [{"severity": "medium", "finding": "DNSSEC not enabled"}]},
|
||||
"mx_records": ["10 mail.example.com"],
|
||||
"caa_records": ["0 issue letsencrypt.org"],
|
||||
"zone_transfer": {"vulnerable": False},
|
||||
"grade_input": {"spf_present": True, "spf_strict": False, "dmarc_present": True, "dmarc_enforcing": True, "dkim_found": True, "dnssec_enabled": False, "zone_transfer_blocked": True},
|
||||
}
|
||||
|
||||
FAKE_PORTS_RESULT = {
|
||||
"hostname": "example.com", "ip": "93.184.216.34", "ports_scanned": 20,
|
||||
"open_ports": [
|
||||
{"port": 80, "service": "http", "banner": "nginx/1.21.0"},
|
||||
{"port": 443, "service": "https", "banner": "nginx/1.21.0"},
|
||||
{"port": 22, "service": "ssh", "banner": "OpenSSH_8.9", "severity": "medium", "finding": "SSH port open", "remediation": "Restrict SSH access"},
|
||||
],
|
||||
"closed_ports": [21, 23, 25, 53, 110, 143, 993, 995, 3306, 5432, 6379, 8080, 8443, 27017],
|
||||
"grade_input": {"no_database_ports_exposed": True, "no_admin_ports_exposed": False, "no_legacy_ports_exposed": True, "only_web_ports": False},
|
||||
}
|
||||
|
||||
FAKE_TECH_RESULT = {
|
||||
"url": "https://example.com",
|
||||
"server": {"name": "nginx", "version": "1.21.0", "raw": "nginx/1.21.0"},
|
||||
"framework": "React", "language": "JavaScript", "cms": None,
|
||||
"javascript_libraries": ["react-18.2.0", "lodash-4.17.21", "axios-1.6.0"],
|
||||
"cdn": "Cloudflare", "analytics": ["Google Analytics"],
|
||||
"security_txt": True, "robots_txt": True,
|
||||
"interesting_paths": ["/admin", "/.env", "/api/docs"],
|
||||
"cookies": [
|
||||
{"name": "session", "secure": True, "httponly": True, "samesite": "Strict"},
|
||||
{"name": "_ga", "secure": False, "httponly": False, "samesite": "None"},
|
||||
],
|
||||
"grade_input": {"server_version_hidden": False, "framework_version_hidden": True, "security_txt_present": True, "cookies_secure": False, "cookies_httponly": False},
|
||||
}
|
||||
|
||||
FAKE_SUBDOMAIN_RESULT = {
|
||||
"domain": "example.com", "source": "crt.sh", "total_found": 8,
|
||||
"subdomains": ["www.example.com", "mail.example.com", "api.example.com", "staging.example.com", "dev.example.com", "admin.example.com", "cdn.example.com", "blog.example.com"],
|
||||
"interesting": [
|
||||
{"subdomain": "staging.example.com", "reason": "staging environment exposed", "severity": "high", "remediation": "Restrict access"},
|
||||
{"subdomain": "dev.example.com", "reason": "development environment exposed", "severity": "high", "remediation": "Restrict access"},
|
||||
{"subdomain": "admin.example.com", "reason": "admin panel exposed", "severity": "medium", "remediation": "Add IP restriction"},
|
||||
],
|
||||
"grade_input": {"no_dev_staging_exposed": False, "no_admin_exposed": False, "reasonable_surface_area": True},
|
||||
}
|
||||
|
||||
|
||||
def _make_codex_provider():
|
||||
from framework.config import get_api_base, get_api_key, get_llm_extra_kwargs
|
||||
api_key = get_api_key()
|
||||
api_base = get_api_base()
|
||||
extra_kwargs = get_llm_extra_kwargs()
|
||||
if not api_key or not api_base:
|
||||
return None
|
||||
return LiteLLMProvider(
|
||||
model="openai/gpt-5.3-codex",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def _stream_and_collect(provider, messages, system, tools):
|
||||
"""Stream a call, collect text + tool calls, print events. Returns (text, tool_calls)."""
|
||||
text = ""
|
||||
tool_calls: list[ToolCallEvent] = []
|
||||
async for event in provider.stream(messages=messages, system=system, tools=tools):
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
tool_calls.append(event)
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(f" finish: stop={event.stop_reason} in={event.input_tokens} out={event.output_tokens}")
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" STREAM ERROR: {event.error}")
|
||||
return text, tool_calls
|
||||
return text, tool_calls
|
||||
|
||||
|
||||
def _validate_tool_args(tool_calls: list[ToolCallEvent]) -> bool:
|
||||
"""Check that every tool call has valid, non-truncated JSON arguments."""
|
||||
ok = True
|
||||
for tc in tool_calls:
|
||||
print(f" ToolCall: {tc.tool_name} id={tc.tool_use_id}")
|
||||
args = tc.tool_input
|
||||
|
||||
# Check for the _raw fallback (means JSON parse failed → truncated)
|
||||
if "_raw" in args:
|
||||
print(f" TRUNCATED — raw args: {args['_raw'][:200]}...")
|
||||
ok = False
|
||||
continue
|
||||
|
||||
# For set_output, validate the nested JSON string
|
||||
if tc.tool_name == "set_output" and "scan_results" in args:
|
||||
raw_json = args["scan_results"]
|
||||
print(f" scan_results length: {len(raw_json)} chars")
|
||||
try:
|
||||
parsed = json.loads(raw_json)
|
||||
keys = list(parsed.keys()) if isinstance(parsed, dict) else "not-a-dict"
|
||||
print(f" parsed OK — keys: {keys}")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" INVALID JSON in scan_results: {e}")
|
||||
print(f" tail: ...{raw_json[-200:]}")
|
||||
ok = False
|
||||
else:
|
||||
print(f" args: {json.dumps(args)}")
|
||||
return ok
|
||||
|
||||
|
||||
async def test_codex_multi_tool_scan():
|
||||
"""Reproduce the real agent flow: LLM calls 6 scan tools, then set_output with big JSON."""
|
||||
provider = _make_codex_provider()
|
||||
if not provider:
|
||||
print("[scan] SKIP — no Codex subscription")
|
||||
return
|
||||
|
||||
system = (
|
||||
"You are a security scanning agent. You have access to scanning tools.\n"
|
||||
"The user will give you scan results. Your job is to consolidate them and "
|
||||
"call set_output with a JSON string containing ALL the scan results.\n"
|
||||
"The scan_results value MUST be a valid JSON string containing every scan result provided.\n"
|
||||
"Do NOT summarize — include the complete data from each scan."
|
||||
)
|
||||
|
||||
# Provide all scan results as tool_result messages so the LLM has to
|
||||
# consolidate them into one big set_output call.
|
||||
all_results = {
|
||||
"ssl": FAKE_SSL_RESULT,
|
||||
"headers": FAKE_HEADERS_RESULT,
|
||||
"dns": FAKE_DNS_RESULT,
|
||||
"ports": FAKE_PORTS_RESULT,
|
||||
"tech": FAKE_TECH_RESULT,
|
||||
"subdomains": FAKE_SUBDOMAIN_RESULT,
|
||||
}
|
||||
results_json = json.dumps(all_results, indent=2)
|
||||
print(f" Input scan data size: {len(results_json)} chars")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Here are the completed scan results for example.com. "
|
||||
"Consolidate ALL of them into a single set_output call. "
|
||||
"The scan_results argument must be a JSON string containing the complete data.\n\n"
|
||||
f"```json\n{results_json}\n```"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# --- Turn 1: expect set_output tool call with big JSON ---
|
||||
text, tool_calls = await _stream_and_collect(provider, messages, system, SCAN_TOOLS)
|
||||
|
||||
if text:
|
||||
print(f" text: {text[:200]}{'...' if len(text) > 200 else ''}")
|
||||
|
||||
if not tool_calls:
|
||||
print(" NO TOOL CALLS — expected set_output")
|
||||
print(f" full text: {text}")
|
||||
return
|
||||
|
||||
valid = _validate_tool_args(tool_calls)
|
||||
print(f" RESULT: {'OK' if valid else 'TRUNCATED/MALFORMED'}")
|
||||
|
||||
|
||||
async def test_codex_parallel_tool_calls():
|
||||
"""Ask the LLM to call multiple scan tools at once — tests parallel tool call streaming."""
|
||||
provider = _make_codex_provider()
|
||||
if not provider:
|
||||
print("[parallel] SKIP — no Codex subscription")
|
||||
return
|
||||
|
||||
system = (
|
||||
"You are a security scanning agent. When asked to scan a target, "
|
||||
"call ALL relevant scanning tools in parallel in a single response. "
|
||||
"Always call: ssl_tls_scan, http_headers_scan, dns_security_scan, "
|
||||
"port_scan, tech_stack_detect, and subdomain_enumerate."
|
||||
)
|
||||
messages = [
|
||||
{"role": "user", "content": "Run a full security scan on example.com"},
|
||||
]
|
||||
|
||||
text, tool_calls = await _stream_and_collect(provider, messages, system, SCAN_TOOLS)
|
||||
|
||||
if text:
|
||||
print(f" text: {text[:200]}{'...' if len(text) > 200 else ''}")
|
||||
|
||||
print(f" Total tool calls: {len(tool_calls)}")
|
||||
valid = _validate_tool_args(tool_calls)
|
||||
print(f" RESULT: {'OK' if valid else 'TRUNCATED/MALFORMED'}")
|
||||
|
||||
|
||||
async def test_openai_baseline():
|
||||
"""OpenAI direct — baseline to compare against."""
|
||||
api_key = OPENAI_API_KEY or os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("[openai] SKIP — OPENAI_API_KEY not set")
|
||||
return
|
||||
|
||||
provider = LiteLLMProvider(model="openai/gpt-4o-mini", api_key=api_key)
|
||||
messages = [{"role": "user", "content": "What is 3+3? Reply with just the number."}]
|
||||
response = await provider.acomplete(messages=messages, max_tokens=64)
|
||||
print(f" Response: {response.content!r}")
|
||||
print(f" tokens: in={response.input_tokens} out={response.output_tokens}")
|
||||
print(f" RESULT: {'OK' if response.content else 'EMPTY'}")
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 60)
|
||||
print("Test 1: Codex — parallel tool calls (6 scan tools)")
|
||||
print("=" * 60)
|
||||
await test_codex_parallel_tool_calls()
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
print("Test 2: Codex — big set_output call (~4KB JSON arg)")
|
||||
print("=" * 60)
|
||||
await test_codex_multi_tool_scan()
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
print("Test 3: OpenAI direct — baseline")
|
||||
print("=" * 60)
|
||||
await test_openai_baseline()
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user