Compare commits

...

5 Commits

Author SHA1 Message Date
Richard T 0770a8161a fix: lint issues 2026-01-26 14:26:58 -08:00
Himanshu Chauhan e7be02b421 feat(validation): add Pydantic model validation for LLM outputs
- Add output_model field to NodeSpec for specifying Pydantic model
- Add max_validation_retries field (default: 2) for retry configuration
- Add validation_errors field to NodeResult for error tracking
- Implement validate_with_pydantic() in OutputValidator
- Implement format_validation_feedback() for LLM retry prompts
- Auto-generate JSON schema from Pydantic model for response_format
- Add retry loop that feeds validation errors back to LLM
- Add 28 comprehensive tests covering all new functionality
2026-01-26 14:26:58 -08:00
Richard T f0453d0a73 Merge branch 'feat/add-deepseek-docs' of https://github.com/SoulSniper-V2/hive into apply-deepseek-docsMerge branch 'feat/add-deepseek-docs' of https://github.com/SoulSniper-V2/hive into apply-deepseek-docs 2026-01-26 14:09:01 -08:00
Arush Wadhawan ae73866f94 docs(llm): add DeepSeek models support documentation and examples
Signed-off-by: Arush Wadhawan <warush23+github@gmail.com>
2026-01-26 14:04:42 -08:00
Arush Wadhawan 40e39d29f8 docs(llm): add DeepSeek models support documentation and examples
Signed-off-by: Arush Wadhawan <warush23+github@gmail.com>
2026-01-26 12:24:51 -05:00
7 changed files with 684 additions and 19 deletions
+1 -1
View File
@@ -331,7 +331,7 @@ No. Aden is built from the ground up with no dependencies on LangChain, CrewAI,
**Q: What LLM providers does Aden support?**
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
**Q: Can I use Aden with local AI models like Ollama?**
+155 -14
View File
@@ -144,7 +144,20 @@ class NodeSpec(BaseModel):
max_retries: int = Field(default=3)
retry_on: list[str] = Field(default_factory=list, description="Error types to retry on")
model_config = {"extra": "allow"}
# Pydantic model for output validation
output_model: type[BaseModel] | None = Field(
default=None,
description=(
"Optional Pydantic model class for validating and parsing LLM output. "
"When set, the LLM response will be validated against this model."
),
)
max_validation_retries: int = Field(
default=2,
description="Maximum retries when Pydantic validation fails (with feedback to LLM)"
)
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
class MemoryWriteError(Exception):
@@ -346,6 +359,9 @@ class NodeResult:
tokens_used: int = 0
latency_ms: int = 0
# Pydantic validation errors (if any)
validation_errors: list[str] = field(default_factory=list)
def to_summary(self, node_spec: Any = None) -> str:
"""
Generate a human-readable summary of this node's execution and output.
@@ -597,19 +613,133 @@ class LLMNode(NodeProtocol):
f" 📋 Expecting JSON output with keys: {ctx.node_spec.output_keys}"
)
response = ctx.llm.complete(
messages=messages,
system=system,
json_mode=use_json_mode,
)
# Phase 3: Auto-generate JSON schema from Pydantic model
response_format = None
if ctx.node_spec.output_model is not None:
json_schema = ctx.node_spec.output_model.model_json_schema()
response_format = {
"type": "json_schema",
"json_schema": {
"name": ctx.node_spec.output_model.__name__,
"schema": json_schema,
"strict": True,
}
}
logger.info(
f" 📐 Using JSON schema from Pydantic model: "
f"{ctx.node_spec.output_model.__name__}"
)
# Log the response
response_preview = (
response.content[:200] if len(response.content) > 200 else response.content
)
if len(response.content) > 200:
response_preview += "..."
logger.info(f" ← Response: {response_preview}")
# Phase 2: Retry loop for Pydantic validation
max_validation_retries = (
ctx.node_spec.max_validation_retries if ctx.node_spec.output_model else 0
)
validation_attempt = 0
total_input_tokens = 0
total_output_tokens = 0
current_messages = messages.copy()
while True:
response = ctx.llm.complete(
messages=current_messages,
system=system,
json_mode=use_json_mode,
response_format=response_format,
)
total_input_tokens += response.input_tokens
total_output_tokens += response.output_tokens
# Log the response
response_preview = (
response.content[:200] if len(response.content) > 200 else response.content
)
if len(response.content) > 200:
response_preview += "..."
logger.info(f" ← Response: {response_preview}")
# If no output_model, break immediately (no validation needed)
if ctx.node_spec.output_model is None:
break
# Try to parse and validate the response
try:
import json
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
if isinstance(parsed, dict):
from framework.graph.validator import OutputValidator
validator = OutputValidator()
validation_result, validated_model = validator.validate_with_pydantic(
parsed, ctx.node_spec.output_model
)
if validation_result.success:
# Validation passed, break out of retry loop
logger.info(
f" ✓ Pydantic validation passed for "
f"{ctx.node_spec.output_model.__name__}"
)
break
else:
# Validation failed
validation_attempt += 1
if validation_attempt <= max_validation_retries:
# Add validation feedback to messages and retry
feedback = validator.format_validation_feedback(
validation_result, ctx.node_spec.output_model
)
logger.warning(
f" ⚠ Pydantic validation failed "
f"(attempt {validation_attempt}/{max_validation_retries}): "
f"{validation_result.error}"
)
logger.info(" 🔄 Retrying with validation feedback...")
# Add the assistant's failed response and feedback
current_messages.append({
"role": "assistant",
"content": response.content
})
current_messages.append({
"role": "user",
"content": feedback
})
continue # Retry the LLM call
else:
# Max retries exceeded
latency_ms = int((time.time() - start) * 1000)
logger.error(
f" ✗ Pydantic validation failed after "
f"{max_validation_retries} retries: "
f"{validation_result.error}"
)
ctx.runtime.record_outcome(
decision_id=decision_id,
success=False,
error=f"Validation failed: {validation_result.error}",
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
return NodeResult(
success=False,
error=(
f"Pydantic validation failed after "
f"{max_validation_retries} retries: "
f"{validation_result.error}"
),
output=parsed,
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
validation_errors=validation_result.errors,
)
else:
# Not a dict, can't validate - break and let downstream handle
break
except Exception:
# JSON extraction failed - break and let downstream handle
break
latency_ms = int((time.time() - start) * 1000)
@@ -635,8 +765,19 @@ class LLMNode(NodeProtocol):
# Try to extract JSON from response
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
# If parsed successfully, write each field to its corresponding output key
# If parsed successfully, validate against Pydantic model if specified
if isinstance(parsed, dict):
# If we have output_model, the validation already happened in the retry loop
if ctx.node_spec.output_model is not None:
from framework.graph.validator import OutputValidator
validator = OutputValidator()
validation_result, validated_model = validator.validate_with_pydantic(
parsed, ctx.node_spec.output_model
)
# Use validated model's dict representation
if validated_model:
parsed = validated_model.model_dump()
for key in ctx.node_spec.output_keys:
if key in parsed:
value = parsed[key]
+67
View File
@@ -8,6 +8,8 @@ import logging
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel, ValidationError
logger = logging.getLogger(__name__)
@@ -131,6 +133,71 @@ class OutputValidator:
return ValidationResult(success=len(errors) == 0, errors=errors)
def validate_with_pydantic(
self,
output: dict[str, Any],
model: type[BaseModel],
) -> tuple[ValidationResult, BaseModel | None]:
"""
Validate output against a Pydantic model.
Args:
output: The output dict to validate
model: Pydantic model class to validate against
Returns:
Tuple of (ValidationResult, validated_model_instance or None)
"""
try:
validated = model.model_validate(output)
return ValidationResult(success=True, errors=[]), validated
except ValidationError as e:
errors = []
for error in e.errors():
field_path = ".".join(str(loc) for loc in error["loc"])
msg = error["msg"]
error_type = error["type"]
errors.append(f"{field_path}: {msg} (type: {error_type})")
return ValidationResult(success=False, errors=errors), None
def format_validation_feedback(
self,
validation_result: ValidationResult,
model: type[BaseModel],
) -> str:
"""
Format validation errors as feedback for LLM retry.
Args:
validation_result: The failed validation result
model: The Pydantic model that was used for validation
Returns:
Formatted feedback string to include in retry prompt
"""
# Get the model's JSON schema for reference
schema = model.model_json_schema()
feedback = "Your previous response had validation errors:\n\n"
feedback += "ERRORS:\n"
for error in validation_result.errors:
feedback += f" - {error}\n"
feedback += "\nEXPECTED SCHEMA:\n"
feedback += f" Model: {model.__name__}\n"
if "properties" in schema:
feedback += " Required fields:\n"
required = schema.get("required", [])
for prop_name, prop_info in schema["properties"].items():
req_marker = " (required)" if prop_name in required else ""
prop_type = prop_info.get("type", "any")
feedback += f" - {prop_name}: {prop_type}{req_marker}\n"
feedback += "\nPlease fix the errors and respond with valid JSON matching the schema."
return feedback
def validate_no_hallucination(
self,
output: dict[str, Any],
+4
View File
@@ -27,6 +27,7 @@ class LiteLLMProvider(LLMProvider):
- OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
- Anthropic: claude-3-opus, claude-3-sonnet, claude-3-haiku
- Google: gemini-pro, gemini-1.5-pro, gemini-1.5-flash
- DeepSeek: deepseek-chat, deepseek-coder, deepseek-reasoner
- Mistral: mistral-large, mistral-medium, mistral-small
- Groq: llama3-70b, mixtral-8x7b
- Local: ollama/llama3, ollama/mistral
@@ -42,6 +43,9 @@ class LiteLLMProvider(LLMProvider):
# Google Gemini
provider = LiteLLMProvider(model="gemini/gemini-1.5-flash")
# DeepSeek
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
# Local Ollama
provider = LiteLLMProvider(model="ollama/llama3")
+6
View File
@@ -34,6 +34,12 @@ class TestLiteLLMProviderInit:
provider = LiteLLMProvider(model="claude-3-haiku-20240307")
assert provider.model == "claude-3-haiku-20240307"
def test_init_deepseek_model(self):
"""Test initialization with DeepSeek model."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
assert provider.model == "deepseek/deepseek-chat"
def test_init_with_api_key(self):
"""Test initialization with explicit API key."""
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="my-api-key")
+10 -4
View File
@@ -100,12 +100,18 @@ class TestJsonExtraction:
result = node._extract_json(input_text, ["count", "price"])
assert result == {"count": 42, "price": 19.99}
def test_invalid_json_raises_error(self, node):
"""Test that completely invalid JSON raises an error."""
def test_invalid_json_raises_error(self, node, monkeypatch):
"""Test that completely invalid JSON raises an error when no LLM fallback available."""
# Remove API keys so LLM fallback is not attempted
monkeypatch.delenv("CEREBRAS_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
with pytest.raises(ValueError, match="Cannot parse JSON"):
node._extract_json("This is not JSON at all", ["key"])
def test_empty_string_raises_error(self, node):
"""Test that empty string raises an error."""
def test_empty_string_raises_error(self, node, monkeypatch):
"""Test that empty string raises an error when no LLM fallback available."""
# Remove API keys so LLM fallback is not attempted
monkeypatch.delenv("CEREBRAS_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
with pytest.raises(ValueError, match="Cannot parse JSON"):
node._extract_json("", ["key"])
+441
View File
@@ -0,0 +1,441 @@
"""
Tests for Pydantic validation of LLM outputs.
Tests the new output_model feature in NodeSpec that allows
validating LLM responses against Pydantic models.
"""
from pydantic import BaseModel, Field
from framework.graph.node import NodeResult, NodeSpec
from framework.graph.validator import OutputValidator, ValidationResult
# Test Pydantic models
class SimpleOutput(BaseModel):
"""Simple test model."""
message: str
count: int
class ComplexOutput(BaseModel):
"""Complex test model with nested types."""
query: str
results: list[str] = Field(min_length=1)
confidence: float = Field(ge=0, le=1)
metadata: dict[str, str] = Field(default_factory=dict)
class TicketAnalysis(BaseModel):
"""Realistic use case model."""
category: str
priority: int = Field(ge=1, le=5)
summary: str = Field(min_length=10)
suggested_action: str
class TestNodeSpecOutputModel:
"""Tests for output_model field in NodeSpec."""
def test_nodespec_accepts_output_model(self):
"""NodeSpec should accept a Pydantic model class."""
node = NodeSpec(
id="test_node",
name="Test Node",
description="A test node",
node_type="llm_generate",
output_model=SimpleOutput,
)
assert node.output_model == SimpleOutput
assert node.max_validation_retries == 2 # default
def test_nodespec_output_model_optional(self):
"""output_model should be optional (None by default)."""
node = NodeSpec(
id="test_node",
name="Test Node",
description="A test node",
)
assert node.output_model is None
def test_nodespec_custom_validation_retries(self):
"""Should support custom max_validation_retries."""
node = NodeSpec(
id="test_node",
name="Test Node",
description="A test node",
output_model=SimpleOutput,
max_validation_retries=5,
)
assert node.max_validation_retries == 5
class TestOutputValidatorPydantic:
"""Tests for validate_with_pydantic method."""
def test_validate_valid_output(self):
"""Should pass for valid output matching model."""
validator = OutputValidator()
output = {"message": "Hello", "count": 5}
result, validated = validator.validate_with_pydantic(output, SimpleOutput)
assert result.success is True
assert len(result.errors) == 0
assert validated is not None
assert validated.message == "Hello"
assert validated.count == 5
def test_validate_missing_required_field(self):
"""Should fail when required field is missing."""
validator = OutputValidator()
output = {"message": "Hello"} # missing 'count'
result, validated = validator.validate_with_pydantic(output, SimpleOutput)
assert result.success is False
assert len(result.errors) > 0
assert "count" in result.errors[0]
assert validated is None
def test_validate_wrong_type(self):
"""Should fail when field has wrong type."""
validator = OutputValidator()
output = {"message": "Hello", "count": "five"} # count should be int
result, validated = validator.validate_with_pydantic(output, SimpleOutput)
assert result.success is False
assert len(result.errors) > 0
assert validated is None
def test_validate_complex_model(self):
"""Should validate complex nested models."""
validator = OutputValidator()
output = {
"query": "test query",
"results": ["result1", "result2"],
"confidence": 0.85,
"metadata": {"source": "test"},
}
result, validated = validator.validate_with_pydantic(output, ComplexOutput)
assert result.success is True
assert validated is not None
assert validated.query == "test query"
assert len(validated.results) == 2
assert validated.confidence == 0.85
def test_validate_field_constraints(self):
"""Should validate field constraints (min_length, ge, le, etc.)."""
validator = OutputValidator()
# Empty results list (violates min_length=1)
output = {
"query": "test",
"results": [], # should have at least 1 item
"confidence": 0.5,
}
result, validated = validator.validate_with_pydantic(output, ComplexOutput)
assert result.success is False
assert "results" in result.error
def test_validate_range_constraints(self):
"""Should validate range constraints (ge, le)."""
validator = OutputValidator()
# Confidence out of range
output = {
"query": "test",
"results": ["r1"],
"confidence": 1.5, # should be <= 1
}
result, validated = validator.validate_with_pydantic(output, ComplexOutput)
assert result.success is False
assert "confidence" in result.error
def test_validate_realistic_model(self):
"""Should work with realistic use case models."""
validator = OutputValidator()
output = {
"category": "Technical Support",
"priority": 3,
"summary": "User is experiencing login issues with error 401",
"suggested_action": "Reset password and verify account status",
}
result, validated = validator.validate_with_pydantic(output, TicketAnalysis)
assert result.success is True
assert validated is not None
assert validated.category == "Technical Support"
assert validated.priority == 3
class TestValidationFeedback:
"""Tests for format_validation_feedback method."""
def test_format_feedback_includes_errors(self):
"""Feedback should include validation errors."""
validator = OutputValidator()
output = {"message": "Hello"} # missing count
result, _ = validator.validate_with_pydantic(output, SimpleOutput)
feedback = validator.format_validation_feedback(result, SimpleOutput)
assert "validation errors" in feedback.lower()
assert "count" in feedback
assert "SimpleOutput" in feedback
def test_format_feedback_includes_schema(self):
"""Feedback should include expected schema information."""
validator = OutputValidator()
result = ValidationResult(success=False, errors=["test error"])
feedback = validator.format_validation_feedback(result, SimpleOutput)
assert "message" in feedback
assert "count" in feedback
assert "required" in feedback.lower()
class TestNodeResultValidationErrors:
"""Tests for validation_errors field in NodeResult."""
def test_noderesult_includes_validation_errors(self):
"""NodeResult should store validation errors."""
result = NodeResult(
success=False,
error="Pydantic validation failed",
validation_errors=["count: field required", "priority: must be >= 1"],
)
assert len(result.validation_errors) == 2
assert "count" in result.validation_errors[0]
def test_noderesult_empty_validation_errors_by_default(self):
"""validation_errors should be empty list by default."""
result = NodeResult(success=True, output={"key": "value"})
assert result.validation_errors == []
# Integration-style tests
class TestPydanticValidationIntegration:
"""Integration tests for Pydantic validation in node execution."""
def test_nodespec_serialization_with_output_model(self):
"""NodeSpec with output_model should serialize correctly."""
node = NodeSpec(
id="test",
name="Test",
description="Test node",
output_model=SimpleOutput,
)
# model_dump should work (Pydantic serialization)
dumped = node.model_dump()
assert "output_model" in dumped
# The model class itself is stored, not serialized
assert dumped["output_model"] == SimpleOutput
# Phase 3: JSON Schema Generation Tests
class TestJSONSchemaGeneration:
"""Tests for auto-generating JSON schema from Pydantic model."""
def test_simple_model_schema_generation(self):
"""Should generate correct JSON schema for simple model."""
schema = SimpleOutput.model_json_schema()
assert "properties" in schema
assert "message" in schema["properties"]
assert "count" in schema["properties"]
assert schema["properties"]["message"]["type"] == "string"
assert schema["properties"]["count"]["type"] == "integer"
def test_complex_model_schema_generation(self):
"""Should generate correct JSON schema for complex model."""
schema = ComplexOutput.model_json_schema()
assert "properties" in schema
assert "query" in schema["properties"]
assert "results" in schema["properties"]
assert "confidence" in schema["properties"]
# Check constraints are in schema
assert (
"minimum" in schema["properties"]["confidence"]
or "exclusiveMinimum" in schema["properties"]["confidence"]
)
def test_schema_includes_required_fields(self):
"""JSON schema should include required fields."""
schema = SimpleOutput.model_json_schema()
assert "required" in schema
assert "message" in schema["required"]
assert "count" in schema["required"]
def test_schema_can_be_used_in_response_format(self):
"""Schema should be usable in LLM response_format parameter."""
schema = TicketAnalysis.model_json_schema()
response_format = {
"type": "json_schema",
"json_schema": {
"name": TicketAnalysis.__name__,
"schema": schema,
"strict": True,
}
}
# Should be valid structure
assert response_format["type"] == "json_schema"
assert response_format["json_schema"]["name"] == "TicketAnalysis"
assert "properties" in response_format["json_schema"]["schema"]
# Phase 2: Retry with Feedback Tests
class TestRetryWithFeedback:
"""Tests for retry-with-feedback functionality."""
def test_validation_feedback_format(self):
"""Feedback should be properly formatted for LLM retry."""
validator = OutputValidator()
output = {"priority": 10} # Invalid: missing fields and priority > 5
result, _ = validator.validate_with_pydantic(output, TicketAnalysis)
feedback = validator.format_validation_feedback(result, TicketAnalysis)
# Should include error details
assert "ERRORS:" in feedback
assert "EXPECTED SCHEMA:" in feedback
assert "TicketAnalysis" in feedback
# Should mention missing required fields
assert "category" in feedback or "summary" in feedback
def test_feedback_mentions_fix_instruction(self):
"""Feedback should include instruction to fix errors."""
validator = OutputValidator()
result = ValidationResult(success=False, errors=["test error"])
feedback = validator.format_validation_feedback(result, SimpleOutput)
assert "fix" in feedback.lower() or "valid JSON" in feedback
def test_max_validation_retries_default(self):
"""Default max_validation_retries should be 2."""
node = NodeSpec(
id="test",
name="Test",
description="Test node",
output_model=SimpleOutput,
)
assert node.max_validation_retries == 2
def test_max_validation_retries_customizable(self):
"""max_validation_retries should be customizable."""
node = NodeSpec(
id="test",
name="Test",
description="Test node",
output_model=SimpleOutput,
max_validation_retries=5,
)
assert node.max_validation_retries == 5
def test_zero_retries_allowed(self):
"""Should allow 0 retries (immediate failure on validation error)."""
node = NodeSpec(
id="test",
name="Test",
description="Test node",
output_model=SimpleOutput,
max_validation_retries=0,
)
assert node.max_validation_retries == 0
def test_feedback_includes_all_error_types(self):
"""Feedback should include various error types."""
validator = OutputValidator()
# Create output with multiple errors
output = {
"category": "X", # too short if there was min_length
"priority": 10, # out of range (should be 1-5)
"summary": "short", # too short (min_length=10)
# missing suggested_action
}
result, _ = validator.validate_with_pydantic(output, TicketAnalysis)
feedback = validator.format_validation_feedback(result, TicketAnalysis)
# Should contain error details
assert "ERRORS:" in feedback
# Should list multiple errors
assert result.errors is not None
assert len(result.errors) >= 1
# Extended Integration Tests
class TestPydanticValidationIntegrationExtended:
"""Extended integration tests for the complete validation flow."""
def test_nodespec_with_all_validation_options(self):
"""NodeSpec should accept all validation-related options."""
node = NodeSpec(
id="full_test",
name="Full Validation Test",
description="Tests all validation options",
node_type="llm_generate",
output_keys=["category", "priority", "summary", "suggested_action"],
output_model=TicketAnalysis,
max_validation_retries=3,
)
assert node.output_model == TicketAnalysis
assert node.max_validation_retries == 3
assert len(node.output_keys) == 4
def test_validator_preserves_model_defaults(self):
"""Validated model should preserve default values."""
validator = OutputValidator()
# metadata has a default (default_factory=dict)
output = {
"query": "test",
"results": ["r1"],
"confidence": 0.5,
# metadata not provided, should use default
}
result, validated = validator.validate_with_pydantic(output, ComplexOutput)
assert result.success is True
assert validated.metadata == {} # default value
def test_validation_result_error_property(self):
"""ValidationResult.error should combine all errors."""
result = ValidationResult(
success=False,
errors=["error1", "error2", "error3"]
)
error_str = result.error
assert "error1" in error_str
assert "error2" in error_str
assert "error3" in error_str
assert "; " in error_str # errors joined with "; "