394 lines
14 KiB
Python
394 lines
14 KiB
Python
"""
|
|
Output Cleaner - Framework-level I/O validation and cleaning.
|
|
|
|
Validates node outputs match expected schemas and uses fast LLM
|
|
to clean malformed outputs before they flow to the next node.
|
|
|
|
This prevents cascading failures and dramatically improves execution success rates.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _heuristic_repair(text: str) -> dict | None:
|
|
"""
|
|
Attempt to repair JSON without an LLM call.
|
|
|
|
Handles common errors:
|
|
- Markdown code blocks
|
|
- Python booleans/None (True -> true)
|
|
- Single quotes instead of double quotes
|
|
"""
|
|
if not isinstance(text, str):
|
|
return None
|
|
|
|
# 1. Strip Markdown code blocks
|
|
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.MULTILINE)
|
|
text = re.sub(r"\s*```$", "", text, flags=re.MULTILINE)
|
|
text = text.strip()
|
|
|
|
# 2. Find outermost JSON-like structure (greedy match)
|
|
match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
|
|
if match:
|
|
candidate = match.group(1)
|
|
|
|
# 3. Common fixes
|
|
# Fix Python constants
|
|
candidate = re.sub(r"\bTrue\b", "true", candidate)
|
|
candidate = re.sub(r"\bFalse\b", "false", candidate)
|
|
candidate = re.sub(r"\bNone\b", "null", candidate)
|
|
|
|
# 4. Attempt load
|
|
try:
|
|
return json.loads(candidate)
|
|
except json.JSONDecodeError:
|
|
# 5. Advanced: Try swapping single quotes if double quotes fail
|
|
# This is risky but effective for simple dicts
|
|
try:
|
|
if "'" in candidate and '"' not in candidate:
|
|
candidate_swapped = candidate.replace("'", '"')
|
|
return json.loads(candidate_swapped)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class CleansingConfig:
|
|
"""Configuration for output cleansing."""
|
|
|
|
enabled: bool = True
|
|
fast_model: str = "cerebras/llama-3.3-70b" # Fast, cheap model for cleaning
|
|
max_retries: int = 2
|
|
cache_successful_patterns: bool = True
|
|
fallback_to_raw: bool = True # If cleaning fails, pass raw output
|
|
log_cleanings: bool = True # Log when cleansing happens
|
|
|
|
|
|
@dataclass
|
|
class ValidationResult:
|
|
"""Result of output validation."""
|
|
|
|
valid: bool
|
|
errors: list[str] = field(default_factory=list)
|
|
warnings: list[str] = field(default_factory=list)
|
|
cleaned_output: dict[str, Any] | None = None
|
|
|
|
|
|
class OutputCleaner:
|
|
"""
|
|
Framework-level output validation and cleaning.
|
|
|
|
Uses heuristics and fast LLM to clean malformed outputs
|
|
before they flow to the next node.
|
|
"""
|
|
|
|
def __init__(self, config: CleansingConfig, llm_provider=None):
|
|
"""
|
|
Initialize the output cleaner.
|
|
|
|
Args:
|
|
config: Cleansing configuration
|
|
llm_provider: Optional LLM provider.
|
|
"""
|
|
self.config = config
|
|
self.success_cache: dict[str, Any] = {} # Cache successful patterns
|
|
self.failure_count: dict[str, int] = {} # Track edge failures
|
|
self.cleansing_count = 0 # Track total cleanings performed
|
|
|
|
# Initialize LLM provider for cleaning
|
|
if llm_provider:
|
|
self.llm = llm_provider
|
|
elif config.enabled:
|
|
# Create dedicated fast LLM provider for cleaning
|
|
try:
|
|
import os
|
|
|
|
from framework.llm.litellm import LiteLLMProvider
|
|
|
|
api_key = os.environ.get("CEREBRAS_API_KEY")
|
|
if api_key:
|
|
self.llm = LiteLLMProvider(
|
|
api_key=api_key,
|
|
model=config.fast_model,
|
|
temperature=0.0, # Deterministic cleaning
|
|
)
|
|
logger.info(f"✓ Initialized OutputCleaner with {config.fast_model}")
|
|
else:
|
|
logger.warning("⚠ CEREBRAS_API_KEY not found, output cleaning will be disabled")
|
|
self.llm = None
|
|
except ImportError:
|
|
logger.warning("⚠ LiteLLMProvider not available, output cleaning disabled")
|
|
self.llm = None
|
|
else:
|
|
self.llm = None
|
|
|
|
def validate_output(
|
|
self,
|
|
output: dict[str, Any],
|
|
source_node_id: str,
|
|
target_node_spec: Any, # NodeSpec
|
|
) -> ValidationResult:
|
|
"""
|
|
Validate output matches target node's expected input schema.
|
|
|
|
Returns:
|
|
ValidationResult with errors and optionally cleaned output
|
|
"""
|
|
errors = []
|
|
warnings = []
|
|
|
|
# Check 1: Required input keys present
|
|
for key in target_node_spec.input_keys:
|
|
if key not in output:
|
|
errors.append(f"Missing required key: '{key}'")
|
|
continue
|
|
|
|
value = output[key]
|
|
|
|
# Check 2: Detect if value is JSON string (the JSON parsing trap!)
|
|
if isinstance(value, str):
|
|
# Try parsing as JSON to detect the trap
|
|
try:
|
|
parsed = json.loads(value)
|
|
if isinstance(parsed, dict):
|
|
if key in parsed:
|
|
# Key exists in parsed JSON - classic parsing failure!
|
|
errors.append(
|
|
f"Key '{key}' contains JSON string with nested '{key}' field - "
|
|
f"likely parsing failure from LLM node"
|
|
)
|
|
elif len(value) > 100:
|
|
# Large JSON string, but doesn't contain the key
|
|
warnings.append(
|
|
f"Key '{key}' contains JSON string ({len(value)} chars)"
|
|
)
|
|
except json.JSONDecodeError:
|
|
# Not JSON, check if suspiciously large
|
|
if len(value) > 500:
|
|
warnings.append(
|
|
f"Key '{key}' contains large string ({len(value)} chars), "
|
|
f"possibly entire LLM response"
|
|
)
|
|
|
|
# Check 3: Type validation (if schema provided)
|
|
if hasattr(target_node_spec, "input_schema") and target_node_spec.input_schema:
|
|
expected_schema = target_node_spec.input_schema.get(key)
|
|
if expected_schema:
|
|
expected_type = expected_schema.get("type")
|
|
if expected_type and not self._type_matches(value, expected_type):
|
|
actual_type = type(value).__name__
|
|
errors.append(
|
|
f"Key '{key}': expected type '{expected_type}', got '{actual_type}'"
|
|
)
|
|
|
|
# Warnings don't make validation fail, but errors do
|
|
is_valid = len(errors) == 0
|
|
|
|
if not is_valid and self.config.log_cleanings:
|
|
logger.warning(
|
|
f"⚠ Output validation failed for {source_node_id} → {target_node_spec.id}: "
|
|
f"{len(errors)} error(s), {len(warnings)} warning(s)"
|
|
)
|
|
|
|
return ValidationResult(
|
|
valid=is_valid,
|
|
errors=errors,
|
|
warnings=warnings,
|
|
)
|
|
|
|
def clean_output(
|
|
self,
|
|
output: dict[str, Any],
|
|
source_node_id: str,
|
|
target_node_spec: Any, # NodeSpec
|
|
validation_errors: list[str],
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Use heuristics and fast LLM to clean malformed output.
|
|
|
|
Args:
|
|
output: Raw output from source node
|
|
source_node_id: ID of source node
|
|
target_node_spec: Target node spec (for schema)
|
|
validation_errors: Errors from validation
|
|
|
|
Returns:
|
|
Cleaned output matching target schema
|
|
"""
|
|
if not self.config.enabled:
|
|
logger.warning("⚠ Output cleansing disabled in config")
|
|
return output
|
|
|
|
# --- PHASE 1: Fast Heuristic Repair (Avoids LLM call) ---
|
|
# Often the output is just a string containing JSON, or has minor syntax errors
|
|
# If output is a dictionary but malformed, we might need to serialize it first
|
|
# to try and fix the underlying string representation if it came from raw text
|
|
|
|
# Heuristic: Check if any value is actually a JSON string that should be promoted
|
|
# This handles the "JSON Parsing Trap" where LLM returns {"key": "{\"nested\": ...}"}
|
|
heuristic_fixed = False
|
|
fixed_output = output.copy()
|
|
|
|
for key, value in output.items():
|
|
if isinstance(value, str):
|
|
repaired = _heuristic_repair(value)
|
|
if repaired and isinstance(repaired, dict | list):
|
|
# Check if this repaired structure looks like what we want
|
|
# e.g. if the key is 'data' and the string contained valid JSON
|
|
fixed_output[key] = repaired
|
|
heuristic_fixed = True
|
|
|
|
# If we fixed something, re-validate manually to see if it's enough
|
|
if heuristic_fixed:
|
|
logger.info("⚡ Heuristic repair applied (nested JSON expansion)")
|
|
return fixed_output
|
|
|
|
# --- PHASE 2: LLM-based Repair ---
|
|
if not self.llm:
|
|
logger.warning("⚠ No LLM provider available for cleansing")
|
|
return output
|
|
|
|
# Build schema description for target node
|
|
schema_desc = self._build_schema_description(target_node_spec)
|
|
|
|
# Create cleansing prompt
|
|
prompt = f"""Clean this malformed agent output to match the expected schema.
|
|
|
|
VALIDATION ERRORS:
|
|
{chr(10).join(f"- {e}" for e in validation_errors)}
|
|
|
|
EXPECTED SCHEMA for node '{target_node_spec.id}':
|
|
{schema_desc}
|
|
|
|
RAW OUTPUT from node '{source_node_id}':
|
|
{json.dumps(output, indent=2, default=str)}
|
|
|
|
INSTRUCTIONS:
|
|
1. Extract values that match the expected schema keys
|
|
2. If a value is a JSON string, parse it and extract the correct field
|
|
3. Convert types to match the schema (string, dict, list, number, boolean)
|
|
4. Remove extra fields not in the expected schema
|
|
5. Ensure all required keys are present
|
|
|
|
Return ONLY valid JSON matching the expected schema. No explanations, no markdown."""
|
|
|
|
try:
|
|
if self.config.log_cleanings:
|
|
logger.info(
|
|
f"🧹 Cleaning output from '{source_node_id}' using {self.config.fast_model}"
|
|
)
|
|
|
|
response = self.llm.complete(
|
|
messages=[{"role": "user", "content": prompt}],
|
|
system=(
|
|
"You clean malformed agent outputs. Return only valid JSON matching the schema."
|
|
),
|
|
max_tokens=2048, # Sufficient for cleaning most outputs
|
|
)
|
|
|
|
# Parse cleaned output
|
|
cleaned_text = response.content.strip()
|
|
|
|
# Apply heuristic repair to the LLM's output too (just in case)
|
|
cleaned = _heuristic_repair(cleaned_text)
|
|
|
|
if not cleaned:
|
|
# Fallback to standard load if heuristic returns None (unlikely for LLM output)
|
|
cleaned = json.loads(cleaned_text)
|
|
|
|
if isinstance(cleaned, dict):
|
|
self.cleansing_count += 1
|
|
if self.config.log_cleanings:
|
|
logger.info(
|
|
f"✓ Output cleaned successfully (total cleanings: {self.cleansing_count})"
|
|
)
|
|
return cleaned
|
|
else:
|
|
logger.warning(f"⚠ Cleaned output is not a dict: {type(cleaned)}")
|
|
if self.config.fallback_to_raw:
|
|
return output
|
|
else:
|
|
raise ValueError(f"Cleaning produced {type(cleaned)}, expected dict")
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"✗ Failed to parse cleaned JSON: {e}")
|
|
if self.config.fallback_to_raw:
|
|
logger.info("↩ Falling back to raw output")
|
|
return output
|
|
else:
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"✗ Output cleaning failed: {e}")
|
|
if self.config.fallback_to_raw:
|
|
logger.info("↩ Falling back to raw output")
|
|
return output
|
|
else:
|
|
raise
|
|
|
|
def _build_schema_description(self, node_spec: Any) -> str:
|
|
"""Build human-readable schema description from NodeSpec."""
|
|
lines = ["{"]
|
|
|
|
for key in node_spec.input_keys:
|
|
# Get type hint and description if available
|
|
if hasattr(node_spec, "input_schema") and node_spec.input_schema:
|
|
schema = node_spec.input_schema.get(key, {})
|
|
type_hint = schema.get("type", "any")
|
|
description = schema.get("description", "")
|
|
required = schema.get("required", True)
|
|
|
|
line = f' "{key}": {type_hint}'
|
|
if description:
|
|
line += f" // {description}"
|
|
if required:
|
|
line += " (required)"
|
|
lines.append(line + ",")
|
|
else:
|
|
# No schema, just show the key
|
|
lines.append(f' "{key}": any // (required)')
|
|
|
|
lines.append("}")
|
|
return "\n".join(lines)
|
|
|
|
def _type_matches(self, value: Any, expected_type: str) -> bool:
|
|
"""Check if value matches expected type."""
|
|
type_map = {
|
|
"string": str,
|
|
"str": str,
|
|
"int": int,
|
|
"integer": int,
|
|
"float": float,
|
|
"number": (int, float),
|
|
"bool": bool,
|
|
"boolean": bool,
|
|
"dict": dict,
|
|
"object": dict,
|
|
"list": list,
|
|
"array": list,
|
|
"any": object, # Matches everything
|
|
}
|
|
|
|
expected_class = type_map.get(expected_type.lower())
|
|
if expected_class:
|
|
return isinstance(value, expected_class)
|
|
|
|
# Unknown type, allow it
|
|
return True
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Get cleansing statistics."""
|
|
return {
|
|
"total_cleanings": self.cleansing_count,
|
|
"failure_count": dict(self.failure_count),
|
|
"cache_size": len(self.success_cache),
|
|
}
|