Merge pull request #159 from bryanadenhq/fix-json-output
Chore:Small bug fixes with json output
This commit is contained in:
@@ -23,7 +23,9 @@
|
||||
"mcp__agent-builder__generate_success_tests",
|
||||
"mcp__agent-builder__debug_test",
|
||||
"mcp__agent-builder__run_tests",
|
||||
"mcp__agent-builder__list_mcp_tools"
|
||||
"mcp__agent-builder__list_mcp_tools",
|
||||
"mcp__agent-builder__test_graph",
|
||||
"Bash(python:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -267,7 +267,7 @@ If you prefer to build agents manually:
|
||||
{
|
||||
"node_id": "analyze",
|
||||
"name": "Analyze Ticket",
|
||||
"node_type": "llm",
|
||||
"node_type": "llm_generate",
|
||||
"system_prompt": "Analyze this support ticket...",
|
||||
"input_keys": ["ticket_content"],
|
||||
"output_keys": ["category", "priority"]
|
||||
|
||||
+148
-59
@@ -28,6 +28,45 @@ from framework.llm.provider import LLMProvider, Tool
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_json_object(text: str) -> str | None:
|
||||
"""Find the first valid JSON object in text using balanced brace matching.
|
||||
|
||||
This handles nested objects correctly, unlike simple regex like r'\\{[^{}]*\\}'.
|
||||
"""
|
||||
start = text.find('{')
|
||||
if start == -1:
|
||||
return None
|
||||
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
for i, char in enumerate(text[start:], start):
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == '\\' and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
continue
|
||||
|
||||
if char == '{':
|
||||
depth += 1
|
||||
elif char == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NodeSpec(BaseModel):
|
||||
"""
|
||||
Specification for a node in the graph.
|
||||
@@ -356,6 +395,20 @@ class LLMNode(NodeProtocol):
|
||||
def __init__(self, tool_executor: Callable | None = None):
|
||||
self.tool_executor = tool_executor
|
||||
|
||||
def _strip_code_blocks(self, content: str) -> str:
|
||||
"""Strip markdown code block wrappers from content.
|
||||
|
||||
LLMs often wrap JSON output in ```json...``` blocks.
|
||||
This method removes those wrappers to get clean content.
|
||||
"""
|
||||
import re
|
||||
content = content.strip()
|
||||
# Match ```json or ``` at start and ``` at end (greedy to handle nested)
|
||||
match = re.match(r'^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$', content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return content
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
"""Execute the LLM node."""
|
||||
import time
|
||||
@@ -417,9 +470,15 @@ class LLMNode(NodeProtocol):
|
||||
tool_executor=executor,
|
||||
)
|
||||
else:
|
||||
# Use JSON mode for llm_generate nodes with structured output
|
||||
use_json_mode = (
|
||||
ctx.node_spec.node_type == "llm_generate"
|
||||
and len(ctx.node_spec.output_keys) >= 1
|
||||
)
|
||||
response = ctx.llm.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
json_mode=use_json_mode,
|
||||
)
|
||||
|
||||
# Log the response
|
||||
@@ -442,44 +501,52 @@ class LLMNode(NodeProtocol):
|
||||
output = self._parse_output(response.content, ctx.node_spec)
|
||||
|
||||
# For llm_generate and llm_tool_use nodes, try to parse JSON and extract fields
|
||||
if ctx.node_spec.node_type in ("llm_generate", "llm_tool_use") and len(ctx.node_spec.output_keys) > 1:
|
||||
if ctx.node_spec.node_type in ("llm_generate", "llm_tool_use") and len(ctx.node_spec.output_keys) >= 1:
|
||||
try:
|
||||
import json
|
||||
|
||||
# Try direct JSON parse first
|
||||
parsed = self._extract_json_with_haiku(response.content, ctx.node_spec.output_keys)
|
||||
# Try to extract JSON from response
|
||||
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
|
||||
|
||||
# If parsed successfully, write each field to its corresponding output key
|
||||
if isinstance(parsed, dict):
|
||||
for key in ctx.node_spec.output_keys:
|
||||
if key in parsed:
|
||||
ctx.memory.write(key, parsed[key])
|
||||
output[key] = parsed[key]
|
||||
value = parsed[key]
|
||||
# Strip code block wrappers from string values
|
||||
if isinstance(value, str):
|
||||
value = self._strip_code_blocks(value)
|
||||
ctx.memory.write(key, value)
|
||||
output[key] = value
|
||||
elif key in ctx.input_data:
|
||||
# Key not in parsed JSON but exists in input - pass through input value
|
||||
ctx.memory.write(key, ctx.input_data[key])
|
||||
output[key] = ctx.input_data[key]
|
||||
else:
|
||||
# Key not in parsed JSON or input, write the whole response
|
||||
ctx.memory.write(key, response.content)
|
||||
output[key] = response.content
|
||||
# Key not in parsed JSON or input, write the whole response (stripped)
|
||||
stripped_content = self._strip_code_blocks(response.content)
|
||||
ctx.memory.write(key, stripped_content)
|
||||
output[key] = stripped_content
|
||||
else:
|
||||
# Not a dict, fall back to writing entire response to all keys
|
||||
# Not a dict, fall back to writing entire response to all keys (stripped)
|
||||
stripped_content = self._strip_code_blocks(response.content)
|
||||
for key in ctx.node_spec.output_keys:
|
||||
ctx.memory.write(key, response.content)
|
||||
output[key] = response.content
|
||||
ctx.memory.write(key, stripped_content)
|
||||
output[key] = stripped_content
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
# JSON extraction failed completely
|
||||
# JSON extraction failed completely - still strip code blocks
|
||||
logger.warning(f" ⚠ Failed to extract JSON output: {e}")
|
||||
stripped_content = self._strip_code_blocks(response.content)
|
||||
for key in ctx.node_spec.output_keys:
|
||||
ctx.memory.write(key, response.content)
|
||||
output[key] = response.content
|
||||
ctx.memory.write(key, stripped_content)
|
||||
output[key] = stripped_content
|
||||
else:
|
||||
# For non-llm_generate or single output nodes, write entire response to all keys
|
||||
# For non-llm_generate or single output nodes, write entire response (stripped)
|
||||
stripped_content = self._strip_code_blocks(response.content)
|
||||
for key in ctx.node_spec.output_keys:
|
||||
ctx.memory.write(key, response.content)
|
||||
output[key] = response.content
|
||||
ctx.memory.write(key, stripped_content)
|
||||
output[key] = stripped_content
|
||||
|
||||
return NodeResult(
|
||||
success=True,
|
||||
@@ -508,41 +575,55 @@ class LLMNode(NodeProtocol):
|
||||
# Default output
|
||||
return {"result": content}
|
||||
|
||||
def _extract_json_with_haiku(self, raw_response: str, output_keys: list[str]) -> dict[str, Any]:
|
||||
"""Use Haiku to extract clean JSON from potentially verbose LLM response."""
|
||||
def _extract_json(self, raw_response: str, output_keys: list[str]) -> dict[str, Any]:
|
||||
"""Extract clean JSON from potentially verbose LLM response.
|
||||
|
||||
Tries multiple extraction strategies in order:
|
||||
1. Direct JSON parse
|
||||
2. Markdown code block extraction
|
||||
3. Balanced brace matching
|
||||
4. Haiku LLM fallback (last resort)
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
|
||||
content = raw_response.strip()
|
||||
|
||||
# Try direct JSON parse first (fast path)
|
||||
try:
|
||||
content = raw_response.strip()
|
||||
# Remove markdown code blocks if present
|
||||
if content.startswith("```"):
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', content, re.DOTALL)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# JSON parse failed - use OutputCleaner to extract clean JSON
|
||||
# Try to extract JSON from markdown code blocks (greedy match to handle nested blocks)
|
||||
# Use anchored match to capture from first ``` to last ```
|
||||
code_block_match = re.match(r'^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$', content, re.DOTALL)
|
||||
if code_block_match:
|
||||
try:
|
||||
parsed = json.loads(code_block_match.group(1).strip())
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try to find JSON object by matching balanced braces (use module-level helper)
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# All local extraction methods failed - use LLM as last resort
|
||||
# Prefer Cerebras (faster/cheaper), fallback to Anthropic Haiku
|
||||
import os
|
||||
api_key = os.environ.get("CEREBRAS_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
# No API key, try one more simple extraction
|
||||
try:
|
||||
# Find first { and last }
|
||||
start = raw_response.find('{')
|
||||
end = raw_response.rfind('}')
|
||||
if start != -1 and end != -1:
|
||||
json_str = raw_response[start:end+1]
|
||||
return json.loads(json_str)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
raise ValueError("Cannot parse JSON and no API key for OutputCleaner (set CEREBRAS_API_KEY)")
|
||||
raise ValueError("Cannot parse JSON and no API key for LLM cleanup (set CEREBRAS_API_KEY or ANTHROPIC_API_KEY)")
|
||||
|
||||
# Use fast LLM to clean the response (Cerebras llama-3.3-70b preferred)
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
@@ -557,38 +638,37 @@ class LLMNode(NodeProtocol):
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
cleaner_llm = AnthropicProvider(model="claude-3-5-haiku-20241022")
|
||||
|
||||
prompt = f"""Extract the JSON object from this LLM response. Extract ONLY the values that the LLM actually generated.
|
||||
prompt = f"""Extract the JSON object from this LLM response.
|
||||
|
||||
Expected output keys: {output_keys}
|
||||
|
||||
LLM Response:
|
||||
{raw_response}
|
||||
|
||||
IMPORTANT:
|
||||
- Only extract keys that the LLM explicitly output in its response
|
||||
- Do NOT include keys that were just mentioned or passed through from input
|
||||
- If the LLM output multiple pieces of text/JSON, extract the LAST JSON object only
|
||||
- Output ONLY valid JSON with no extra text, no markdown, no explanations"""
|
||||
Output ONLY the JSON object, nothing else."""
|
||||
|
||||
try:
|
||||
result = cleaner_llm.complete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You extract clean JSON from messy responses. Output only valid JSON, nothing else.",
|
||||
system="Extract JSON from text. Output only valid JSON.",
|
||||
json_mode=True,
|
||||
)
|
||||
|
||||
cleaned = result.content.strip()
|
||||
# Remove markdown if OutputCleaner added it
|
||||
# Remove markdown if LLM added it
|
||||
if cleaned.startswith("```"):
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, re.DOTALL)
|
||||
if match:
|
||||
cleaned = match.group(1).strip()
|
||||
|
||||
parsed = json.loads(cleaned)
|
||||
logger.info(" ✓ OutputCleaner extracted JSON")
|
||||
logger.info(" ✓ LLM cleaned JSON output")
|
||||
return parsed
|
||||
|
||||
except ValueError:
|
||||
raise # Re-raise our descriptive error
|
||||
except Exception as e:
|
||||
logger.warning(f" ⚠ OutputCleaner JSON extraction failed: {e}")
|
||||
logger.warning(f" ⚠ LLM JSON extraction failed: {e}")
|
||||
raise
|
||||
|
||||
def _build_messages(self, ctx: NodeContext) -> list[dict]:
|
||||
@@ -629,12 +709,23 @@ IMPORTANT:
|
||||
|
||||
# Build prompt for Haiku to extract clean values
|
||||
import json
|
||||
|
||||
# Smart truncation: truncate individual values rather than corrupting JSON structure
|
||||
def truncate_value(v, max_len=500):
|
||||
s = str(v)
|
||||
return s[:max_len] + "..." if len(s) > max_len else v
|
||||
|
||||
truncated_data = {
|
||||
k: truncate_value(v) for k, v in memory_data.items()
|
||||
}
|
||||
memory_json = json.dumps(truncated_data, indent=2, default=str)
|
||||
|
||||
prompt = f"""Extract the following information from the memory context:
|
||||
|
||||
Required fields: {', '.join(ctx.node_spec.input_keys)}
|
||||
|
||||
Memory context (may contain nested data, JSON strings, or extra information):
|
||||
{json.dumps(memory_data, indent=2, default=str)[:3000]}
|
||||
{memory_json}
|
||||
|
||||
Extract ONLY the clean values for the required fields. Ignore nested structures, JSON wrappers, and irrelevant data.
|
||||
|
||||
@@ -652,11 +743,10 @@ Output as JSON with the exact field names requested."""
|
||||
# Parse Haiku's response
|
||||
response_text = message.content[0].text.strip()
|
||||
|
||||
# Try to extract JSON
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
extracted = json.loads(json_match.group())
|
||||
# Try to extract JSON using balanced brace matching
|
||||
json_str = find_json_object(response_text)
|
||||
if json_str:
|
||||
extracted = json.loads(json_str)
|
||||
# Format as key: value pairs
|
||||
parts = [f"{k}: {v}" for k, v in extracted.items() if k in ctx.node_spec.input_keys]
|
||||
if parts:
|
||||
@@ -820,11 +910,10 @@ Respond with ONLY a JSON object:
|
||||
max_tokens=150,
|
||||
)
|
||||
|
||||
# Parse response
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
# Parse response using balanced brace matching
|
||||
json_str = find_json_object(response.content)
|
||||
if json_str:
|
||||
data = json.loads(json_str)
|
||||
chosen = data.get("chosen", "default")
|
||||
reasoning = data.get("reasoning", "")
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Test OutputCleaner with real Cerebras LLM.
|
||||
Demonstrates how OutputCleaner fixes the JSON parsing trap using llama-3.3-70b.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from framework.graph.output_cleaner import OutputCleaner, CleansingConfig
|
||||
@@ -68,7 +67,7 @@ def test_cleaning_with_cerebras():
|
||||
target_node_spec=target_spec,
|
||||
)
|
||||
|
||||
print(f"\nMalformed output:")
|
||||
print("\nMalformed output:")
|
||||
print(json.dumps(malformed_output, indent=2))
|
||||
print(f"\nValidation errors: {validation.errors}")
|
||||
|
||||
@@ -81,7 +80,7 @@ def test_cleaning_with_cerebras():
|
||||
validation_errors=validation.errors,
|
||||
)
|
||||
|
||||
print(f"\n✓ Cleaned output:")
|
||||
print("\n✓ Cleaned output:")
|
||||
print(json.dumps(cleaned, indent=2))
|
||||
|
||||
assert isinstance(cleaned, dict), "Should return dict"
|
||||
@@ -114,7 +113,7 @@ def test_cleaning_with_cerebras():
|
||||
target_node_spec=target_spec2,
|
||||
)
|
||||
|
||||
print(f"\nMalformed output:")
|
||||
print("\nMalformed output:")
|
||||
print(json.dumps(malformed_output2, indent=2))
|
||||
print(f"\nValidation errors: {validation2.errors}")
|
||||
|
||||
@@ -127,7 +126,7 @@ def test_cleaning_with_cerebras():
|
||||
validation_errors=validation2.errors,
|
||||
)
|
||||
|
||||
print(f"\n✓ Cleaned output:")
|
||||
print("\n✓ Cleaned output:")
|
||||
print(json.dumps(cleaned2, indent=2))
|
||||
|
||||
assert isinstance(cleaned2, dict), "Should return dict"
|
||||
@@ -138,7 +137,7 @@ def test_cleaning_with_cerebras():
|
||||
|
||||
# Stats
|
||||
stats = cleaner.get_stats()
|
||||
print(f"\n\nCleaner Statistics:")
|
||||
print("\n\nCleaner Statistics:")
|
||||
print(f" Total cleanings: {stats['total_cleanings']}")
|
||||
print(f" Cache size: {stats['cache_size']}")
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ class AnthropicProvider(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
json_mode: bool = False,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion from Claude (via LiteLLM)."""
|
||||
return self._provider.complete(
|
||||
@@ -74,6 +75,7 @@ class AnthropicProvider(LLMProvider):
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
|
||||
@@ -78,6 +78,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
json_mode: bool = False,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion using LiteLLM."""
|
||||
# Prepare messages with system prompt
|
||||
@@ -86,6 +87,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
# Add JSON mode via prompt engineering (works across all providers)
|
||||
if json_mode:
|
||||
json_instruction = (
|
||||
"\n\nPlease respond with a valid JSON object."
|
||||
)
|
||||
# Append to system message if present, otherwise add as system message
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
else:
|
||||
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
|
||||
|
||||
# Build kwargs
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
|
||||
@@ -58,6 +58,7 @@ class LLMProvider(ABC):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
json_mode: bool = False,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
@@ -67,6 +68,7 @@ class LLMProvider(ABC):
|
||||
system: System prompt
|
||||
tools: Available tools for the LLM to use
|
||||
max_tokens: Maximum tokens to generate
|
||||
json_mode: If True, request structured JSON output from the LLM
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and metadata
|
||||
|
||||
@@ -310,11 +310,68 @@ def set_goal(
|
||||
"""Define the goal for the agent. Goals are the source of truth - they define what success looks like."""
|
||||
session = get_session()
|
||||
|
||||
# Parse JSON inputs
|
||||
criteria_list = json.loads(success_criteria)
|
||||
constraint_list = json.loads(constraints)
|
||||
# Parse JSON inputs with error handling
|
||||
try:
|
||||
criteria_list = json.loads(success_criteria)
|
||||
except json.JSONDecodeError as e:
|
||||
return json.dumps({
|
||||
"valid": False,
|
||||
"errors": [f"Invalid JSON in success_criteria: {e}"],
|
||||
"warnings": [],
|
||||
})
|
||||
|
||||
# Convert to proper objects
|
||||
try:
|
||||
constraint_list = json.loads(constraints)
|
||||
except json.JSONDecodeError as e:
|
||||
return json.dumps({
|
||||
"valid": False,
|
||||
"errors": [f"Invalid JSON in constraints: {e}"],
|
||||
"warnings": [],
|
||||
})
|
||||
|
||||
# Validate BEFORE object creation
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
if not goal_id:
|
||||
errors.append("Goal must have an id")
|
||||
if not name:
|
||||
errors.append("Goal must have a name")
|
||||
if not description:
|
||||
errors.append("Goal must have a description")
|
||||
if not criteria_list:
|
||||
errors.append("Goal must have at least one success criterion")
|
||||
if not constraint_list:
|
||||
warnings.append("Consider adding constraints")
|
||||
|
||||
# Validate required fields in criteria and constraints
|
||||
for i, sc in enumerate(criteria_list):
|
||||
if not isinstance(sc, dict):
|
||||
errors.append(f"success_criteria[{i}] must be an object")
|
||||
else:
|
||||
if "id" not in sc:
|
||||
errors.append(f"success_criteria[{i}] missing required field 'id'")
|
||||
if "description" not in sc:
|
||||
errors.append(f"success_criteria[{i}] missing required field 'description'")
|
||||
|
||||
for i, c in enumerate(constraint_list):
|
||||
if not isinstance(c, dict):
|
||||
errors.append(f"constraints[{i}] must be an object")
|
||||
else:
|
||||
if "id" not in c:
|
||||
errors.append(f"constraints[{i}] missing required field 'id'")
|
||||
if "description" not in c:
|
||||
errors.append(f"constraints[{i}] missing required field 'description'")
|
||||
|
||||
# Return early if validation failed
|
||||
if errors:
|
||||
return json.dumps({
|
||||
"valid": False,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
})
|
||||
|
||||
# Convert to proper objects (now safe - we validated required fields)
|
||||
criteria = [
|
||||
SuccessCriterion(
|
||||
id=sc["id"],
|
||||
@@ -345,21 +402,6 @@ def set_goal(
|
||||
constraints=constraint_objs,
|
||||
)
|
||||
|
||||
# Validate
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
if not goal_id:
|
||||
errors.append("Goal must have an id")
|
||||
if not name:
|
||||
errors.append("Goal must have a name")
|
||||
if not description:
|
||||
errors.append("Goal must have a description")
|
||||
if not criteria_list:
|
||||
errors.append("Goal must have at least one success criterion")
|
||||
if not constraint_list:
|
||||
warnings.append("Consider adding constraints")
|
||||
|
||||
_save_session(session) # Auto-save
|
||||
|
||||
return json.dumps({
|
||||
|
||||
@@ -329,3 +329,135 @@ class TestAnthropicProviderBackwardCompatibility:
|
||||
|
||||
assert result.content == "The time is 3:00 PM."
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
|
||||
class TestJsonMode:
|
||||
"""Test json_mode parameter for structured JSON output via prompt engineering."""
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_json_mode_adds_instruction_to_system_prompt(self, mock_completion):
|
||||
"""Test that json_mode=True adds JSON instruction to system prompt."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"key": "value"}'
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
provider.complete(
|
||||
messages=[{"role": "user", "content": "Return JSON"}],
|
||||
system="You are helpful.",
|
||||
json_mode=True
|
||||
)
|
||||
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
# Should NOT use response_format (prompt engineering instead)
|
||||
assert "response_format" not in call_kwargs
|
||||
# Should have JSON instruction appended to system message
|
||||
messages = call_kwargs["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "You are helpful." in messages[0]["content"]
|
||||
assert "Please respond with a valid JSON object" in messages[0]["content"]
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_json_mode_creates_system_prompt_if_none(self, mock_completion):
|
||||
"""Test that json_mode=True creates system prompt if none provided."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"key": "value"}'
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
provider.complete(
|
||||
messages=[{"role": "user", "content": "Return JSON"}],
|
||||
json_mode=True
|
||||
)
|
||||
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
messages = call_kwargs["messages"]
|
||||
# Should insert a system message with JSON instruction
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Please respond with a valid JSON object" in messages[0]["content"]
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_json_mode_false_no_instruction(self, mock_completion):
|
||||
"""Test that json_mode=False does not add JSON instruction."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
provider.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system="You are helpful.",
|
||||
json_mode=False
|
||||
)
|
||||
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
assert "response_format" not in call_kwargs
|
||||
messages = call_kwargs["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Please respond with a valid JSON object" not in messages[0]["content"]
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_json_mode_default_is_false(self, mock_completion):
|
||||
"""Test that json_mode defaults to False (no JSON instruction)."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
provider.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system="You are helpful."
|
||||
)
|
||||
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
assert "response_format" not in call_kwargs
|
||||
messages = call_kwargs["messages"]
|
||||
# System prompt should be unchanged
|
||||
assert messages[0]["content"] == "You are helpful."
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_anthropic_provider_passes_json_mode(self, mock_completion):
|
||||
"""Test that AnthropicProvider passes json_mode through (prompt engineering)."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"result": "ok"}'
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "claude-haiku-4-5-20251001"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = AnthropicProvider(api_key="test-key")
|
||||
provider.complete(
|
||||
messages=[{"role": "user", "content": "Return JSON"}],
|
||||
system="You are helpful.",
|
||||
json_mode=True
|
||||
)
|
||||
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
# Should NOT use response_format
|
||||
assert "response_format" not in call_kwargs
|
||||
# Should have JSON instruction in system prompt
|
||||
messages = call_kwargs["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Please respond with a valid JSON object" in messages[0]["content"]
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Tests for LLMNode JSON extraction logic.
|
||||
|
||||
Run with:
|
||||
cd core
|
||||
pytest tests/test_node_json_extraction.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from framework.graph.node import LLMNode
|
||||
|
||||
|
||||
class TestJsonExtraction:
|
||||
"""Test _extract_json JSON extraction without LLM calls."""
|
||||
|
||||
@pytest.fixture
|
||||
def node(self):
|
||||
"""Create an LLMNode instance for testing."""
|
||||
return LLMNode()
|
||||
|
||||
def test_clean_json(self, node):
|
||||
"""Test parsing clean JSON directly."""
|
||||
result = node._extract_json('{"key": "value"}', ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_json_with_whitespace(self, node):
|
||||
"""Test parsing JSON with surrounding whitespace."""
|
||||
result = node._extract_json(' {"key": "value"} \n', ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_markdown_code_block_at_start(self, node):
|
||||
"""Test extracting JSON from markdown code block at start."""
|
||||
input_text = '```json\n{"key": "value"}\n```'
|
||||
result = node._extract_json(input_text, ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_markdown_code_block_without_json_label(self, node):
|
||||
"""Test extracting JSON from markdown code block without 'json' label."""
|
||||
input_text = '```\n{"key": "value"}\n```'
|
||||
result = node._extract_json(input_text, ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_prose_around_markdown_block(self, node):
|
||||
"""Test extracting JSON when prose surrounds the markdown block."""
|
||||
input_text = 'Here is the result:\n```json\n{"key": "value"}\n```\nHope this helps!'
|
||||
result = node._extract_json(input_text, ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_json_embedded_in_prose(self, node):
|
||||
"""Test extracting JSON embedded in prose text."""
|
||||
input_text = 'The answer is {"key": "value"} as requested.'
|
||||
result = node._extract_json(input_text, ["key"])
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_nested_json(self, node):
|
||||
"""Test parsing nested JSON objects."""
|
||||
input_text = '{"outer": {"inner": "value"}}'
|
||||
result = node._extract_json(input_text, ["outer"])
|
||||
assert result == {"outer": {"inner": "value"}}
|
||||
|
||||
def test_deeply_nested_json(self, node):
|
||||
"""Test parsing deeply nested JSON objects."""
|
||||
input_text = '{"a": {"b": {"c": {"d": "deep"}}}}'
|
||||
result = node._extract_json(input_text, ["a"])
|
||||
assert result == {"a": {"b": {"c": {"d": "deep"}}}}
|
||||
|
||||
def test_json_with_array(self, node):
|
||||
"""Test parsing JSON with array values."""
|
||||
input_text = '{"items": [1, 2, 3]}'
|
||||
result = node._extract_json(input_text, ["items"])
|
||||
assert result == {"items": [1, 2, 3]}
|
||||
|
||||
def test_json_with_string_containing_braces(self, node):
|
||||
"""Test parsing JSON where string values contain braces."""
|
||||
input_text = '{"code": "function() { return 1; }"}'
|
||||
result = node._extract_json(input_text, ["code"])
|
||||
assert result == {"code": "function() { return 1; }"}
|
||||
|
||||
def test_json_with_escaped_quotes(self, node):
|
||||
"""Test parsing JSON with escaped quotes in strings."""
|
||||
input_text = '{"message": "He said \\"hello\\""}'
|
||||
result = node._extract_json(input_text, ["message"])
|
||||
assert result == {"message": 'He said "hello"'}
|
||||
|
||||
def test_multiple_json_objects_takes_first(self, node):
|
||||
"""Test that when multiple JSON objects exist, first is taken."""
|
||||
input_text = '{"first": 1} and then {"second": 2}'
|
||||
result = node._extract_json(input_text, ["first"])
|
||||
assert result == {"first": 1}
|
||||
|
||||
def test_json_with_boolean_and_null(self, node):
|
||||
"""Test parsing JSON with boolean and null values."""
|
||||
input_text = '{"active": true, "deleted": false, "data": null}'
|
||||
result = node._extract_json(input_text, ["active", "deleted", "data"])
|
||||
assert result == {"active": True, "deleted": False, "data": None}
|
||||
|
||||
def test_json_with_numbers(self, node):
|
||||
"""Test parsing JSON with integer and float values."""
|
||||
input_text = '{"count": 42, "price": 19.99}'
|
||||
result = node._extract_json(input_text, ["count", "price"])
|
||||
assert result == {"count": 42, "price": 19.99}
|
||||
|
||||
def test_invalid_json_raises_error(self, node):
|
||||
"""Test that completely invalid JSON raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot parse JSON"):
|
||||
node._extract_json("This is not JSON at all", ["key"])
|
||||
|
||||
def test_empty_string_raises_error(self, node):
|
||||
"""Test that empty string raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot parse JSON"):
|
||||
node._extract_json("", ["key"])
|
||||
Reference in New Issue
Block a user