Files
hive/core/framework/graph/output_cleaner.py
T
2026-01-27 10:45:49 -08:00

394 lines
14 KiB
Python

"""
Output Cleaner - Framework-level I/O validation and cleaning.
Validates node outputs match expected schemas and uses fast LLM
to clean malformed outputs before they flow to the next node.
This prevents cascading failures and dramatically improves execution success rates.
"""
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
def _heuristic_repair(text: str) -> dict | None:
"""
Attempt to repair JSON without an LLM call.
Handles common errors:
- Markdown code blocks
- Python booleans/None (True -> true)
- Single quotes instead of double quotes
"""
if not isinstance(text, str):
return None
# 1. Strip Markdown code blocks
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.MULTILINE)
text = re.sub(r"\s*```$", "", text, flags=re.MULTILINE)
text = text.strip()
# 2. Find outermost JSON-like structure (greedy match)
match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
if match:
candidate = match.group(1)
# 3. Common fixes
# Fix Python constants
candidate = re.sub(r"\bTrue\b", "true", candidate)
candidate = re.sub(r"\bFalse\b", "false", candidate)
candidate = re.sub(r"\bNone\b", "null", candidate)
# 4. Attempt load
try:
return json.loads(candidate)
except json.JSONDecodeError:
# 5. Advanced: Try swapping single quotes if double quotes fail
# This is risky but effective for simple dicts
try:
if "'" in candidate and '"' not in candidate:
candidate_swapped = candidate.replace("'", '"')
return json.loads(candidate_swapped)
except json.JSONDecodeError:
pass
return None
@dataclass
class CleansingConfig:
"""Configuration for output cleansing."""
enabled: bool = True
fast_model: str = "cerebras/llama-3.3-70b" # Fast, cheap model for cleaning
max_retries: int = 2
cache_successful_patterns: bool = True
fallback_to_raw: bool = True # If cleaning fails, pass raw output
log_cleanings: bool = True # Log when cleansing happens
@dataclass
class ValidationResult:
"""Result of output validation."""
valid: bool
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
cleaned_output: dict[str, Any] | None = None
class OutputCleaner:
"""
Framework-level output validation and cleaning.
Uses heuristics and fast LLM to clean malformed outputs
before they flow to the next node.
"""
def __init__(self, config: CleansingConfig, llm_provider=None):
"""
Initialize the output cleaner.
Args:
config: Cleansing configuration
llm_provider: Optional LLM provider.
"""
self.config = config
self.success_cache: dict[str, Any] = {} # Cache successful patterns
self.failure_count: dict[str, int] = {} # Track edge failures
self.cleansing_count = 0 # Track total cleanings performed
# Initialize LLM provider for cleaning
if llm_provider:
self.llm = llm_provider
elif config.enabled:
# Create dedicated fast LLM provider for cleaning
try:
import os
from framework.llm.litellm import LiteLLMProvider
api_key = os.environ.get("CEREBRAS_API_KEY")
if api_key:
self.llm = LiteLLMProvider(
api_key=api_key,
model=config.fast_model,
temperature=0.0, # Deterministic cleaning
)
logger.info(f"✓ Initialized OutputCleaner with {config.fast_model}")
else:
logger.warning("⚠ CEREBRAS_API_KEY not found, output cleaning will be disabled")
self.llm = None
except ImportError:
logger.warning("⚠ LiteLLMProvider not available, output cleaning disabled")
self.llm = None
else:
self.llm = None
def validate_output(
self,
output: dict[str, Any],
source_node_id: str,
target_node_spec: Any, # NodeSpec
) -> ValidationResult:
"""
Validate output matches target node's expected input schema.
Returns:
ValidationResult with errors and optionally cleaned output
"""
errors = []
warnings = []
# Check 1: Required input keys present
for key in target_node_spec.input_keys:
if key not in output:
errors.append(f"Missing required key: '{key}'")
continue
value = output[key]
# Check 2: Detect if value is JSON string (the JSON parsing trap!)
if isinstance(value, str):
# Try parsing as JSON to detect the trap
try:
parsed = json.loads(value)
if isinstance(parsed, dict):
if key in parsed:
# Key exists in parsed JSON - classic parsing failure!
errors.append(
f"Key '{key}' contains JSON string with nested '{key}' field - "
f"likely parsing failure from LLM node"
)
elif len(value) > 100:
# Large JSON string, but doesn't contain the key
warnings.append(
f"Key '{key}' contains JSON string ({len(value)} chars)"
)
except json.JSONDecodeError:
# Not JSON, check if suspiciously large
if len(value) > 500:
warnings.append(
f"Key '{key}' contains large string ({len(value)} chars), "
f"possibly entire LLM response"
)
# Check 3: Type validation (if schema provided)
if hasattr(target_node_spec, "input_schema") and target_node_spec.input_schema:
expected_schema = target_node_spec.input_schema.get(key)
if expected_schema:
expected_type = expected_schema.get("type")
if expected_type and not self._type_matches(value, expected_type):
actual_type = type(value).__name__
errors.append(
f"Key '{key}': expected type '{expected_type}', got '{actual_type}'"
)
# Warnings don't make validation fail, but errors do
is_valid = len(errors) == 0
if not is_valid and self.config.log_cleanings:
logger.warning(
f"⚠ Output validation failed for {source_node_id}{target_node_spec.id}: "
f"{len(errors)} error(s), {len(warnings)} warning(s)"
)
return ValidationResult(
valid=is_valid,
errors=errors,
warnings=warnings,
)
def clean_output(
self,
output: dict[str, Any],
source_node_id: str,
target_node_spec: Any, # NodeSpec
validation_errors: list[str],
) -> dict[str, Any]:
"""
Use heuristics and fast LLM to clean malformed output.
Args:
output: Raw output from source node
source_node_id: ID of source node
target_node_spec: Target node spec (for schema)
validation_errors: Errors from validation
Returns:
Cleaned output matching target schema
"""
if not self.config.enabled:
logger.warning("⚠ Output cleansing disabled in config")
return output
# --- PHASE 1: Fast Heuristic Repair (Avoids LLM call) ---
# Often the output is just a string containing JSON, or has minor syntax errors
# If output is a dictionary but malformed, we might need to serialize it first
# to try and fix the underlying string representation if it came from raw text
# Heuristic: Check if any value is actually a JSON string that should be promoted
# This handles the "JSON Parsing Trap" where LLM returns {"key": "{\"nested\": ...}"}
heuristic_fixed = False
fixed_output = output.copy()
for key, value in output.items():
if isinstance(value, str):
repaired = _heuristic_repair(value)
if repaired and isinstance(repaired, dict | list):
# Check if this repaired structure looks like what we want
# e.g. if the key is 'data' and the string contained valid JSON
fixed_output[key] = repaired
heuristic_fixed = True
# If we fixed something, re-validate manually to see if it's enough
if heuristic_fixed:
logger.info("⚡ Heuristic repair applied (nested JSON expansion)")
return fixed_output
# --- PHASE 2: LLM-based Repair ---
if not self.llm:
logger.warning("⚠ No LLM provider available for cleansing")
return output
# Build schema description for target node
schema_desc = self._build_schema_description(target_node_spec)
# Create cleansing prompt
prompt = f"""Clean this malformed agent output to match the expected schema.
VALIDATION ERRORS:
{chr(10).join(f"- {e}" for e in validation_errors)}
EXPECTED SCHEMA for node '{target_node_spec.id}':
{schema_desc}
RAW OUTPUT from node '{source_node_id}':
{json.dumps(output, indent=2, default=str)}
INSTRUCTIONS:
1. Extract values that match the expected schema keys
2. If a value is a JSON string, parse it and extract the correct field
3. Convert types to match the schema (string, dict, list, number, boolean)
4. Remove extra fields not in the expected schema
5. Ensure all required keys are present
Return ONLY valid JSON matching the expected schema. No explanations, no markdown."""
try:
if self.config.log_cleanings:
logger.info(
f"🧹 Cleaning output from '{source_node_id}' using {self.config.fast_model}"
)
response = self.llm.complete(
messages=[{"role": "user", "content": prompt}],
system=(
"You clean malformed agent outputs. Return only valid JSON matching the schema."
),
max_tokens=2048, # Sufficient for cleaning most outputs
)
# Parse cleaned output
cleaned_text = response.content.strip()
# Apply heuristic repair to the LLM's output too (just in case)
cleaned = _heuristic_repair(cleaned_text)
if not cleaned:
# Fallback to standard load if heuristic returns None (unlikely for LLM output)
cleaned = json.loads(cleaned_text)
if isinstance(cleaned, dict):
self.cleansing_count += 1
if self.config.log_cleanings:
logger.info(
f"✓ Output cleaned successfully (total cleanings: {self.cleansing_count})"
)
return cleaned
else:
logger.warning(f"⚠ Cleaned output is not a dict: {type(cleaned)}")
if self.config.fallback_to_raw:
return output
else:
raise ValueError(f"Cleaning produced {type(cleaned)}, expected dict")
except json.JSONDecodeError as e:
logger.error(f"✗ Failed to parse cleaned JSON: {e}")
if self.config.fallback_to_raw:
logger.info("↩ Falling back to raw output")
return output
else:
raise
except Exception as e:
logger.error(f"✗ Output cleaning failed: {e}")
if self.config.fallback_to_raw:
logger.info("↩ Falling back to raw output")
return output
else:
raise
def _build_schema_description(self, node_spec: Any) -> str:
"""Build human-readable schema description from NodeSpec."""
lines = ["{"]
for key in node_spec.input_keys:
# Get type hint and description if available
if hasattr(node_spec, "input_schema") and node_spec.input_schema:
schema = node_spec.input_schema.get(key, {})
type_hint = schema.get("type", "any")
description = schema.get("description", "")
required = schema.get("required", True)
line = f' "{key}": {type_hint}'
if description:
line += f" // {description}"
if required:
line += " (required)"
lines.append(line + ",")
else:
# No schema, just show the key
lines.append(f' "{key}": any // (required)')
lines.append("}")
return "\n".join(lines)
def _type_matches(self, value: Any, expected_type: str) -> bool:
"""Check if value matches expected type."""
type_map = {
"string": str,
"str": str,
"int": int,
"integer": int,
"float": float,
"number": (int, float),
"bool": bool,
"boolean": bool,
"dict": dict,
"object": dict,
"list": list,
"array": list,
"any": object, # Matches everything
}
expected_class = type_map.get(expected_type.lower())
if expected_class:
return isinstance(value, expected_class)
# Unknown type, allow it
return True
def get_stats(self) -> dict[str, Any]:
"""Get cleansing statistics."""
return {
"total_cleanings": self.cleansing_count,
"failure_count": dict(self.failure_count),
"cache_size": len(self.success_cache),
}