fix(wip): codex tool use bug fixes

This commit is contained in:
Richard Tang
2026-02-25 20:09:49 -08:00
parent 8294cd3dd9
commit 0b83f6ea99
5 changed files with 909 additions and 1 deletions
+78 -1
View File
@@ -70,8 +70,49 @@ def _patch_litellm_anthropic_oauth() -> None:
AnthropicModelInfo.validate_environment = _patched_validate_environment
def _patch_litellm_metadata_nonetype() -> None:
"""Patch litellm entry points to prevent metadata=None TypeError.
litellm bug: the @client decorator in utils.py has four places that do
"model_group" in kwargs.get("metadata", {})
but kwargs["metadata"] can be explicitly None (set internally by
litellm_params), causing:
TypeError: argument of type 'NoneType' is not iterable
This masks the real API error with a confusing APIConnectionError.
Fix: wrap the four litellm entry points (completion, acompletion,
responses, aresponses) to pop metadata=None before the @client
decorator's error handler can crash on it.
"""
import functools
for fn_name in ("completion", "acompletion", "responses", "aresponses"):
original = getattr(litellm, fn_name, None)
if original is None:
continue
if asyncio.iscoroutinefunction(original):
@functools.wraps(original)
async def _async_wrapper(*args, _orig=original, **kwargs):
if kwargs.get("metadata") is None:
kwargs.pop("metadata", None)
return await _orig(*args, **kwargs)
setattr(litellm, fn_name, _async_wrapper)
else:
@functools.wraps(original)
def _sync_wrapper(*args, _orig=original, **kwargs):
if kwargs.get("metadata") is None:
kwargs.pop("metadata", None)
return _orig(*args, **kwargs)
setattr(litellm, fn_name, _sync_wrapper)
if litellm is not None:
_patch_litellm_anthropic_oauth()
_patch_litellm_metadata_nonetype()
RATE_LIMIT_MAX_RETRIES = 10
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
@@ -284,6 +325,12 @@ class LiteLLMProvider(LLMProvider):
"LiteLLM is not installed. Please install it with: uv pip install litellm"
)
# Note: The Codex ChatGPT backend is a Responses API endpoint at
# chatgpt.com/backend-api/codex/responses. LiteLLM's model registry
# correctly marks codex models with mode="responses", so we do NOT
# override the mode. The responses_api_bridge in litellm handles
# converting Chat Completions requests to Responses API format.
def _completion_with_rate_limit_retry(
self, max_retries: int | None = None, **kwargs: Any
) -> Any:
@@ -708,6 +755,11 @@ class LiteLLMProvider(LLMProvider):
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
# Codex Responses API requires an `instructions` field (system prompt).
# Inject a minimal one when callers don't provide a system message.
if self._codex_backend and not any(m["role"] == "system" for m in full_messages):
full_messages.insert(0, {"role": "system", "content": "You are a helpful assistant."})
# Add JSON mode via prompt engineering (works across all providers)
if json_mode:
json_instruction = "\n\nPlease respond with a valid JSON object."
@@ -732,7 +784,7 @@ class LiteLLMProvider(LLMProvider):
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
if response_format:
kwargs["response_format"] = response_format
# The Codex ChatGPT backend rejects max_output_tokens and stream_options.
# The Codex ChatGPT backend (Responses API) rejects several params.
if self._codex_backend:
kwargs.pop("max_tokens", None)
kwargs.pop("stream_options", None)
@@ -744,6 +796,7 @@ class LiteLLMProvider(LLMProvider):
tail_events: list[StreamEvent] = []
accumulated_text = ""
tool_calls_acc: dict[int, dict[str, str]] = {}
_last_tool_idx = 0 # tracks most recently opened tool call slot
input_tokens = 0
output_tokens = 0
stream_finish_reason: str | None = None
@@ -767,9 +820,33 @@ class LiteLLMProvider(LLMProvider):
)
# --- Tool calls (accumulate across chunks) ---
# The Codex/Responses API bridge (litellm bug) hardcodes
# index=0 on every ChatCompletionToolCallChunk, even for
# parallel tool calls. We work around this by using tc.id
# (set on output_item.added events) as a "new tool call"
# signal and tracking the most recently opened slot for
# argument deltas that arrive with id=None.
if delta and delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
if tc.id:
# New tool call announced (or done event re-sent).
# Check if this id already has a slot.
existing_idx = next(
(k for k, v in tool_calls_acc.items() if v["id"] == tc.id),
None,
)
if existing_idx is not None:
idx = existing_idx
elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in ("", tc.id):
# Slot taken by a different call — assign new index
idx = max(tool_calls_acc.keys()) + 1
_last_tool_idx = idx
else:
# Argument delta with no id — route to last opened slot
idx = _last_tool_idx
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
if tc.id:
+226
View File
@@ -0,0 +1,226 @@
"""Diagnostic script to reproduce and trace Codex streaming errors.
Run: .venv/bin/python core/tests/debug_codex_stream.py
"""
import asyncio
import json
import sys
import traceback
sys.path.insert(0, "core")
import litellm # noqa: E402
# Enable litellm debug logging to see the raw HTTP exchange
litellm._turn_on_debug()
async def test_codex_stream():
"""Minimal Codex streaming call via LiteLLMProvider (Responses API path)."""
from framework.config import get_api_base, get_api_key, get_llm_extra_kwargs
from framework.llm.litellm import LiteLLMProvider
api_key = get_api_key()
api_base = get_api_base()
extra_kwargs = get_llm_extra_kwargs()
if not api_key or not api_base:
print("ERROR: No Codex subscription configured in ~/.hive/configuration.json")
return
print(f"api_base: {api_base}")
print(f"extra_kwargs keys: {list(extra_kwargs.keys())}")
print(f"extra_headers: {list(extra_kwargs.get('extra_headers', {}).keys())}")
model = "openai/gpt-5.3-codex"
# Create the provider
provider = LiteLLMProvider(
model=model,
api_key=api_key,
api_base=api_base,
**extra_kwargs,
)
print(f"_codex_backend: {provider._codex_backend}")
# Verify mode is "responses" (the correct routing for Codex backend)
_strip = model.removeprefix("openai/")
mode = litellm.model_cost.get(_strip, {}).get("mode", "NOT SET")
print(f"litellm.model_cost['{_strip}']['mode']: {mode}")
if mode != "responses":
print(" WARNING: Expected mode='responses' for Codex backend!")
print()
# -----------------------------------------------------------
# Test 1: Stream via LiteLLMProvider.stream() (the real code path)
# -----------------------------------------------------------
print("=" * 60)
print("TEST 1: LiteLLMProvider.stream() — basic text")
print("=" * 60)
try:
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
messages = [{"role": "user", "content": "Say hello in exactly 3 words."}]
chunk_count = 0
text = ""
async for event in provider.stream(messages=messages):
chunk_count += 1
if isinstance(event, TextDeltaEvent):
text = event.snapshot
elif isinstance(event, TextEndEvent):
print(f" TextEnd: {event.full_text!r}")
elif isinstance(event, ToolCallEvent):
print(f" ToolCall: {event.tool_name}({event.tool_input})")
elif isinstance(event, FinishEvent):
print(
f" Finish: stop={event.stop_reason} "
f"in={event.input_tokens} out={event.output_tokens}"
)
elif isinstance(event, StreamErrorEvent):
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
print(f" Text: {text!r}")
print(f" Total events: {chunk_count}")
print(" RESULT: OK" if text else " RESULT: EMPTY")
except Exception as e:
print(f" ERROR: {type(e).__name__}: {e}")
traceback.print_exc()
print()
# -----------------------------------------------------------
# Test 2: Stream via LiteLLMProvider.stream() with tools
# -----------------------------------------------------------
print("=" * 60)
print("TEST 2: LiteLLMProvider.stream() — with tools")
print("=" * 60)
try:
from framework.llm.provider import Tool
tools = [
Tool(
name="get_weather",
description="Get weather for a city",
parameters={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
)
]
messages = [{"role": "user", "content": "What is the weather in SF?"}]
chunk_count = 0
text = ""
tool_calls = []
async for event in provider.stream(messages=messages, tools=tools):
chunk_count += 1
if isinstance(event, TextDeltaEvent):
text = event.snapshot
elif isinstance(event, ToolCallEvent):
tool_calls.append(
{"name": event.tool_name, "input": event.tool_input}
)
print(f" ToolCall: {event.tool_name}({json.dumps(event.tool_input)})")
elif isinstance(event, FinishEvent):
print(
f" Finish: stop={event.stop_reason} "
f"in={event.input_tokens} out={event.output_tokens}"
)
elif isinstance(event, StreamErrorEvent):
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
print(f" Text: {text!r}")
print(f" Tool calls: {json.dumps(tool_calls, indent=2)}")
print(f" Total events: {chunk_count}")
status = "OK" if (text or tool_calls) else "EMPTY"
print(f" RESULT: {status}")
except Exception as e:
print(f" ERROR: {type(e).__name__}: {e}")
traceback.print_exc()
print()
# -----------------------------------------------------------
# Test 3: acomplete() via provider (uses stream + collect)
# -----------------------------------------------------------
print("=" * 60)
print("TEST 3: LiteLLMProvider.acomplete() — round-trip")
print("=" * 60)
try:
messages = [{"role": "user", "content": "What is 2+2? Reply with just the number."}]
response = await provider.acomplete(messages=messages)
print(f" Content: {response.content!r}")
print(f" Model: {response.model}")
print(f" Tokens: in={response.input_tokens} out={response.output_tokens}")
print(f" Stop: {response.stop_reason}")
print(" RESULT: OK" if response.content else " RESULT: EMPTY")
except Exception as e:
print(f" ERROR: {type(e).__name__}: {e}")
traceback.print_exc()
print()
# -----------------------------------------------------------
# Test 4: Direct litellm.acompletion with metadata fix
# -----------------------------------------------------------
print("=" * 60)
print("TEST 4: Direct litellm.acompletion (with metadata={})")
print("=" * 60)
try:
direct_kwargs = {
"model": model,
"messages": [{"role": "user", "content": "Say hello in exactly 3 words."}],
"stream": True,
"api_key": api_key,
"api_base": api_base,
"metadata": {}, # Prevent NoneType masking in error handler
**extra_kwargs,
}
response = await litellm.acompletion(**direct_kwargs)
chunk_count = 0
text = ""
async for chunk in response:
chunk_count += 1
choices = chunk.choices if chunk.choices else []
delta = choices[0].delta if choices else None
content = delta.content if delta and delta.content else ""
if content:
text += content
finish = choices[0].finish_reason if choices else None
if finish:
print(f" finish_reason: {finish}")
print(f" Text: {text!r}")
print(f" Total chunks: {chunk_count}")
print(" RESULT: OK" if text else " RESULT: EMPTY")
except Exception as e:
print(f" ERROR: {type(e).__name__}: {e}")
traceback.print_exc()
print()
# -----------------------------------------------------------
# Test 5: Rapid-fire 3 calls via provider.stream()
# -----------------------------------------------------------
print("=" * 60)
print("TEST 5: Rapid-fire 3 calls via provider.stream()")
print("=" * 60)
for i in range(3):
try:
messages = [{"role": "user", "content": f"Say the number {i + 1}."}]
text = ""
async for event in provider.stream(messages=messages):
if isinstance(event, TextDeltaEvent):
text = event.snapshot
elif isinstance(event, StreamErrorEvent):
print(f" Call {i + 1}: StreamError: {event.error}")
break
status = f"OK ({len(text)} chars: {text!r})" if text else "EMPTY"
print(f" Call {i + 1}: {status}")
except Exception as e:
print(f" Call {i + 1}: ERROR {type(e).__name__}: {e}")
print()
if __name__ == "__main__":
asyncio.run(test_codex_stream())
+69
View File
@@ -0,0 +1,69 @@
"""Run Codex stream with litellm debug logging enabled.
Run: .venv/bin/python core/tests/debug_codex_verbose.py
"""
import asyncio
import sys
sys.path.insert(0, "core")
import litellm # noqa: E402
litellm._turn_on_debug()
from framework.config import get_api_base, get_api_key, get_llm_extra_kwargs # noqa: E402
from framework.llm.litellm import LiteLLMProvider # noqa: E402
from framework.llm.stream_events import ( # noqa: E402
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
async def main():
api_key = get_api_key()
api_base = get_api_base()
extra_kwargs = get_llm_extra_kwargs()
if not api_key or not api_base:
print("ERROR: No Codex config in ~/.hive/configuration.json")
return
provider = LiteLLMProvider(
model="openai/gpt-5.3-codex",
api_key=api_key,
api_base=api_base,
**extra_kwargs,
)
print(f"_codex_backend={provider._codex_backend}")
print()
text = ""
async for event in provider.stream(
messages=[{"role": "user", "content": "What is 2+2? Reply with just the number."}],
system="You are a helpful assistant.",
):
if isinstance(event, TextDeltaEvent):
text = event.snapshot
elif isinstance(event, TextEndEvent):
print(f"TextEnd: {event.full_text!r}")
elif isinstance(event, ToolCallEvent):
print(f"ToolCall: {event.tool_name}({event.tool_input})")
elif isinstance(event, FinishEvent):
print(
f"Finish: stop={event.stop_reason} "
f"in={event.input_tokens} out={event.output_tokens}"
)
elif isinstance(event, StreamErrorEvent):
print(f"StreamError: {event.error} (recoverable={event.recoverable})")
print(f"Text: {text!r}")
print("OK" if text else "EMPTY")
if __name__ == "__main__":
asyncio.run(main())
+159
View File
@@ -0,0 +1,159 @@
"""Integration test: Run a real EventLoopNode against the Codex backend.
Run: .venv/bin/python core/tests/test_codex_eventloop.py
"""
import asyncio
import logging
import sys
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import MagicMock
sys.path.insert(0, "core")
logging.basicConfig(level=logging.WARNING, format="%(levelname)s %(name)s: %(message)s")
# Show our provider's retry/stream logs
logging.getLogger("framework.llm.litellm").setLevel(logging.DEBUG)
from framework.config import RuntimeConfig # noqa: E402
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory # noqa: E402
from framework.llm.litellm import LiteLLMProvider # noqa: E402
def make_provider() -> LiteLLMProvider:
cfg = RuntimeConfig()
if not cfg.api_key:
print("ERROR: No API key configured in ~/.hive/configuration.json")
sys.exit(1)
print(f"Model : {cfg.model}")
print(f"Base : {cfg.api_base}")
print(f"Codex : {'chatgpt.com/backend-api/codex' in (cfg.api_base or '')}")
return LiteLLMProvider(
model=cfg.model,
api_key=cfg.api_key,
api_base=cfg.api_base,
**cfg.extra_kwargs,
)
def make_context(
llm: LiteLLMProvider,
*,
node_id: str = "test",
system_prompt: str = "You are a helpful assistant.",
output_keys: list[str] | None = None,
) -> NodeContext:
if output_keys is None:
output_keys = ["answer"]
spec = NodeSpec(
id=node_id,
name="Test Node",
description="Integration test node",
node_type="event_loop",
output_keys=output_keys,
system_prompt=system_prompt,
)
runtime = MagicMock()
runtime.start_run = MagicMock(return_value="run-1")
runtime.decide = MagicMock(return_value="dec-1")
runtime.record_outcome = MagicMock()
runtime.end_run = MagicMock()
memory = SharedMemory()
return NodeContext(
runtime=runtime,
node_id=node_id,
node_spec=spec,
memory=memory,
input_data={},
llm=llm,
available_tools=[],
max_tokens=4096,
)
async def run_test(name: str, llm: LiteLLMProvider, system: str, output_keys: list[str]) -> NodeResult:
print(f"\n{'=' * 60}")
print(f"TEST: {name}")
print(f"{'=' * 60}")
ctx = make_context(llm, system_prompt=system, output_keys=output_keys)
node = EventLoopNode(config=LoopConfig(max_iterations=3))
try:
result = await node.execute(ctx)
print(f" Success : {result.success}")
print(f" Output : {result.output}")
if result.error:
print(f" Error : {result.error}")
return result
except Exception as e:
print(f" EXCEPTION: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return NodeResult(success=False, error=str(e))
async def main():
llm = make_provider()
print()
# Test 1: Simple text output — the node should call set_output to fill "answer"
r1 = await run_test(
name="Simple text generation",
llm=llm,
system=(
"You are a helpful assistant. When asked a question, use the "
"set_output tool to store your answer in the 'answer' key. "
"Keep answers short (1-2 sentences)."
),
output_keys=["answer"],
)
# Test 2: If test 1 failed, try bare stream() to isolate the issue
if not r1.success:
print(f"\n{'=' * 60}")
print("FALLBACK: Testing bare provider.stream() directly")
print(f"{'=' * 60}")
try:
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
ToolCallEvent,
)
text = ""
events = []
async for event in llm.stream(
messages=[{"role": "user", "content": "Say hello in 3 words."}],
):
events.append(type(event).__name__)
if isinstance(event, TextDeltaEvent):
text = event.snapshot
elif isinstance(event, FinishEvent):
print(f" Finish: stop={event.stop_reason} in={event.input_tokens} out={event.output_tokens}")
elif isinstance(event, StreamErrorEvent):
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
elif isinstance(event, ToolCallEvent):
print(f" ToolCall: {event.tool_name}")
print(f" Text : {text!r}")
print(f" Events : {events}")
print(f" RESULT : {'OK' if text else 'EMPTY'}")
except Exception as e:
print(f" EXCEPTION: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
print(f"\n{'=' * 60}")
print("DONE")
print(f"{'=' * 60}")
if __name__ == "__main__":
asyncio.run(main())
+377
View File
@@ -0,0 +1,377 @@
"""Test script: Codex vs OpenAI — tool call argument truncation repro.
Run: uv run python core/tests/test_two_llm_calls.py
"""
import asyncio
import json
import os
import sys
sys.path.insert(0, "core")
from framework.llm.litellm import LiteLLMProvider
from framework.llm.provider import Tool
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
ToolCallEvent,
)
OPENAI_API_KEY = "sk-*****"
# ---------------------------------------------------------------------------
# Tool definitions — mimic the real vulnerability_assessment agent
# ---------------------------------------------------------------------------
SCAN_TOOLS = [
Tool(
name="ssl_tls_scan",
description="Scan SSL/TLS configuration for a hostname",
parameters={
"type": "object",
"properties": {
"hostname": {"type": "string", "description": "Domain name to scan"},
"port": {"type": "integer", "description": "Port to connect to", "default": 443},
},
"required": ["hostname"],
},
),
Tool(
name="http_headers_scan",
description="Scan HTTP security headers for a URL",
parameters={
"type": "object",
"properties": {
"url": {"type": "string", "description": "Full URL to scan"},
"follow_redirects": {"type": "boolean", "default": True},
},
"required": ["url"],
},
),
Tool(
name="dns_security_scan",
description="Scan DNS security configuration for a domain",
parameters={
"type": "object",
"properties": {
"domain": {"type": "string", "description": "Domain name to scan"},
},
"required": ["domain"],
},
),
Tool(
name="port_scan",
description="Scan open ports for a hostname",
parameters={
"type": "object",
"properties": {
"hostname": {"type": "string", "description": "Domain or IP to scan"},
"ports": {"type": "string", "default": "top20"},
"timeout": {"type": "number", "default": 3.0},
},
"required": ["hostname"],
},
),
Tool(
name="tech_stack_detect",
description="Detect technology stack for a URL",
parameters={
"type": "object",
"properties": {
"url": {"type": "string", "description": "URL to analyze"},
},
"required": ["url"],
},
),
Tool(
name="subdomain_enumerate",
description="Enumerate subdomains for a domain",
parameters={
"type": "object",
"properties": {
"domain": {"type": "string", "description": "Base domain"},
"max_results": {"type": "integer", "default": 50},
},
"required": ["domain"],
},
),
# The big one — takes 6 JSON-string params (whole scan results)
Tool(
name="set_output",
description="Set the output for this node. Call this when you are done. scan_results must be a JSON string containing the full consolidated results from all scans.",
parameters={
"type": "object",
"properties": {
"scan_results": {
"type": "string",
"description": "JSON string with consolidated scan results including ssl, headers, dns, ports, tech, and subdomain data.",
},
},
"required": ["scan_results"],
},
),
]
# Fake scan results — realistic size to stress-test argument streaming
FAKE_SSL_RESULT = {
"hostname": "example.com", "port": 443, "tls_version": "TLSv1.3",
"cipher": "TLS_AES_256_GCM_SHA384", "cipher_bits": 256,
"certificate": {
"subject": "CN=example.com", "issuer": "CN=Let's Encrypt Authority X3",
"not_before": "2025-01-01T00:00:00Z", "not_after": "2026-01-01T00:00:00Z",
"days_until_expiry": 310, "san": ["example.com", "www.example.com"],
"self_signed": False, "sha256_fingerprint": "AB:CD:EF:12:34:56:78:90",
},
"issues": [
{"severity": "low", "finding": "Certificate expiring in 310 days", "remediation": "Monitor expiry"},
],
"grade_input": {"tls_version_ok": True, "cert_valid": True, "cert_expiring_soon": False, "strong_cipher": True, "self_signed": False},
}
FAKE_HEADERS_RESULT = {
"url": "https://example.com", "status_code": 200,
"headers_present": ["Strict-Transport-Security", "X-Content-Type-Options"],
"headers_missing": [
{"header": "Content-Security-Policy", "severity": "high", "description": "No CSP header", "remediation": "Add CSP header"},
{"header": "X-Frame-Options", "severity": "medium", "description": "No X-Frame-Options", "remediation": "Add DENY or SAMEORIGIN"},
{"header": "Permissions-Policy", "severity": "low", "description": "No Permissions-Policy", "remediation": "Add Permissions-Policy"},
],
"leaky_headers": [
{"header": "Server", "value": "nginx/1.21.0", "severity": "low", "remediation": "Remove server version"},
],
"grade_input": {"hsts": True, "csp": False, "x_frame_options": False, "x_content_type_options": True, "referrer_policy": False, "permissions_policy": False, "no_leaky_headers": False},
}
FAKE_DNS_RESULT = {
"domain": "example.com", "source": "crt.sh",
"spf": {"present": True, "record": "v=spf1 include:_spf.google.com ~all", "policy": "softfail", "issues": []},
"dmarc": {"present": True, "record": "v=DMARC1; p=reject; rua=mailto:dmarc@example.com", "policy": "reject", "issues": []},
"dkim": {"selectors_found": ["google", "default"], "selectors_missing": []},
"dnssec": {"enabled": False, "issues": [{"severity": "medium", "finding": "DNSSEC not enabled"}]},
"mx_records": ["10 mail.example.com"],
"caa_records": ["0 issue letsencrypt.org"],
"zone_transfer": {"vulnerable": False},
"grade_input": {"spf_present": True, "spf_strict": False, "dmarc_present": True, "dmarc_enforcing": True, "dkim_found": True, "dnssec_enabled": False, "zone_transfer_blocked": True},
}
FAKE_PORTS_RESULT = {
"hostname": "example.com", "ip": "93.184.216.34", "ports_scanned": 20,
"open_ports": [
{"port": 80, "service": "http", "banner": "nginx/1.21.0"},
{"port": 443, "service": "https", "banner": "nginx/1.21.0"},
{"port": 22, "service": "ssh", "banner": "OpenSSH_8.9", "severity": "medium", "finding": "SSH port open", "remediation": "Restrict SSH access"},
],
"closed_ports": [21, 23, 25, 53, 110, 143, 993, 995, 3306, 5432, 6379, 8080, 8443, 27017],
"grade_input": {"no_database_ports_exposed": True, "no_admin_ports_exposed": False, "no_legacy_ports_exposed": True, "only_web_ports": False},
}
FAKE_TECH_RESULT = {
"url": "https://example.com",
"server": {"name": "nginx", "version": "1.21.0", "raw": "nginx/1.21.0"},
"framework": "React", "language": "JavaScript", "cms": None,
"javascript_libraries": ["react-18.2.0", "lodash-4.17.21", "axios-1.6.0"],
"cdn": "Cloudflare", "analytics": ["Google Analytics"],
"security_txt": True, "robots_txt": True,
"interesting_paths": ["/admin", "/.env", "/api/docs"],
"cookies": [
{"name": "session", "secure": True, "httponly": True, "samesite": "Strict"},
{"name": "_ga", "secure": False, "httponly": False, "samesite": "None"},
],
"grade_input": {"server_version_hidden": False, "framework_version_hidden": True, "security_txt_present": True, "cookies_secure": False, "cookies_httponly": False},
}
FAKE_SUBDOMAIN_RESULT = {
"domain": "example.com", "source": "crt.sh", "total_found": 8,
"subdomains": ["www.example.com", "mail.example.com", "api.example.com", "staging.example.com", "dev.example.com", "admin.example.com", "cdn.example.com", "blog.example.com"],
"interesting": [
{"subdomain": "staging.example.com", "reason": "staging environment exposed", "severity": "high", "remediation": "Restrict access"},
{"subdomain": "dev.example.com", "reason": "development environment exposed", "severity": "high", "remediation": "Restrict access"},
{"subdomain": "admin.example.com", "reason": "admin panel exposed", "severity": "medium", "remediation": "Add IP restriction"},
],
"grade_input": {"no_dev_staging_exposed": False, "no_admin_exposed": False, "reasonable_surface_area": True},
}
def _make_codex_provider():
from framework.config import get_api_base, get_api_key, get_llm_extra_kwargs
api_key = get_api_key()
api_base = get_api_base()
extra_kwargs = get_llm_extra_kwargs()
if not api_key or not api_base:
return None
return LiteLLMProvider(
model="openai/gpt-5.3-codex",
api_key=api_key,
api_base=api_base,
**extra_kwargs,
)
async def _stream_and_collect(provider, messages, system, tools):
"""Stream a call, collect text + tool calls, print events. Returns (text, tool_calls)."""
text = ""
tool_calls: list[ToolCallEvent] = []
async for event in provider.stream(messages=messages, system=system, tools=tools):
if isinstance(event, TextDeltaEvent):
text = event.snapshot
elif isinstance(event, ToolCallEvent):
tool_calls.append(event)
elif isinstance(event, FinishEvent):
print(f" finish: stop={event.stop_reason} in={event.input_tokens} out={event.output_tokens}")
elif isinstance(event, StreamErrorEvent):
print(f" STREAM ERROR: {event.error}")
return text, tool_calls
return text, tool_calls
def _validate_tool_args(tool_calls: list[ToolCallEvent]) -> bool:
"""Check that every tool call has valid, non-truncated JSON arguments."""
ok = True
for tc in tool_calls:
print(f" ToolCall: {tc.tool_name} id={tc.tool_use_id}")
args = tc.tool_input
# Check for the _raw fallback (means JSON parse failed → truncated)
if "_raw" in args:
print(f" TRUNCATED — raw args: {args['_raw'][:200]}...")
ok = False
continue
# For set_output, validate the nested JSON string
if tc.tool_name == "set_output" and "scan_results" in args:
raw_json = args["scan_results"]
print(f" scan_results length: {len(raw_json)} chars")
try:
parsed = json.loads(raw_json)
keys = list(parsed.keys()) if isinstance(parsed, dict) else "not-a-dict"
print(f" parsed OK — keys: {keys}")
except json.JSONDecodeError as e:
print(f" INVALID JSON in scan_results: {e}")
print(f" tail: ...{raw_json[-200:]}")
ok = False
else:
print(f" args: {json.dumps(args)}")
return ok
async def test_codex_multi_tool_scan():
"""Reproduce the real agent flow: LLM calls 6 scan tools, then set_output with big JSON."""
provider = _make_codex_provider()
if not provider:
print("[scan] SKIP — no Codex subscription")
return
system = (
"You are a security scanning agent. You have access to scanning tools.\n"
"The user will give you scan results. Your job is to consolidate them and "
"call set_output with a JSON string containing ALL the scan results.\n"
"The scan_results value MUST be a valid JSON string containing every scan result provided.\n"
"Do NOT summarize — include the complete data from each scan."
)
# Provide all scan results as tool_result messages so the LLM has to
# consolidate them into one big set_output call.
all_results = {
"ssl": FAKE_SSL_RESULT,
"headers": FAKE_HEADERS_RESULT,
"dns": FAKE_DNS_RESULT,
"ports": FAKE_PORTS_RESULT,
"tech": FAKE_TECH_RESULT,
"subdomains": FAKE_SUBDOMAIN_RESULT,
}
results_json = json.dumps(all_results, indent=2)
print(f" Input scan data size: {len(results_json)} chars")
messages = [
{
"role": "user",
"content": (
"Here are the completed scan results for example.com. "
"Consolidate ALL of them into a single set_output call. "
"The scan_results argument must be a JSON string containing the complete data.\n\n"
f"```json\n{results_json}\n```"
),
},
]
# --- Turn 1: expect set_output tool call with big JSON ---
text, tool_calls = await _stream_and_collect(provider, messages, system, SCAN_TOOLS)
if text:
print(f" text: {text[:200]}{'...' if len(text) > 200 else ''}")
if not tool_calls:
print(" NO TOOL CALLS — expected set_output")
print(f" full text: {text}")
return
valid = _validate_tool_args(tool_calls)
print(f" RESULT: {'OK' if valid else 'TRUNCATED/MALFORMED'}")
async def test_codex_parallel_tool_calls():
"""Ask the LLM to call multiple scan tools at once — tests parallel tool call streaming."""
provider = _make_codex_provider()
if not provider:
print("[parallel] SKIP — no Codex subscription")
return
system = (
"You are a security scanning agent. When asked to scan a target, "
"call ALL relevant scanning tools in parallel in a single response. "
"Always call: ssl_tls_scan, http_headers_scan, dns_security_scan, "
"port_scan, tech_stack_detect, and subdomain_enumerate."
)
messages = [
{"role": "user", "content": "Run a full security scan on example.com"},
]
text, tool_calls = await _stream_and_collect(provider, messages, system, SCAN_TOOLS)
if text:
print(f" text: {text[:200]}{'...' if len(text) > 200 else ''}")
print(f" Total tool calls: {len(tool_calls)}")
valid = _validate_tool_args(tool_calls)
print(f" RESULT: {'OK' if valid else 'TRUNCATED/MALFORMED'}")
async def test_openai_baseline():
"""OpenAI direct — baseline to compare against."""
api_key = OPENAI_API_KEY or os.environ.get("OPENAI_API_KEY")
if not api_key:
print("[openai] SKIP — OPENAI_API_KEY not set")
return
provider = LiteLLMProvider(model="openai/gpt-4o-mini", api_key=api_key)
messages = [{"role": "user", "content": "What is 3+3? Reply with just the number."}]
response = await provider.acomplete(messages=messages, max_tokens=64)
print(f" Response: {response.content!r}")
print(f" tokens: in={response.input_tokens} out={response.output_tokens}")
print(f" RESULT: {'OK' if response.content else 'EMPTY'}")
async def main():
print("=" * 60)
print("Test 1: Codex — parallel tool calls (6 scan tools)")
print("=" * 60)
await test_codex_parallel_tool_calls()
print()
print("=" * 60)
print("Test 2: Codex — big set_output call (~4KB JSON arg)")
print("=" * 60)
await test_codex_multi_tool_scan()
print()
print("=" * 60)
print("Test 3: OpenAI direct — baseline")
print("=" * 60)
await test_openai_baseline()
print()
if __name__ == "__main__":
asyncio.run(main())