Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0770a8161a | |||
| e7be02b421 | |||
| f0453d0a73 | |||
| ae73866f94 | |||
| 40e39d29f8 |
@@ -331,7 +331,7 @@ No. Aden is built from the ground up with no dependencies on LangChain, CrewAI,
|
||||
|
||||
**Q: What LLM providers does Aden support?**
|
||||
|
||||
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
|
||||
**Q: Can I use Aden with local AI models like Ollama?**
|
||||
|
||||
|
||||
+155
-14
@@ -144,7 +144,20 @@ class NodeSpec(BaseModel):
|
||||
max_retries: int = Field(default=3)
|
||||
retry_on: list[str] = Field(default_factory=list, description="Error types to retry on")
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
# Pydantic model for output validation
|
||||
output_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional Pydantic model class for validating and parsing LLM output. "
|
||||
"When set, the LLM response will be validated against this model."
|
||||
),
|
||||
)
|
||||
max_validation_retries: int = Field(
|
||||
default=2,
|
||||
description="Maximum retries when Pydantic validation fails (with feedback to LLM)"
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
class MemoryWriteError(Exception):
|
||||
@@ -346,6 +359,9 @@ class NodeResult:
|
||||
tokens_used: int = 0
|
||||
latency_ms: int = 0
|
||||
|
||||
# Pydantic validation errors (if any)
|
||||
validation_errors: list[str] = field(default_factory=list)
|
||||
|
||||
def to_summary(self, node_spec: Any = None) -> str:
|
||||
"""
|
||||
Generate a human-readable summary of this node's execution and output.
|
||||
@@ -597,19 +613,133 @@ class LLMNode(NodeProtocol):
|
||||
f" 📋 Expecting JSON output with keys: {ctx.node_spec.output_keys}"
|
||||
)
|
||||
|
||||
response = ctx.llm.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
json_mode=use_json_mode,
|
||||
)
|
||||
# Phase 3: Auto-generate JSON schema from Pydantic model
|
||||
response_format = None
|
||||
if ctx.node_spec.output_model is not None:
|
||||
json_schema = ctx.node_spec.output_model.model_json_schema()
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": ctx.node_spec.output_model.__name__,
|
||||
"schema": json_schema,
|
||||
"strict": True,
|
||||
}
|
||||
}
|
||||
logger.info(
|
||||
f" 📐 Using JSON schema from Pydantic model: "
|
||||
f"{ctx.node_spec.output_model.__name__}"
|
||||
)
|
||||
|
||||
# Log the response
|
||||
response_preview = (
|
||||
response.content[:200] if len(response.content) > 200 else response.content
|
||||
)
|
||||
if len(response.content) > 200:
|
||||
response_preview += "..."
|
||||
logger.info(f" ← Response: {response_preview}")
|
||||
# Phase 2: Retry loop for Pydantic validation
|
||||
max_validation_retries = (
|
||||
ctx.node_spec.max_validation_retries if ctx.node_spec.output_model else 0
|
||||
)
|
||||
validation_attempt = 0
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
current_messages = messages.copy()
|
||||
|
||||
while True:
|
||||
response = ctx.llm.complete(
|
||||
messages=current_messages,
|
||||
system=system,
|
||||
json_mode=use_json_mode,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
total_input_tokens += response.input_tokens
|
||||
total_output_tokens += response.output_tokens
|
||||
|
||||
# Log the response
|
||||
response_preview = (
|
||||
response.content[:200] if len(response.content) > 200 else response.content
|
||||
)
|
||||
if len(response.content) > 200:
|
||||
response_preview += "..."
|
||||
logger.info(f" ← Response: {response_preview}")
|
||||
|
||||
# If no output_model, break immediately (no validation needed)
|
||||
if ctx.node_spec.output_model is None:
|
||||
break
|
||||
|
||||
# Try to parse and validate the response
|
||||
try:
|
||||
import json
|
||||
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
from framework.graph.validator import OutputValidator
|
||||
validator = OutputValidator()
|
||||
validation_result, validated_model = validator.validate_with_pydantic(
|
||||
parsed, ctx.node_spec.output_model
|
||||
)
|
||||
|
||||
if validation_result.success:
|
||||
# Validation passed, break out of retry loop
|
||||
logger.info(
|
||||
f" ✓ Pydantic validation passed for "
|
||||
f"{ctx.node_spec.output_model.__name__}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Validation failed
|
||||
validation_attempt += 1
|
||||
|
||||
if validation_attempt <= max_validation_retries:
|
||||
# Add validation feedback to messages and retry
|
||||
feedback = validator.format_validation_feedback(
|
||||
validation_result, ctx.node_spec.output_model
|
||||
)
|
||||
logger.warning(
|
||||
f" ⚠ Pydantic validation failed "
|
||||
f"(attempt {validation_attempt}/{max_validation_retries}): "
|
||||
f"{validation_result.error}"
|
||||
)
|
||||
logger.info(" 🔄 Retrying with validation feedback...")
|
||||
|
||||
# Add the assistant's failed response and feedback
|
||||
current_messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content
|
||||
})
|
||||
current_messages.append({
|
||||
"role": "user",
|
||||
"content": feedback
|
||||
})
|
||||
continue # Retry the LLM call
|
||||
else:
|
||||
# Max retries exceeded
|
||||
latency_ms = int((time.time() - start) * 1000)
|
||||
logger.error(
|
||||
f" ✗ Pydantic validation failed after "
|
||||
f"{max_validation_retries} retries: "
|
||||
f"{validation_result.error}"
|
||||
)
|
||||
ctx.runtime.record_outcome(
|
||||
decision_id=decision_id,
|
||||
success=False,
|
||||
error=f"Validation failed: {validation_result.error}",
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=(
|
||||
f"Pydantic validation failed after "
|
||||
f"{max_validation_retries} retries: "
|
||||
f"{validation_result.error}"
|
||||
),
|
||||
output=parsed,
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
validation_errors=validation_result.errors,
|
||||
)
|
||||
else:
|
||||
# Not a dict, can't validate - break and let downstream handle
|
||||
break
|
||||
except Exception:
|
||||
# JSON extraction failed - break and let downstream handle
|
||||
break
|
||||
|
||||
latency_ms = int((time.time() - start) * 1000)
|
||||
|
||||
@@ -635,8 +765,19 @@ class LLMNode(NodeProtocol):
|
||||
# Try to extract JSON from response
|
||||
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
|
||||
|
||||
# If parsed successfully, write each field to its corresponding output key
|
||||
# If parsed successfully, validate against Pydantic model if specified
|
||||
if isinstance(parsed, dict):
|
||||
# If we have output_model, the validation already happened in the retry loop
|
||||
if ctx.node_spec.output_model is not None:
|
||||
from framework.graph.validator import OutputValidator
|
||||
validator = OutputValidator()
|
||||
validation_result, validated_model = validator.validate_with_pydantic(
|
||||
parsed, ctx.node_spec.output_model
|
||||
)
|
||||
# Use validated model's dict representation
|
||||
if validated_model:
|
||||
parsed = validated_model.model_dump()
|
||||
|
||||
for key in ctx.node_spec.output_keys:
|
||||
if key in parsed:
|
||||
value = parsed[key]
|
||||
|
||||
@@ -8,6 +8,8 @@ import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -131,6 +133,71 @@ class OutputValidator:
|
||||
|
||||
return ValidationResult(success=len(errors) == 0, errors=errors)
|
||||
|
||||
def validate_with_pydantic(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
model: type[BaseModel],
|
||||
) -> tuple[ValidationResult, BaseModel | None]:
|
||||
"""
|
||||
Validate output against a Pydantic model.
|
||||
|
||||
Args:
|
||||
output: The output dict to validate
|
||||
model: Pydantic model class to validate against
|
||||
|
||||
Returns:
|
||||
Tuple of (ValidationResult, validated_model_instance or None)
|
||||
"""
|
||||
try:
|
||||
validated = model.model_validate(output)
|
||||
return ValidationResult(success=True, errors=[]), validated
|
||||
except ValidationError as e:
|
||||
errors = []
|
||||
for error in e.errors():
|
||||
field_path = ".".join(str(loc) for loc in error["loc"])
|
||||
msg = error["msg"]
|
||||
error_type = error["type"]
|
||||
errors.append(f"{field_path}: {msg} (type: {error_type})")
|
||||
return ValidationResult(success=False, errors=errors), None
|
||||
|
||||
def format_validation_feedback(
|
||||
self,
|
||||
validation_result: ValidationResult,
|
||||
model: type[BaseModel],
|
||||
) -> str:
|
||||
"""
|
||||
Format validation errors as feedback for LLM retry.
|
||||
|
||||
Args:
|
||||
validation_result: The failed validation result
|
||||
model: The Pydantic model that was used for validation
|
||||
|
||||
Returns:
|
||||
Formatted feedback string to include in retry prompt
|
||||
"""
|
||||
# Get the model's JSON schema for reference
|
||||
schema = model.model_json_schema()
|
||||
|
||||
feedback = "Your previous response had validation errors:\n\n"
|
||||
feedback += "ERRORS:\n"
|
||||
for error in validation_result.errors:
|
||||
feedback += f" - {error}\n"
|
||||
|
||||
feedback += "\nEXPECTED SCHEMA:\n"
|
||||
feedback += f" Model: {model.__name__}\n"
|
||||
|
||||
if "properties" in schema:
|
||||
feedback += " Required fields:\n"
|
||||
required = schema.get("required", [])
|
||||
for prop_name, prop_info in schema["properties"].items():
|
||||
req_marker = " (required)" if prop_name in required else ""
|
||||
prop_type = prop_info.get("type", "any")
|
||||
feedback += f" - {prop_name}: {prop_type}{req_marker}\n"
|
||||
|
||||
feedback += "\nPlease fix the errors and respond with valid JSON matching the schema."
|
||||
|
||||
return feedback
|
||||
|
||||
def validate_no_hallucination(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
|
||||
@@ -27,6 +27,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
- OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
|
||||
- Anthropic: claude-3-opus, claude-3-sonnet, claude-3-haiku
|
||||
- Google: gemini-pro, gemini-1.5-pro, gemini-1.5-flash
|
||||
- DeepSeek: deepseek-chat, deepseek-coder, deepseek-reasoner
|
||||
- Mistral: mistral-large, mistral-medium, mistral-small
|
||||
- Groq: llama3-70b, mixtral-8x7b
|
||||
- Local: ollama/llama3, ollama/mistral
|
||||
@@ -42,6 +43,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Google Gemini
|
||||
provider = LiteLLMProvider(model="gemini/gemini-1.5-flash")
|
||||
|
||||
# DeepSeek
|
||||
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
|
||||
|
||||
# Local Ollama
|
||||
provider = LiteLLMProvider(model="ollama/llama3")
|
||||
|
||||
|
||||
@@ -34,6 +34,12 @@ class TestLiteLLMProviderInit:
|
||||
provider = LiteLLMProvider(model="claude-3-haiku-20240307")
|
||||
assert provider.model == "claude-3-haiku-20240307"
|
||||
|
||||
def test_init_deepseek_model(self):
|
||||
"""Test initialization with DeepSeek model."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
|
||||
assert provider.model == "deepseek/deepseek-chat"
|
||||
|
||||
def test_init_with_api_key(self):
|
||||
"""Test initialization with explicit API key."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="my-api-key")
|
||||
|
||||
@@ -100,12 +100,18 @@ class TestJsonExtraction:
|
||||
result = node._extract_json(input_text, ["count", "price"])
|
||||
assert result == {"count": 42, "price": 19.99}
|
||||
|
||||
def test_invalid_json_raises_error(self, node):
|
||||
"""Test that completely invalid JSON raises an error."""
|
||||
def test_invalid_json_raises_error(self, node, monkeypatch):
|
||||
"""Test that completely invalid JSON raises an error when no LLM fallback available."""
|
||||
# Remove API keys so LLM fallback is not attempted
|
||||
monkeypatch.delenv("CEREBRAS_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Cannot parse JSON"):
|
||||
node._extract_json("This is not JSON at all", ["key"])
|
||||
|
||||
def test_empty_string_raises_error(self, node):
|
||||
"""Test that empty string raises an error."""
|
||||
def test_empty_string_raises_error(self, node, monkeypatch):
|
||||
"""Test that empty string raises an error when no LLM fallback available."""
|
||||
# Remove API keys so LLM fallback is not attempted
|
||||
monkeypatch.delenv("CEREBRAS_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Cannot parse JSON"):
|
||||
node._extract_json("", ["key"])
|
||||
|
||||
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Tests for Pydantic validation of LLM outputs.
|
||||
|
||||
Tests the new output_model feature in NodeSpec that allows
|
||||
validating LLM responses against Pydantic models.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from framework.graph.node import NodeResult, NodeSpec
|
||||
from framework.graph.validator import OutputValidator, ValidationResult
|
||||
|
||||
|
||||
# Test Pydantic models
|
||||
class SimpleOutput(BaseModel):
|
||||
"""Simple test model."""
|
||||
message: str
|
||||
count: int
|
||||
|
||||
|
||||
class ComplexOutput(BaseModel):
|
||||
"""Complex test model with nested types."""
|
||||
query: str
|
||||
results: list[str] = Field(min_length=1)
|
||||
confidence: float = Field(ge=0, le=1)
|
||||
metadata: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TicketAnalysis(BaseModel):
|
||||
"""Realistic use case model."""
|
||||
category: str
|
||||
priority: int = Field(ge=1, le=5)
|
||||
summary: str = Field(min_length=10)
|
||||
suggested_action: str
|
||||
|
||||
|
||||
class TestNodeSpecOutputModel:
|
||||
"""Tests for output_model field in NodeSpec."""
|
||||
|
||||
def test_nodespec_accepts_output_model(self):
|
||||
"""NodeSpec should accept a Pydantic model class."""
|
||||
node = NodeSpec(
|
||||
id="test_node",
|
||||
name="Test Node",
|
||||
description="A test node",
|
||||
node_type="llm_generate",
|
||||
output_model=SimpleOutput,
|
||||
)
|
||||
|
||||
assert node.output_model == SimpleOutput
|
||||
assert node.max_validation_retries == 2 # default
|
||||
|
||||
def test_nodespec_output_model_optional(self):
|
||||
"""output_model should be optional (None by default)."""
|
||||
node = NodeSpec(
|
||||
id="test_node",
|
||||
name="Test Node",
|
||||
description="A test node",
|
||||
)
|
||||
|
||||
assert node.output_model is None
|
||||
|
||||
def test_nodespec_custom_validation_retries(self):
|
||||
"""Should support custom max_validation_retries."""
|
||||
node = NodeSpec(
|
||||
id="test_node",
|
||||
name="Test Node",
|
||||
description="A test node",
|
||||
output_model=SimpleOutput,
|
||||
max_validation_retries=5,
|
||||
)
|
||||
|
||||
assert node.max_validation_retries == 5
|
||||
|
||||
|
||||
class TestOutputValidatorPydantic:
|
||||
"""Tests for validate_with_pydantic method."""
|
||||
|
||||
def test_validate_valid_output(self):
|
||||
"""Should pass for valid output matching model."""
|
||||
validator = OutputValidator()
|
||||
output = {"message": "Hello", "count": 5}
|
||||
|
||||
result, validated = validator.validate_with_pydantic(output, SimpleOutput)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.errors) == 0
|
||||
assert validated is not None
|
||||
assert validated.message == "Hello"
|
||||
assert validated.count == 5
|
||||
|
||||
def test_validate_missing_required_field(self):
|
||||
"""Should fail when required field is missing."""
|
||||
validator = OutputValidator()
|
||||
output = {"message": "Hello"} # missing 'count'
|
||||
|
||||
result, validated = validator.validate_with_pydantic(output, SimpleOutput)
|
||||
|
||||
assert result.success is False
|
||||
assert len(result.errors) > 0
|
||||
assert "count" in result.errors[0]
|
||||
assert validated is None
|
||||
|
||||
def test_validate_wrong_type(self):
|
||||
"""Should fail when field has wrong type."""
|
||||
validator = OutputValidator()
|
||||
output = {"message": "Hello", "count": "five"} # count should be int
|
||||
|
||||
result, validated = validator.validate_with_pydantic(output, SimpleOutput)
|
||||
|
||||
assert result.success is False
|
||||
assert len(result.errors) > 0
|
||||
assert validated is None
|
||||
|
||||
def test_validate_complex_model(self):
|
||||
"""Should validate complex nested models."""
|
||||
validator = OutputValidator()
|
||||
output = {
|
||||
"query": "test query",
|
||||
"results": ["result1", "result2"],
|
||||
"confidence": 0.85,
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
|
||||
result, validated = validator.validate_with_pydantic(output, ComplexOutput)
|
||||
|
||||
assert result.success is True
|
||||
assert validated is not None
|
||||
assert validated.query == "test query"
|
||||
assert len(validated.results) == 2
|
||||
assert validated.confidence == 0.85
|
||||
|
||||
def test_validate_field_constraints(self):
|
||||
"""Should validate field constraints (min_length, ge, le, etc.)."""
|
||||
validator = OutputValidator()
|
||||
|
||||
# Empty results list (violates min_length=1)
|
||||
output = {
|
||||
"query": "test",
|
||||
"results": [], # should have at least 1 item
|
||||
"confidence": 0.5,
|
||||
}
|
||||
|
||||
result, validated = validator.validate_with_pydantic(output, ComplexOutput)
|
||||
|
||||
assert result.success is False
|
||||
assert "results" in result.error
|
||||
|
||||
def test_validate_range_constraints(self):
|
||||
"""Should validate range constraints (ge, le)."""
|
||||
validator = OutputValidator()
|
||||
|
||||
# Confidence out of range
|
||||
output = {
|
||||
"query": "test",
|
||||
"results": ["r1"],
|
||||
"confidence": 1.5, # should be <= 1
|
||||
}
|
||||
|
||||
result, validated = validator.validate_with_pydantic(output, ComplexOutput)
|
||||
|
||||
assert result.success is False
|
||||
assert "confidence" in result.error
|
||||
|
||||
def test_validate_realistic_model(self):
|
||||
"""Should work with realistic use case models."""
|
||||
validator = OutputValidator()
|
||||
|
||||
output = {
|
||||
"category": "Technical Support",
|
||||
"priority": 3,
|
||||
"summary": "User is experiencing login issues with error 401",
|
||||
"suggested_action": "Reset password and verify account status",
|
||||
}
|
||||
|
||||
result, validated = validator.validate_with_pydantic(output, TicketAnalysis)
|
||||
|
||||
assert result.success is True
|
||||
assert validated is not None
|
||||
assert validated.category == "Technical Support"
|
||||
assert validated.priority == 3
|
||||
|
||||
|
||||
class TestValidationFeedback:
|
||||
"""Tests for format_validation_feedback method."""
|
||||
|
||||
def test_format_feedback_includes_errors(self):
|
||||
"""Feedback should include validation errors."""
|
||||
validator = OutputValidator()
|
||||
output = {"message": "Hello"} # missing count
|
||||
|
||||
result, _ = validator.validate_with_pydantic(output, SimpleOutput)
|
||||
feedback = validator.format_validation_feedback(result, SimpleOutput)
|
||||
|
||||
assert "validation errors" in feedback.lower()
|
||||
assert "count" in feedback
|
||||
assert "SimpleOutput" in feedback
|
||||
|
||||
def test_format_feedback_includes_schema(self):
|
||||
"""Feedback should include expected schema information."""
|
||||
validator = OutputValidator()
|
||||
result = ValidationResult(success=False, errors=["test error"])
|
||||
|
||||
feedback = validator.format_validation_feedback(result, SimpleOutput)
|
||||
|
||||
assert "message" in feedback
|
||||
assert "count" in feedback
|
||||
assert "required" in feedback.lower()
|
||||
|
||||
|
||||
class TestNodeResultValidationErrors:
|
||||
"""Tests for validation_errors field in NodeResult."""
|
||||
|
||||
def test_noderesult_includes_validation_errors(self):
|
||||
"""NodeResult should store validation errors."""
|
||||
result = NodeResult(
|
||||
success=False,
|
||||
error="Pydantic validation failed",
|
||||
validation_errors=["count: field required", "priority: must be >= 1"],
|
||||
)
|
||||
|
||||
assert len(result.validation_errors) == 2
|
||||
assert "count" in result.validation_errors[0]
|
||||
|
||||
def test_noderesult_empty_validation_errors_by_default(self):
|
||||
"""validation_errors should be empty list by default."""
|
||||
result = NodeResult(success=True, output={"key": "value"})
|
||||
|
||||
assert result.validation_errors == []
|
||||
|
||||
|
||||
# Integration-style tests
|
||||
class TestPydanticValidationIntegration:
|
||||
"""Integration tests for Pydantic validation in node execution."""
|
||||
|
||||
def test_nodespec_serialization_with_output_model(self):
|
||||
"""NodeSpec with output_model should serialize correctly."""
|
||||
node = NodeSpec(
|
||||
id="test",
|
||||
name="Test",
|
||||
description="Test node",
|
||||
output_model=SimpleOutput,
|
||||
)
|
||||
|
||||
# model_dump should work (Pydantic serialization)
|
||||
dumped = node.model_dump()
|
||||
assert "output_model" in dumped
|
||||
# The model class itself is stored, not serialized
|
||||
assert dumped["output_model"] == SimpleOutput
|
||||
|
||||
|
||||
# Phase 3: JSON Schema Generation Tests
|
||||
class TestJSONSchemaGeneration:
|
||||
"""Tests for auto-generating JSON schema from Pydantic model."""
|
||||
|
||||
def test_simple_model_schema_generation(self):
|
||||
"""Should generate correct JSON schema for simple model."""
|
||||
schema = SimpleOutput.model_json_schema()
|
||||
|
||||
assert "properties" in schema
|
||||
assert "message" in schema["properties"]
|
||||
assert "count" in schema["properties"]
|
||||
assert schema["properties"]["message"]["type"] == "string"
|
||||
assert schema["properties"]["count"]["type"] == "integer"
|
||||
|
||||
def test_complex_model_schema_generation(self):
|
||||
"""Should generate correct JSON schema for complex model."""
|
||||
schema = ComplexOutput.model_json_schema()
|
||||
|
||||
assert "properties" in schema
|
||||
assert "query" in schema["properties"]
|
||||
assert "results" in schema["properties"]
|
||||
assert "confidence" in schema["properties"]
|
||||
# Check constraints are in schema
|
||||
assert (
|
||||
"minimum" in schema["properties"]["confidence"]
|
||||
or "exclusiveMinimum" in schema["properties"]["confidence"]
|
||||
)
|
||||
|
||||
def test_schema_includes_required_fields(self):
|
||||
"""JSON schema should include required fields."""
|
||||
schema = SimpleOutput.model_json_schema()
|
||||
|
||||
assert "required" in schema
|
||||
assert "message" in schema["required"]
|
||||
assert "count" in schema["required"]
|
||||
|
||||
def test_schema_can_be_used_in_response_format(self):
|
||||
"""Schema should be usable in LLM response_format parameter."""
|
||||
schema = TicketAnalysis.model_json_schema()
|
||||
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": TicketAnalysis.__name__,
|
||||
"schema": schema,
|
||||
"strict": True,
|
||||
}
|
||||
}
|
||||
|
||||
# Should be valid structure
|
||||
assert response_format["type"] == "json_schema"
|
||||
assert response_format["json_schema"]["name"] == "TicketAnalysis"
|
||||
assert "properties" in response_format["json_schema"]["schema"]
|
||||
|
||||
|
||||
# Phase 2: Retry with Feedback Tests
|
||||
class TestRetryWithFeedback:
|
||||
"""Tests for retry-with-feedback functionality."""
|
||||
|
||||
def test_validation_feedback_format(self):
|
||||
"""Feedback should be properly formatted for LLM retry."""
|
||||
validator = OutputValidator()
|
||||
output = {"priority": 10} # Invalid: missing fields and priority > 5
|
||||
|
||||
result, _ = validator.validate_with_pydantic(output, TicketAnalysis)
|
||||
feedback = validator.format_validation_feedback(result, TicketAnalysis)
|
||||
|
||||
# Should include error details
|
||||
assert "ERRORS:" in feedback
|
||||
assert "EXPECTED SCHEMA:" in feedback
|
||||
assert "TicketAnalysis" in feedback
|
||||
# Should mention missing required fields
|
||||
assert "category" in feedback or "summary" in feedback
|
||||
|
||||
def test_feedback_mentions_fix_instruction(self):
|
||||
"""Feedback should include instruction to fix errors."""
|
||||
validator = OutputValidator()
|
||||
result = ValidationResult(success=False, errors=["test error"])
|
||||
|
||||
feedback = validator.format_validation_feedback(result, SimpleOutput)
|
||||
|
||||
assert "fix" in feedback.lower() or "valid JSON" in feedback
|
||||
|
||||
def test_max_validation_retries_default(self):
|
||||
"""Default max_validation_retries should be 2."""
|
||||
node = NodeSpec(
|
||||
id="test",
|
||||
name="Test",
|
||||
description="Test node",
|
||||
output_model=SimpleOutput,
|
||||
)
|
||||
|
||||
assert node.max_validation_retries == 2
|
||||
|
||||
def test_max_validation_retries_customizable(self):
|
||||
"""max_validation_retries should be customizable."""
|
||||
node = NodeSpec(
|
||||
id="test",
|
||||
name="Test",
|
||||
description="Test node",
|
||||
output_model=SimpleOutput,
|
||||
max_validation_retries=5,
|
||||
)
|
||||
|
||||
assert node.max_validation_retries == 5
|
||||
|
||||
def test_zero_retries_allowed(self):
|
||||
"""Should allow 0 retries (immediate failure on validation error)."""
|
||||
node = NodeSpec(
|
||||
id="test",
|
||||
name="Test",
|
||||
description="Test node",
|
||||
output_model=SimpleOutput,
|
||||
max_validation_retries=0,
|
||||
)
|
||||
|
||||
assert node.max_validation_retries == 0
|
||||
|
||||
def test_feedback_includes_all_error_types(self):
|
||||
"""Feedback should include various error types."""
|
||||
validator = OutputValidator()
|
||||
|
||||
# Create output with multiple errors
|
||||
output = {
|
||||
"category": "X", # too short if there was min_length
|
||||
"priority": 10, # out of range (should be 1-5)
|
||||
"summary": "short", # too short (min_length=10)
|
||||
# missing suggested_action
|
||||
}
|
||||
|
||||
result, _ = validator.validate_with_pydantic(output, TicketAnalysis)
|
||||
feedback = validator.format_validation_feedback(result, TicketAnalysis)
|
||||
|
||||
# Should contain error details
|
||||
assert "ERRORS:" in feedback
|
||||
# Should list multiple errors
|
||||
assert result.errors is not None
|
||||
assert len(result.errors) >= 1
|
||||
|
||||
|
||||
# Extended Integration Tests
|
||||
class TestPydanticValidationIntegrationExtended:
|
||||
"""Extended integration tests for the complete validation flow."""
|
||||
|
||||
def test_nodespec_with_all_validation_options(self):
|
||||
"""NodeSpec should accept all validation-related options."""
|
||||
node = NodeSpec(
|
||||
id="full_test",
|
||||
name="Full Validation Test",
|
||||
description="Tests all validation options",
|
||||
node_type="llm_generate",
|
||||
output_keys=["category", "priority", "summary", "suggested_action"],
|
||||
output_model=TicketAnalysis,
|
||||
max_validation_retries=3,
|
||||
)
|
||||
|
||||
assert node.output_model == TicketAnalysis
|
||||
assert node.max_validation_retries == 3
|
||||
assert len(node.output_keys) == 4
|
||||
|
||||
def test_validator_preserves_model_defaults(self):
|
||||
"""Validated model should preserve default values."""
|
||||
validator = OutputValidator()
|
||||
|
||||
# metadata has a default (default_factory=dict)
|
||||
output = {
|
||||
"query": "test",
|
||||
"results": ["r1"],
|
||||
"confidence": 0.5,
|
||||
# metadata not provided, should use default
|
||||
}
|
||||
|
||||
result, validated = validator.validate_with_pydantic(output, ComplexOutput)
|
||||
|
||||
assert result.success is True
|
||||
assert validated.metadata == {} # default value
|
||||
|
||||
def test_validation_result_error_property(self):
|
||||
"""ValidationResult.error should combine all errors."""
|
||||
result = ValidationResult(
|
||||
success=False,
|
||||
errors=["error1", "error2", "error3"]
|
||||
)
|
||||
|
||||
error_str = result.error
|
||||
|
||||
assert "error1" in error_str
|
||||
assert "error2" in error_str
|
||||
assert "error3" in error_str
|
||||
assert "; " in error_str # errors joined with "; "
|
||||
Reference in New Issue
Block a user