1621 lines
60 KiB
Python
1621 lines
60 KiB
Python
"""
|
|
Node Protocol - The building block of agent graphs.
|
|
|
|
A Node is a unit of work that:
|
|
1. Receives context (goal, shared memory, input)
|
|
2. Makes decisions (using LLM, tools, or logic)
|
|
3. Produces results (output, state changes)
|
|
4. Records everything to the Runtime
|
|
|
|
Nodes are composable and reusable. The same node can appear
|
|
in different graphs for different goals.
|
|
|
|
Protocol:
|
|
Every node must implement the NodeProtocol interface.
|
|
The framework provides NodeContext with everything the node needs.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from framework.llm.provider import LLMProvider, Tool
|
|
from framework.runtime.core import Runtime
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _fix_unescaped_newlines_in_json(json_str: str) -> str:
|
|
"""Fix unescaped newlines inside JSON string values.
|
|
|
|
LLMs sometimes output actual newlines inside JSON strings instead of \\n.
|
|
This function fixes that by properly escaping newlines within string values.
|
|
"""
|
|
result = []
|
|
in_string = False
|
|
escape_next = False
|
|
i = 0
|
|
|
|
while i < len(json_str):
|
|
char = json_str[i]
|
|
|
|
if escape_next:
|
|
result.append(char)
|
|
escape_next = False
|
|
i += 1
|
|
continue
|
|
|
|
if char == "\\" and in_string:
|
|
escape_next = True
|
|
result.append(char)
|
|
i += 1
|
|
continue
|
|
|
|
if char == '"' and not escape_next:
|
|
in_string = not in_string
|
|
result.append(char)
|
|
i += 1
|
|
continue
|
|
|
|
# Fix unescaped newlines inside strings
|
|
if in_string and char == "\n":
|
|
result.append("\\n")
|
|
i += 1
|
|
continue
|
|
|
|
# Fix unescaped carriage returns inside strings
|
|
if in_string and char == "\r":
|
|
result.append("\\r")
|
|
i += 1
|
|
continue
|
|
|
|
# Fix unescaped tabs inside strings
|
|
if in_string and char == "\t":
|
|
result.append("\\t")
|
|
i += 1
|
|
continue
|
|
|
|
result.append(char)
|
|
i += 1
|
|
|
|
return "".join(result)
|
|
|
|
|
|
def find_json_object(text: str) -> str | None:
|
|
"""Find the first valid JSON object in text using balanced brace matching.
|
|
|
|
This handles nested objects correctly, unlike simple regex like r'\\{[^{}]*\\}'.
|
|
"""
|
|
start = text.find("{")
|
|
if start == -1:
|
|
return None
|
|
|
|
depth = 0
|
|
in_string = False
|
|
escape_next = False
|
|
|
|
for i, char in enumerate(text[start:], start):
|
|
if escape_next:
|
|
escape_next = False
|
|
continue
|
|
|
|
if char == "\\" and in_string:
|
|
escape_next = True
|
|
continue
|
|
|
|
if char == '"' and not escape_next:
|
|
in_string = not in_string
|
|
continue
|
|
|
|
if in_string:
|
|
continue
|
|
|
|
if char == "{":
|
|
depth += 1
|
|
elif char == "}":
|
|
depth -= 1
|
|
if depth == 0:
|
|
return text[start : i + 1]
|
|
|
|
return None
|
|
|
|
|
|
class NodeSpec(BaseModel):
|
|
"""
|
|
Specification for a node in the graph.
|
|
|
|
This is the declarative definition of a node - what it does,
|
|
what it needs, and what it produces. The actual implementation
|
|
is separate (NodeProtocol).
|
|
|
|
Example:
|
|
NodeSpec(
|
|
id="calculator",
|
|
name="Calculator Node",
|
|
description="Performs mathematical calculations",
|
|
node_type="llm_tool_use",
|
|
input_keys=["expression"],
|
|
output_keys=["result"],
|
|
tools=["calculate", "math_function"],
|
|
system_prompt="You are a calculator..."
|
|
)
|
|
"""
|
|
|
|
id: str
|
|
name: str
|
|
description: str
|
|
|
|
# Node behavior type
|
|
node_type: str = Field(
|
|
default="llm_tool_use",
|
|
description="Type: 'llm_tool_use', 'llm_generate', 'function', 'router', 'human_input'",
|
|
)
|
|
|
|
# Data flow
|
|
input_keys: list[str] = Field(
|
|
default_factory=list, description="Keys this node reads from shared memory or input"
|
|
)
|
|
output_keys: list[str] = Field(
|
|
default_factory=list, description="Keys this node writes to shared memory or output"
|
|
)
|
|
|
|
# Optional schemas for validation and cleansing
|
|
input_schema: dict[str, dict] = Field(
|
|
default_factory=dict,
|
|
description=(
|
|
"Optional schema for input validation. "
|
|
"Format: {key: {type: 'string', required: True, description: '...'}}"
|
|
),
|
|
)
|
|
output_schema: dict[str, dict] = Field(
|
|
default_factory=dict,
|
|
description=(
|
|
"Optional schema for output validation. "
|
|
"Format: {key: {type: 'dict', required: True, description: '...'}}"
|
|
),
|
|
)
|
|
|
|
# For LLM nodes
|
|
system_prompt: str | None = Field(default=None, description="System prompt for LLM nodes")
|
|
tools: list[str] = Field(default_factory=list, description="Tool names this node can use")
|
|
model: str | None = Field(
|
|
default=None, description="Specific model to use (defaults to graph default)"
|
|
)
|
|
|
|
# For function nodes
|
|
function: str | None = Field(
|
|
default=None, description="Function name or path for function nodes"
|
|
)
|
|
|
|
# For router nodes
|
|
routes: dict[str, str] = Field(
|
|
default_factory=dict, description="Condition -> target_node_id mapping for routers"
|
|
)
|
|
|
|
# Retry behavior
|
|
max_retries: int = Field(default=3)
|
|
retry_on: list[str] = Field(default_factory=list, description="Error types to retry on")
|
|
|
|
# Pydantic model for output validation
|
|
output_model: type[BaseModel] | None = Field(
|
|
default=None,
|
|
description=(
|
|
"Optional Pydantic model class for validating and parsing LLM output. "
|
|
"When set, the LLM response will be validated against this model."
|
|
),
|
|
)
|
|
max_validation_retries: int = Field(
|
|
default=2,
|
|
description="Maximum retries when Pydantic validation fails (with feedback to LLM)"
|
|
)
|
|
|
|
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
|
|
|
|
|
|
class MemoryWriteError(Exception):
|
|
"""Raised when an invalid value is written to memory."""
|
|
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class SharedMemory:
|
|
"""
|
|
Shared state between nodes in a graph execution.
|
|
|
|
Nodes read and write to shared memory using typed keys.
|
|
The memory is scoped to a single run.
|
|
|
|
For parallel execution, use write_async() which provides per-key locking
|
|
to prevent race conditions when multiple nodes write concurrently.
|
|
"""
|
|
|
|
_data: dict[str, Any] = field(default_factory=dict)
|
|
_allowed_read: set[str] = field(default_factory=set)
|
|
_allowed_write: set[str] = field(default_factory=set)
|
|
# Locks for thread-safe parallel execution
|
|
_lock: asyncio.Lock | None = field(default=None, repr=False)
|
|
_key_locks: dict[str, asyncio.Lock] = field(default_factory=dict, repr=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Initialize the main lock if not provided."""
|
|
if self._lock is None:
|
|
self._lock = asyncio.Lock()
|
|
|
|
def read(self, key: str) -> Any:
|
|
"""Read a value from shared memory."""
|
|
if self._allowed_read and key not in self._allowed_read:
|
|
raise PermissionError(f"Node not allowed to read key: {key}")
|
|
return self._data.get(key)
|
|
|
|
def write(self, key: str, value: Any, validate: bool = True) -> None:
|
|
"""
|
|
Write a value to shared memory.
|
|
|
|
Args:
|
|
key: The memory key to write to
|
|
value: The value to write
|
|
validate: If True, check for suspicious content (default True)
|
|
|
|
Raises:
|
|
PermissionError: If node doesn't have write permission
|
|
MemoryWriteError: If value appears to be hallucinated content
|
|
"""
|
|
if self._allowed_write and key not in self._allowed_write:
|
|
raise PermissionError(f"Node not allowed to write key: {key}")
|
|
|
|
if validate and isinstance(value, str):
|
|
# Check for obviously hallucinated content
|
|
if len(value) > 5000:
|
|
# Long strings that look like code are suspicious
|
|
if self._contains_code_indicators(value):
|
|
logger.warning(
|
|
f"⚠ Suspicious write to key '{key}': appears to be code "
|
|
f"({len(value)} chars). Consider using validate=False if intended."
|
|
)
|
|
raise MemoryWriteError(
|
|
f"Rejected suspicious content for key '{key}': "
|
|
f"appears to be hallucinated code ({len(value)} chars). "
|
|
"If this is intentional, use validate=False."
|
|
)
|
|
|
|
self._data[key] = value
|
|
|
|
async def write_async(self, key: str, value: Any, validate: bool = True) -> None:
|
|
"""
|
|
Thread-safe async write with per-key locking.
|
|
|
|
Use this method when multiple nodes may write concurrently during
|
|
parallel execution. Each key has its own lock to minimize contention.
|
|
|
|
Args:
|
|
key: The memory key to write to
|
|
value: The value to write
|
|
validate: If True, check for suspicious content (default True)
|
|
|
|
Raises:
|
|
PermissionError: If node doesn't have write permission
|
|
MemoryWriteError: If value appears to be hallucinated content
|
|
"""
|
|
# Check permissions first (no lock needed)
|
|
if self._allowed_write and key not in self._allowed_write:
|
|
raise PermissionError(f"Node not allowed to write key: {key}")
|
|
|
|
# Ensure key has a lock (double-checked locking pattern)
|
|
if key not in self._key_locks:
|
|
async with self._lock:
|
|
if key not in self._key_locks:
|
|
self._key_locks[key] = asyncio.Lock()
|
|
|
|
# Acquire per-key lock and write
|
|
async with self._key_locks[key]:
|
|
if validate and isinstance(value, str):
|
|
if len(value) > 5000:
|
|
if self._contains_code_indicators(value):
|
|
logger.warning(
|
|
f"⚠ Suspicious write to key '{key}': appears to be code "
|
|
f"({len(value)} chars). Consider using validate=False if intended."
|
|
)
|
|
raise MemoryWriteError(
|
|
f"Rejected suspicious content for key '{key}': "
|
|
f"appears to be hallucinated code ({len(value)} chars). "
|
|
"If this is intentional, use validate=False."
|
|
)
|
|
self._data[key] = value
|
|
|
|
def _contains_code_indicators(self, value: str) -> bool:
|
|
"""
|
|
Check for code patterns in a string using sampling for efficiency.
|
|
|
|
For strings under 10KB, checks the entire content.
|
|
For longer strings, samples at strategic positions to balance
|
|
performance with detection accuracy.
|
|
|
|
Args:
|
|
value: The string to check for code indicators
|
|
|
|
Returns:
|
|
True if code indicators are found, False otherwise
|
|
"""
|
|
code_indicators = [
|
|
# Python
|
|
"```python",
|
|
"def ",
|
|
"class ",
|
|
"import ",
|
|
"async def ",
|
|
"from ",
|
|
# JavaScript/TypeScript
|
|
"function ",
|
|
"const ",
|
|
"let ",
|
|
"=> {",
|
|
"require(",
|
|
"export ",
|
|
# SQL
|
|
"SELECT ",
|
|
"INSERT ",
|
|
"UPDATE ",
|
|
"DELETE ",
|
|
"DROP ",
|
|
# HTML/Script injection
|
|
"<script",
|
|
"<?php",
|
|
"<%",
|
|
]
|
|
|
|
# For strings under 10KB, check the entire content
|
|
if len(value) < 10000:
|
|
return any(indicator in value for indicator in code_indicators)
|
|
|
|
# For longer strings, sample at strategic positions
|
|
sample_positions = [
|
|
0, # Start
|
|
len(value) // 4, # 25%
|
|
len(value) // 2, # 50%
|
|
3 * len(value) // 4, # 75%
|
|
max(0, len(value) - 2000), # Near end
|
|
]
|
|
|
|
for pos in sample_positions:
|
|
chunk = value[pos : pos + 2000]
|
|
if any(indicator in chunk for indicator in code_indicators):
|
|
return True
|
|
|
|
return False
|
|
|
|
def read_all(self) -> dict[str, Any]:
|
|
"""Read all accessible data."""
|
|
if self._allowed_read:
|
|
return {k: v for k, v in self._data.items() if k in self._allowed_read}
|
|
return dict(self._data)
|
|
|
|
def with_permissions(
|
|
self,
|
|
read_keys: list[str],
|
|
write_keys: list[str],
|
|
) -> "SharedMemory":
|
|
"""Create a view with restricted permissions for a specific node.
|
|
|
|
The scoped view shares the same underlying data and locks,
|
|
enabling thread-safe parallel execution across scoped views.
|
|
"""
|
|
return SharedMemory(
|
|
_data=self._data,
|
|
_allowed_read=set(read_keys) if read_keys else set(),
|
|
_allowed_write=set(write_keys) if write_keys else set(),
|
|
_lock=self._lock, # Share lock for thread safety
|
|
_key_locks=self._key_locks, # Share key locks
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class NodeContext:
|
|
"""
|
|
Everything a node needs to execute.
|
|
|
|
This is passed to every node and provides:
|
|
- Access to the runtime (for decision logging)
|
|
- Access to shared memory (for state)
|
|
- Access to LLM (for generation)
|
|
- Access to tools (for actions)
|
|
- The goal context (for guidance)
|
|
"""
|
|
|
|
# Core runtime
|
|
runtime: Runtime
|
|
|
|
# Node identity
|
|
node_id: str
|
|
node_spec: NodeSpec
|
|
|
|
# State
|
|
memory: SharedMemory
|
|
input_data: dict[str, Any] = field(default_factory=dict)
|
|
|
|
# LLM access (if applicable)
|
|
llm: LLMProvider | None = None
|
|
available_tools: list[Tool] = field(default_factory=list)
|
|
|
|
# Goal context
|
|
goal_context: str = ""
|
|
goal: Any = None # Goal object for LLM-powered routers
|
|
|
|
# LLM configuration
|
|
max_tokens: int = 4096 # Maximum tokens for LLM responses
|
|
|
|
# Execution metadata
|
|
attempt: int = 1
|
|
max_attempts: int = 3
|
|
|
|
|
|
@dataclass
|
|
class NodeResult:
|
|
"""
|
|
The output of a node execution.
|
|
|
|
Contains:
|
|
- Success/failure status
|
|
- Output data
|
|
- State changes made
|
|
- Route decision (for routers)
|
|
"""
|
|
|
|
success: bool
|
|
output: dict[str, Any] = field(default_factory=dict)
|
|
error: str | None = None
|
|
|
|
# For routing decisions
|
|
next_node: str | None = None
|
|
route_reason: str | None = None
|
|
|
|
# Metadata
|
|
tokens_used: int = 0
|
|
latency_ms: int = 0
|
|
|
|
# Pydantic validation errors (if any)
|
|
validation_errors: list[str] = field(default_factory=list)
|
|
|
|
def to_summary(self, node_spec: Any = None) -> str:
|
|
"""
|
|
Generate a human-readable summary of this node's execution and output.
|
|
|
|
This is like toString() - it describes what the node produced in its current state.
|
|
Uses Haiku to intelligently summarize complex outputs.
|
|
"""
|
|
if not self.success:
|
|
return f"❌ Failed: {self.error}"
|
|
|
|
if not self.output:
|
|
return "✓ Completed (no output)"
|
|
|
|
# Use Haiku to generate intelligent summary
|
|
import os
|
|
|
|
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
|
|
if not api_key:
|
|
# Fallback: simple key-value listing
|
|
parts = [f"✓ Completed with {len(self.output)} outputs:"]
|
|
for key, value in list(self.output.items())[:5]: # Limit to 5 keys
|
|
value_str = str(value)[:100]
|
|
if len(str(value)) > 100:
|
|
value_str += "..."
|
|
parts.append(f" • {key}: {value_str}")
|
|
return "\n".join(parts)
|
|
|
|
# Use Haiku to generate intelligent summary
|
|
try:
|
|
import json
|
|
|
|
import anthropic
|
|
|
|
node_context = ""
|
|
if node_spec:
|
|
node_context = f"\nNode: {node_spec.name}\nPurpose: {node_spec.description}"
|
|
|
|
output_json = json.dumps(self.output, indent=2, default=str)[:2000]
|
|
prompt = (
|
|
f"Generate a 1-2 sentence human-readable summary of "
|
|
f"what this node produced.{node_context}\n\n"
|
|
f"Node output:\n{output_json}\n\n"
|
|
"Provide a concise, clear summary that a human can quickly "
|
|
"understand. Focus on the key information produced."
|
|
)
|
|
|
|
client = anthropic.Anthropic(api_key=api_key)
|
|
message = client.messages.create(
|
|
model="claude-3-5-haiku-20241022",
|
|
max_tokens=200,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
)
|
|
|
|
summary = message.content[0].text.strip()
|
|
return f"✓ {summary}"
|
|
|
|
except Exception:
|
|
# Fallback on error
|
|
parts = [f"✓ Completed with {len(self.output)} outputs:"]
|
|
for key, value in list(self.output.items())[:3]:
|
|
value_str = str(value)[:80]
|
|
if len(str(value)) > 80:
|
|
value_str += "..."
|
|
parts.append(f" • {key}: {value_str}")
|
|
return "\n".join(parts)
|
|
|
|
|
|
class NodeProtocol(ABC):
|
|
"""
|
|
The interface all nodes must implement.
|
|
|
|
To create a node:
|
|
1. Subclass NodeProtocol
|
|
2. Implement execute()
|
|
3. Register with the executor
|
|
|
|
Example:
|
|
class CalculatorNode(NodeProtocol):
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
expression = ctx.input_data.get("expression")
|
|
|
|
# Record decision
|
|
decision_id = ctx.runtime.decide(
|
|
intent="Calculate expression",
|
|
options=[...],
|
|
chosen="evaluate",
|
|
reasoning="Direct evaluation"
|
|
)
|
|
|
|
# Do the work
|
|
result = eval(expression)
|
|
|
|
# Record outcome
|
|
ctx.runtime.record_outcome(decision_id, success=True, result=result)
|
|
|
|
return NodeResult(success=True, output={"result": result})
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
"""
|
|
Execute this node's logic.
|
|
|
|
Args:
|
|
ctx: NodeContext with everything needed
|
|
|
|
Returns:
|
|
NodeResult with output and status
|
|
"""
|
|
pass
|
|
|
|
def validate_input(self, ctx: NodeContext) -> list[str]:
|
|
"""
|
|
Validate that required inputs are present.
|
|
|
|
Override to add custom validation.
|
|
|
|
Returns:
|
|
List of validation error messages (empty if valid)
|
|
"""
|
|
errors = []
|
|
for key in ctx.node_spec.input_keys:
|
|
if key not in ctx.input_data and ctx.memory.read(key) is None:
|
|
errors.append(f"Missing required input: {key}")
|
|
return errors
|
|
|
|
|
|
class LLMNode(NodeProtocol):
|
|
"""
|
|
A node that uses an LLM with tools.
|
|
|
|
This is the most common node type. It:
|
|
1. Builds a prompt from context
|
|
2. Calls the LLM with available tools
|
|
3. Executes tool calls
|
|
4. Returns the final result
|
|
|
|
The LLM decides how to achieve the goal within constraints.
|
|
"""
|
|
|
|
# Stop reasons indicating truncation (varies by provider)
|
|
TRUNCATION_STOP_REASONS = {"length", "max_tokens", "token_limit"}
|
|
|
|
# Compaction instruction added when response is truncated
|
|
COMPACTION_INSTRUCTION = """
|
|
IMPORTANT: Your previous response was truncated because it exceeded the token limit.
|
|
Please provide a MORE CONCISE response that fits within the limit.
|
|
Focus on the essential information and omit verbose details.
|
|
Keep the same JSON structure but with shorter content values.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tool_executor: Callable | None = None,
|
|
require_tools: bool = False,
|
|
cleanup_llm_model: str | None = None,
|
|
max_compaction_retries: int = 2,
|
|
):
|
|
self.tool_executor = tool_executor
|
|
self.require_tools = require_tools
|
|
self.cleanup_llm_model = cleanup_llm_model
|
|
self.max_compaction_retries = max_compaction_retries
|
|
|
|
def _is_truncated(self, response) -> bool:
|
|
"""Check if LLM response was truncated due to token limit."""
|
|
stop_reason = getattr(response, "stop_reason", "").lower()
|
|
return stop_reason in self.TRUNCATION_STOP_REASONS
|
|
|
|
def _strip_code_blocks(self, content: str) -> str:
|
|
"""Strip markdown code block wrappers from content.
|
|
|
|
LLMs often wrap JSON output in ```json...``` blocks.
|
|
This method removes those wrappers to get clean content.
|
|
"""
|
|
import re
|
|
|
|
content = content.strip()
|
|
# Match ```json or ``` at start and ``` at end (greedy to handle nested)
|
|
match = re.match(r"^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$", content, re.DOTALL)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return content
|
|
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
"""Execute the LLM node."""
|
|
import time
|
|
|
|
if ctx.llm is None:
|
|
return NodeResult(success=False, error="LLM not available")
|
|
|
|
# Fail fast if tools are required but not available
|
|
if self.require_tools and not ctx.available_tools:
|
|
return NodeResult(
|
|
success=False,
|
|
error=f"Node '{ctx.node_spec.name}' requires tools but none are available. "
|
|
f"Declared tools: {ctx.node_spec.tools}. "
|
|
"Register tools via ToolRegistry before running the agent.",
|
|
)
|
|
|
|
ctx.runtime.set_node(ctx.node_id)
|
|
|
|
# Record the decision to use LLM
|
|
decision_id = ctx.runtime.decide(
|
|
intent=f"Execute {ctx.node_spec.name}",
|
|
options=[
|
|
{
|
|
"id": "llm_execute",
|
|
"description": f"Use LLM to {ctx.node_spec.description}",
|
|
"action_type": "llm_call",
|
|
}
|
|
],
|
|
chosen="llm_execute",
|
|
reasoning=f"Node type is {ctx.node_spec.node_type}",
|
|
context={"input": ctx.input_data},
|
|
)
|
|
|
|
start = time.time()
|
|
|
|
try:
|
|
# Build messages
|
|
messages = self._build_messages(ctx)
|
|
|
|
# Build system prompt
|
|
system = self._build_system_prompt(ctx)
|
|
|
|
# Log the LLM call details
|
|
logger.info(" 🤖 LLM Call:")
|
|
logger.info(
|
|
f" System: {system[:150]}..."
|
|
if len(system) > 150
|
|
else f" System: {system}"
|
|
)
|
|
logger.info(
|
|
f" User message: {messages[-1]['content'][:150]}..."
|
|
if len(messages[-1]["content"]) > 150
|
|
else f" User message: {messages[-1]['content']}"
|
|
)
|
|
if ctx.available_tools:
|
|
logger.info(f" Tools available: {[t.name for t in ctx.available_tools]}")
|
|
|
|
# Call LLM
|
|
if ctx.available_tools and self.tool_executor:
|
|
from framework.llm.provider import ToolResult, ToolUse
|
|
|
|
def executor(tool_use: ToolUse) -> ToolResult:
|
|
args = ", ".join(f"{k}={v}" for k, v in tool_use.input.items())
|
|
logger.info(f" 🔧 Tool call: {tool_use.name}({args})")
|
|
result = self.tool_executor(tool_use)
|
|
# Truncate long results
|
|
result_str = str(result.content)[:150]
|
|
if len(str(result.content)) > 150:
|
|
result_str += "..."
|
|
logger.info(f" ✓ Tool result: {result_str}")
|
|
return result
|
|
|
|
response = ctx.llm.complete_with_tools(
|
|
messages=messages,
|
|
system=system,
|
|
tools=ctx.available_tools,
|
|
tool_executor=executor,
|
|
max_tokens=ctx.max_tokens,
|
|
)
|
|
else:
|
|
# Use JSON mode for llm_generate nodes with output_keys
|
|
# Skip strict schema validation - just validate keys after parsing
|
|
use_json_mode = (
|
|
ctx.node_spec.node_type == "llm_generate"
|
|
and ctx.node_spec.output_keys
|
|
and len(ctx.node_spec.output_keys) >= 1
|
|
)
|
|
if use_json_mode:
|
|
logger.info(
|
|
f" 📋 Expecting JSON output with keys: {ctx.node_spec.output_keys}"
|
|
)
|
|
|
|
response = ctx.llm.complete(
|
|
messages=messages,
|
|
system=system,
|
|
json_mode=use_json_mode,
|
|
max_tokens=ctx.max_tokens,
|
|
)
|
|
|
|
# Check for truncation and retry with compaction if needed
|
|
expects_json = (
|
|
ctx.node_spec.node_type in ("llm_generate", "llm_tool_use")
|
|
and ctx.node_spec.output_keys
|
|
and len(ctx.node_spec.output_keys) >= 1
|
|
)
|
|
|
|
compaction_attempt = 0
|
|
while self._is_truncated(response) and expects_json and compaction_attempt < self.max_compaction_retries:
|
|
compaction_attempt += 1
|
|
logger.warning(
|
|
f" ⚠ Response truncated (stop_reason: {response.stop_reason}), "
|
|
f"retrying with compaction ({compaction_attempt}/{self.max_compaction_retries})"
|
|
)
|
|
|
|
# Add compaction instruction to messages
|
|
compaction_messages = messages + [
|
|
{"role": "assistant", "content": response.content},
|
|
{"role": "user", "content": self.COMPACTION_INSTRUCTION},
|
|
]
|
|
|
|
# Retry the call with compaction instruction
|
|
if ctx.available_tools and self.tool_executor:
|
|
response = ctx.llm.complete_with_tools(
|
|
messages=compaction_messages,
|
|
system=system,
|
|
tools=ctx.available_tools,
|
|
tool_executor=executor,
|
|
max_tokens=ctx.max_tokens,
|
|
)
|
|
else:
|
|
response = ctx.llm.complete(
|
|
messages=compaction_messages,
|
|
system=system,
|
|
json_mode=use_json_mode,
|
|
max_tokens=ctx.max_tokens,
|
|
)
|
|
|
|
if self._is_truncated(response) and expects_json:
|
|
logger.warning(
|
|
f" ⚠ Response still truncated after {compaction_attempt} compaction attempts"
|
|
)
|
|
|
|
# Phase 2: Validation retry loop for Pydantic models
|
|
max_validation_retries = ctx.node_spec.max_validation_retries if ctx.node_spec.output_model else 0
|
|
validation_attempt = 0
|
|
total_input_tokens = 0
|
|
total_output_tokens = 0
|
|
current_messages = messages.copy()
|
|
|
|
while True:
|
|
total_input_tokens += response.input_tokens
|
|
total_output_tokens += response.output_tokens
|
|
|
|
# Log the response
|
|
response_preview = (
|
|
response.content[:200] if len(response.content) > 200 else response.content
|
|
)
|
|
if len(response.content) > 200:
|
|
response_preview += "..."
|
|
logger.info(f" ← Response: {response_preview}")
|
|
|
|
# If no output_model, break immediately (no validation needed)
|
|
if ctx.node_spec.output_model is None:
|
|
break
|
|
|
|
# Try to parse and validate the response
|
|
try:
|
|
import json
|
|
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
|
|
|
|
if isinstance(parsed, dict):
|
|
from framework.graph.validator import OutputValidator
|
|
validator = OutputValidator()
|
|
validation_result, validated_model = validator.validate_with_pydantic(
|
|
parsed, ctx.node_spec.output_model
|
|
)
|
|
|
|
if validation_result.success:
|
|
# Validation passed, break out of retry loop
|
|
model_name = ctx.node_spec.output_model.__name__
|
|
logger.info(f" ✓ Pydantic validation passed for {model_name}")
|
|
break
|
|
else:
|
|
# Validation failed
|
|
validation_attempt += 1
|
|
|
|
if validation_attempt <= max_validation_retries:
|
|
# Add validation feedback to messages and retry
|
|
feedback = validator.format_validation_feedback(
|
|
validation_result, ctx.node_spec.output_model
|
|
)
|
|
logger.warning(
|
|
f" ⚠ Pydantic validation failed "
|
|
f"(attempt {validation_attempt}/{max_validation_retries}): "
|
|
f"{validation_result.error}"
|
|
)
|
|
logger.info(" 🔄 Retrying with validation feedback...")
|
|
|
|
# Add the assistant's failed response and feedback
|
|
current_messages.append({
|
|
"role": "assistant",
|
|
"content": response.content
|
|
})
|
|
current_messages.append({
|
|
"role": "user",
|
|
"content": feedback
|
|
})
|
|
|
|
# Re-call LLM with feedback
|
|
if ctx.available_tools and self.tool_executor:
|
|
response = ctx.llm.complete_with_tools(
|
|
messages=current_messages,
|
|
system=system,
|
|
tools=ctx.available_tools,
|
|
tool_executor=executor,
|
|
max_tokens=ctx.max_tokens,
|
|
)
|
|
else:
|
|
response = ctx.llm.complete(
|
|
messages=current_messages,
|
|
system=system,
|
|
json_mode=use_json_mode,
|
|
max_tokens=ctx.max_tokens,
|
|
)
|
|
continue # Retry validation
|
|
else:
|
|
# Max retries exceeded
|
|
latency_ms = int((time.time() - start) * 1000)
|
|
err = validation_result.error
|
|
logger.error(
|
|
f" ✗ Pydantic validation failed after "
|
|
f"{max_validation_retries} retries: {err}"
|
|
)
|
|
ctx.runtime.record_outcome(
|
|
decision_id=decision_id,
|
|
success=False,
|
|
error=f"Validation failed: {validation_result.error}",
|
|
tokens_used=total_input_tokens + total_output_tokens,
|
|
latency_ms=latency_ms,
|
|
)
|
|
error_msg = (
|
|
f"Pydantic validation failed after "
|
|
f"{max_validation_retries} retries: {err}"
|
|
)
|
|
return NodeResult(
|
|
success=False,
|
|
error=error_msg,
|
|
output=parsed,
|
|
tokens_used=total_input_tokens + total_output_tokens,
|
|
latency_ms=latency_ms,
|
|
validation_errors=validation_result.errors,
|
|
)
|
|
else:
|
|
# Not a dict, can't validate - break and let downstream handle
|
|
break
|
|
except Exception:
|
|
# JSON extraction failed - break and let downstream handle
|
|
break
|
|
|
|
latency_ms = int((time.time() - start) * 1000)
|
|
|
|
ctx.runtime.record_outcome(
|
|
decision_id=decision_id,
|
|
success=True,
|
|
result=response.content,
|
|
tokens_used=response.input_tokens + response.output_tokens,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
# Write to output keys
|
|
output = self._parse_output(response.content, ctx.node_spec)
|
|
|
|
# For llm_generate and llm_tool_use nodes, try to parse JSON and extract fields
|
|
if (
|
|
ctx.node_spec.node_type in ("llm_generate", "llm_tool_use")
|
|
and len(ctx.node_spec.output_keys) >= 1
|
|
):
|
|
try:
|
|
import json
|
|
|
|
# Try to extract JSON from response
|
|
parsed = self._extract_json(
|
|
response.content, ctx.node_spec.output_keys, self.cleanup_llm_model
|
|
)
|
|
|
|
# If parsed successfully, write each field to its corresponding output key
|
|
# Use validate=False since LLM output legitimately contains text that
|
|
# may trigger false positives (e.g., "from OpenAI" matches "from ")
|
|
if isinstance(parsed, dict):
|
|
# If we have output_model, the validation already happened in the retry loop
|
|
if ctx.node_spec.output_model is not None:
|
|
from framework.graph.validator import OutputValidator
|
|
validator = OutputValidator()
|
|
validation_result, validated_model = validator.validate_with_pydantic(
|
|
parsed, ctx.node_spec.output_model
|
|
)
|
|
# Use validated model's dict representation
|
|
if validated_model:
|
|
parsed = validated_model.model_dump()
|
|
|
|
for key in ctx.node_spec.output_keys:
|
|
if key in parsed:
|
|
value = parsed[key]
|
|
# Strip code block wrappers from string values
|
|
if isinstance(value, str):
|
|
value = self._strip_code_blocks(value)
|
|
ctx.memory.write(key, value, validate=False)
|
|
output[key] = value
|
|
elif key in ctx.input_data:
|
|
# Key not in JSON but exists in input - pass through
|
|
ctx.memory.write(key, ctx.input_data[key], validate=False)
|
|
output[key] = ctx.input_data[key]
|
|
else:
|
|
# Key not in JSON or input, write whole response (stripped)
|
|
stripped_content = self._strip_code_blocks(response.content)
|
|
ctx.memory.write(key, stripped_content, validate=False)
|
|
output[key] = stripped_content
|
|
else:
|
|
# Not a dict, fall back to writing entire response to all keys (stripped)
|
|
stripped_content = self._strip_code_blocks(response.content)
|
|
for key in ctx.node_spec.output_keys:
|
|
ctx.memory.write(key, stripped_content, validate=False)
|
|
output[key] = stripped_content
|
|
|
|
except (json.JSONDecodeError, Exception) as e:
|
|
# JSON extraction failed - fail explicitly instead of polluting memory
|
|
logger.error(f" ✗ Failed to extract structured output: {e}")
|
|
logger.error(
|
|
f" Raw response (first 500 chars): {response.content[:500]}..."
|
|
)
|
|
|
|
# Return failure instead of writing garbage to all keys
|
|
return NodeResult(
|
|
success=False,
|
|
error=(
|
|
f"Output extraction failed: {e}. LLM returned non-JSON response. "
|
|
f"Expected keys: {ctx.node_spec.output_keys}"
|
|
),
|
|
output={},
|
|
tokens_used=response.input_tokens + response.output_tokens,
|
|
latency_ms=latency_ms,
|
|
)
|
|
# JSON extraction failed completely - still strip code blocks
|
|
# logger.warning(f" ⚠ Failed to extract JSON output: {e}")
|
|
# stripped_content = self._strip_code_blocks(response.content)
|
|
# for key in ctx.node_spec.output_keys:
|
|
# ctx.memory.write(key, stripped_content)
|
|
# output[key] = stripped_content
|
|
else:
|
|
# For non-llm_generate or single output nodes, write entire response (stripped)
|
|
stripped_content = self._strip_code_blocks(response.content)
|
|
for key in ctx.node_spec.output_keys:
|
|
ctx.memory.write(key, stripped_content, validate=False)
|
|
output[key] = stripped_content
|
|
|
|
return NodeResult(
|
|
success=True,
|
|
output=output,
|
|
tokens_used=response.input_tokens + response.output_tokens,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
except Exception as e:
|
|
latency_ms = int((time.time() - start) * 1000)
|
|
ctx.runtime.record_outcome(
|
|
decision_id=decision_id,
|
|
success=False,
|
|
error=str(e),
|
|
latency_ms=latency_ms,
|
|
)
|
|
return NodeResult(success=False, error=str(e), latency_ms=latency_ms)
|
|
|
|
def _parse_output(self, content: str, node_spec: NodeSpec) -> dict[str, Any]:
|
|
"""
|
|
Parse LLM output based on node type.
|
|
|
|
For llm_generate nodes with multiple output keys, attempts to parse JSON.
|
|
Otherwise returns raw content.
|
|
"""
|
|
# Default output
|
|
return {"result": content}
|
|
|
|
def _extract_json(
|
|
self, raw_response: str, output_keys: list[str], cleanup_llm_model: str | None = None
|
|
) -> dict[str, Any]:
|
|
"""Extract clean JSON from potentially verbose LLM response.
|
|
|
|
Tries multiple extraction strategies in order:
|
|
1. Direct JSON parse
|
|
2. Markdown code block extraction
|
|
3. Balanced brace matching
|
|
4. Configured LLM fallback (last resort)
|
|
|
|
Args:
|
|
raw_response: The raw LLM response text
|
|
output_keys: Expected output keys for the JSON
|
|
cleanup_llm_model: Optional model to use for LLM cleanup fallback
|
|
"""
|
|
import json
|
|
import re
|
|
|
|
content = raw_response.strip()
|
|
|
|
# Try direct JSON parse first (fast path)
|
|
try:
|
|
content = raw_response.strip()
|
|
|
|
# Remove markdown code blocks if present - more robust extraction
|
|
if content.startswith("```"):
|
|
# Try multiple patterns for markdown code blocks
|
|
# Pattern 1: ```json\n...\n``` or ```\n...\n```
|
|
match = re.search(r"^```(?:json)?\s*\n([\s\S]*?)\n```\s*$", content)
|
|
if match:
|
|
content = match.group(1).strip()
|
|
else:
|
|
# Pattern 2: Just strip the first and last lines if they're ```
|
|
lines = content.split("\n")
|
|
if lines[0].startswith("```") and lines[-1].strip() == "```":
|
|
content = "\n".join(lines[1:-1]).strip()
|
|
|
|
parsed = json.loads(content)
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except json.JSONDecodeError as e:
|
|
logger.info(f" Direct JSON parse failed: {e}")
|
|
logger.info(f" Content first 200 chars repr: {repr(content[:200])}")
|
|
# Try fixing unescaped newlines in string values
|
|
try:
|
|
fixed = _fix_unescaped_newlines_in_json(content)
|
|
logger.info(f" Fixed content first 200 chars repr: {repr(fixed[:200])}")
|
|
parsed = json.loads(fixed)
|
|
if isinstance(parsed, dict):
|
|
logger.info(" ✓ Parsed JSON after fixing unescaped newlines")
|
|
return parsed
|
|
except json.JSONDecodeError as e2:
|
|
logger.info(f" Newline fix also failed: {e2}")
|
|
|
|
# Try to extract JSON from markdown code blocks (greedy match to handle nested blocks)
|
|
# Multiple patterns to handle different LLM formatting styles
|
|
code_block_patterns = [
|
|
# Anchored match from first ``` to last ```
|
|
r"^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$",
|
|
# Non-anchored: find ```json anywhere and extract to closing ```
|
|
r"```(?:json|JSON)?\s*\n([\s\S]*?)\n```",
|
|
# Handle case where closing ``` might have trailing content
|
|
r"```(?:json|JSON)?\s*\n([\s\S]*?)\n```",
|
|
]
|
|
for pattern in code_block_patterns:
|
|
code_block_match = re.search(pattern, content, re.DOTALL)
|
|
if code_block_match:
|
|
try:
|
|
extracted = code_block_match.group(1).strip()
|
|
if extracted: # Skip empty matches
|
|
# Try direct parse first, then with newline fix
|
|
try:
|
|
parsed = json.loads(extracted)
|
|
except json.JSONDecodeError:
|
|
parsed = json.loads(_fix_unescaped_newlines_in_json(extracted))
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Try to find JSON object by matching balanced braces (use module-level helper)
|
|
json_str = find_json_object(content)
|
|
if json_str:
|
|
try:
|
|
# Try direct parse first, then with newline fix
|
|
try:
|
|
parsed = json.loads(json_str)
|
|
except json.JSONDecodeError:
|
|
parsed = json.loads(_fix_unescaped_newlines_in_json(json_str))
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Try stripping markdown prefix and finding JSON from there
|
|
# This handles cases like "```json\n{...}" where regex might fail
|
|
if "```" in content:
|
|
# Find position after ```json or ``` marker
|
|
json_start = content.find("{")
|
|
if json_start > 0:
|
|
# Extract from first { to end, then find balanced JSON
|
|
json_str = find_json_object(content[json_start:])
|
|
if json_str:
|
|
try:
|
|
# Try direct parse first, then with newline fix
|
|
try:
|
|
parsed = json.loads(json_str)
|
|
except json.JSONDecodeError:
|
|
parsed = json.loads(_fix_unescaped_newlines_in_json(json_str))
|
|
if isinstance(parsed, dict):
|
|
logger.info(" ✓ Extracted JSON via brace matching after markdown strip")
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# All local extraction failed - use LLM as last resort
|
|
import os
|
|
from framework.llm.litellm import LiteLLMProvider
|
|
|
|
logger.info(f" cleanup_llm_model param: {cleanup_llm_model}")
|
|
|
|
# Use configured cleanup model, or fall back to defaults
|
|
if cleanup_llm_model:
|
|
# Use the configured cleanup model (LiteLLM handles API keys via env vars)
|
|
cleaner_llm = LiteLLMProvider(
|
|
model=cleanup_llm_model,
|
|
temperature=0.0,
|
|
)
|
|
logger.info(f" Using configured cleanup LLM: {cleanup_llm_model}")
|
|
else:
|
|
# Fall back to default logic: Cerebras preferred, then Haiku
|
|
api_key = os.environ.get("CEREBRAS_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
raise ValueError(
|
|
"Cannot parse JSON and no API key for LLM cleanup "
|
|
"(set CEREBRAS_API_KEY or ANTHROPIC_API_KEY, or configure cleanup_llm_model)"
|
|
)
|
|
|
|
if os.environ.get("CEREBRAS_API_KEY"):
|
|
cleaner_llm = LiteLLMProvider(
|
|
api_key=os.environ.get("CEREBRAS_API_KEY"),
|
|
model="cerebras/llama-3.3-70b",
|
|
temperature=0.0,
|
|
)
|
|
else:
|
|
cleaner_llm = LiteLLMProvider(
|
|
api_key=api_key,
|
|
model="claude-3-5-haiku-20241022",
|
|
temperature=0.0,
|
|
)
|
|
|
|
prompt = f"""Extract the JSON object from this LLM response.
|
|
|
|
Expected output keys: {output_keys}
|
|
|
|
LLM Response:
|
|
{raw_response}
|
|
|
|
Output ONLY the JSON object, nothing else."""
|
|
|
|
try:
|
|
result = cleaner_llm.complete(
|
|
messages=[{"role": "user", "content": prompt}],
|
|
system="Extract JSON from text. Output only valid JSON.",
|
|
json_mode=True,
|
|
)
|
|
|
|
cleaned = result.content.strip() if result.content else ""
|
|
|
|
# Check for empty response
|
|
if not cleaned:
|
|
logger.warning(" ⚠ LLM cleanup returned empty response")
|
|
raise ValueError(
|
|
f"LLM cleanup returned empty response. "
|
|
f"Raw response starts with: {raw_response[:200]}..."
|
|
)
|
|
|
|
# Remove markdown if LLM added it
|
|
if cleaned.startswith("```"):
|
|
match = re.search(r"^```(?:json)?\s*\n([\s\S]*?)\n```\s*$", cleaned)
|
|
if match:
|
|
cleaned = match.group(1).strip()
|
|
else:
|
|
# Fallback: strip first/last lines
|
|
lines = cleaned.split("\n")
|
|
if lines[0].startswith("```") and lines[-1].strip() == "```":
|
|
cleaned = "\n".join(lines[1:-1]).strip()
|
|
|
|
# Try balanced brace extraction if still not valid JSON
|
|
if not cleaned.startswith("{"):
|
|
json_str = find_json_object(cleaned)
|
|
if json_str:
|
|
cleaned = json_str
|
|
|
|
if not cleaned:
|
|
raise ValueError(
|
|
f"Could not extract JSON from LLM cleanup response. "
|
|
f"Raw response starts with: {raw_response[:200]}..."
|
|
)
|
|
|
|
# Try direct parse first, then with newline fix
|
|
try:
|
|
parsed = json.loads(cleaned)
|
|
except json.JSONDecodeError:
|
|
parsed = json.loads(_fix_unescaped_newlines_in_json(cleaned))
|
|
logger.info(" ✓ LLM cleaned JSON output")
|
|
return parsed
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f" ⚠ LLM cleanup response not valid JSON: {e}")
|
|
raise ValueError(
|
|
f"LLM cleanup response not valid JSON: {e}. "
|
|
f"Expected keys: {output_keys}"
|
|
)
|
|
except ValueError:
|
|
raise # Re-raise our descriptive error
|
|
except Exception as e:
|
|
logger.warning(f" ⚠ LLM JSON extraction failed: {e}")
|
|
raise
|
|
|
|
def _build_messages(self, ctx: NodeContext) -> list[dict]:
|
|
"""Build the message list for the LLM."""
|
|
# Use Haiku to intelligently format inputs from memory
|
|
user_content = self._format_inputs_with_haiku(ctx)
|
|
return [{"role": "user", "content": user_content}]
|
|
|
|
def _format_inputs_with_haiku(self, ctx: NodeContext) -> str:
|
|
"""Use Haiku to intelligently extract and format inputs from memory."""
|
|
if not ctx.node_spec.input_keys:
|
|
return str(ctx.input_data)
|
|
|
|
# Read all memory for context
|
|
memory_data = ctx.memory.read_all()
|
|
|
|
# If memory is empty or very simple, just use raw data
|
|
if not memory_data or len(memory_data) <= 2:
|
|
# Simple case - just format the input keys directly
|
|
parts = []
|
|
for key in ctx.node_spec.input_keys:
|
|
value = ctx.memory.read(key)
|
|
if value is not None:
|
|
parts.append(f"{key}: {value}")
|
|
return "\n".join(parts) if parts else str(ctx.input_data)
|
|
|
|
# Use Haiku to intelligently extract relevant data
|
|
import os
|
|
|
|
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
# Fallback to simple formatting if no API key
|
|
parts = []
|
|
for key in ctx.node_spec.input_keys:
|
|
value = ctx.memory.read(key)
|
|
if value is not None:
|
|
parts.append(f"{key}: {value}")
|
|
return "\n".join(parts)
|
|
|
|
# Build prompt for Haiku to extract clean values
|
|
import json
|
|
|
|
# Smart truncation: truncate values rather than corrupting JSON
|
|
def truncate_value(v, max_len=500):
|
|
s = str(v)
|
|
return s[:max_len] + "..." if len(s) > max_len else v
|
|
|
|
truncated_data = {k: truncate_value(v) for k, v in memory_data.items()}
|
|
memory_json = json.dumps(truncated_data, indent=2, default=str)
|
|
|
|
required_fields = ", ".join(ctx.node_spec.input_keys)
|
|
prompt = (
|
|
f"Extract the following information from the memory context:\n\n"
|
|
f"Required fields: {required_fields}\n\n"
|
|
f"Memory context (may contain nested data, JSON strings, "
|
|
f"or extra information):\n{memory_json}\n\n"
|
|
"Extract ONLY the clean values for the required fields. "
|
|
"Ignore nested structures, JSON wrappers, and irrelevant data.\n\n"
|
|
"Output as JSON with the exact field names requested."
|
|
)
|
|
|
|
try:
|
|
import anthropic
|
|
|
|
client = anthropic.Anthropic(api_key=api_key)
|
|
message = client.messages.create(
|
|
model="claude-3-5-haiku-20241022",
|
|
max_tokens=1000,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
)
|
|
|
|
# Parse Haiku's response
|
|
response_text = message.content[0].text.strip()
|
|
|
|
# Try to extract JSON using balanced brace matching
|
|
json_str = find_json_object(response_text)
|
|
if json_str:
|
|
extracted = json.loads(json_str)
|
|
# Format as key: value pairs
|
|
parts = [f"{k}: {v}" for k, v in extracted.items() if k in ctx.node_spec.input_keys]
|
|
if parts:
|
|
return "\n".join(parts)
|
|
|
|
except Exception as e:
|
|
# Fallback to simple formatting on error
|
|
logger.warning(f"Haiku formatting failed: {e}, falling back to simple format")
|
|
|
|
# Fallback: simple key-value formatting
|
|
parts = []
|
|
for key in ctx.node_spec.input_keys:
|
|
value = ctx.memory.read(key)
|
|
if value is not None:
|
|
parts.append(f"{key}: {value}")
|
|
return "\n".join(parts) if parts else str(ctx.input_data)
|
|
|
|
def _build_system_prompt(self, ctx: NodeContext) -> str:
|
|
"""Build the system prompt."""
|
|
parts = []
|
|
|
|
if ctx.node_spec.system_prompt:
|
|
# Format system prompt with values from memory (for input_keys placeholders)
|
|
prompt = ctx.node_spec.system_prompt
|
|
if ctx.node_spec.input_keys:
|
|
# Build formatting context from memory
|
|
format_context = {}
|
|
for key in ctx.node_spec.input_keys:
|
|
value = ctx.memory.read(key)
|
|
if value is not None:
|
|
format_context[key] = value
|
|
|
|
# Try to format, but fallback to raw prompt if formatting fails
|
|
try:
|
|
prompt = prompt.format(**format_context)
|
|
except (KeyError, ValueError):
|
|
# Placeholders don't match or formatting error - use raw prompt
|
|
pass
|
|
|
|
parts.append(prompt)
|
|
|
|
if ctx.goal_context:
|
|
parts.append("\n# Goal Context")
|
|
parts.append(ctx.goal_context)
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
class RouterNode(NodeProtocol):
|
|
"""
|
|
A node that routes to different next nodes based on conditions.
|
|
|
|
The router examines the current state and decides which
|
|
node should execute next.
|
|
|
|
Can use either:
|
|
1. Simple condition matching (deterministic)
|
|
2. LLM-based routing (goal-aware, adaptive)
|
|
|
|
Set node_spec.routes to a dict of conditions -> target nodes.
|
|
If node_spec.system_prompt is provided, LLM will choose the route.
|
|
"""
|
|
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
"""Execute routing logic."""
|
|
ctx.runtime.set_node(ctx.node_id)
|
|
|
|
# Build options from routes
|
|
options = []
|
|
for condition, target in ctx.node_spec.routes.items():
|
|
options.append(
|
|
{
|
|
"id": condition,
|
|
"description": f"Route to {target} when condition '{condition}' is met",
|
|
"target": target,
|
|
}
|
|
)
|
|
|
|
# Check if we should use LLM-based routing
|
|
if ctx.node_spec.system_prompt and ctx.llm:
|
|
# LLM-based routing (goal-aware)
|
|
chosen_route = await self._llm_route(ctx, options)
|
|
else:
|
|
# Simple condition-based routing (deterministic)
|
|
route_value = ctx.input_data.get("route_on") or ctx.memory.read("route_on")
|
|
chosen_route = None
|
|
for condition, target in ctx.node_spec.routes.items():
|
|
if self._check_condition(condition, route_value, ctx):
|
|
chosen_route = (condition, target)
|
|
break
|
|
|
|
if chosen_route is None:
|
|
# Default route
|
|
chosen_route = ("default", ctx.node_spec.routes.get("default", "end"))
|
|
|
|
decision_id = ctx.runtime.decide(
|
|
intent="Determine next node in graph",
|
|
options=options,
|
|
chosen=chosen_route[0],
|
|
reasoning=f"Routing decision: {chosen_route[0]}",
|
|
)
|
|
|
|
ctx.runtime.record_outcome(
|
|
decision_id=decision_id,
|
|
success=True,
|
|
result=chosen_route[1],
|
|
summary=f"Routing to {chosen_route[1]}",
|
|
)
|
|
|
|
return NodeResult(
|
|
success=True,
|
|
next_node=chosen_route[1],
|
|
route_reason=f"Chose route: {chosen_route[0]}",
|
|
)
|
|
|
|
async def _llm_route(
|
|
self,
|
|
ctx: NodeContext,
|
|
options: list[dict[str, Any]],
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Use LLM to choose the best route based on goal and context.
|
|
|
|
Returns:
|
|
Tuple of (chosen_condition, target_node)
|
|
"""
|
|
import json
|
|
|
|
# Build routing options description
|
|
options_desc = "\n".join(
|
|
[f"- {opt['id']}: {opt['description']} → goes to '{opt['target']}'" for opt in options]
|
|
)
|
|
|
|
# Build context
|
|
context_data = {
|
|
"input": ctx.input_data,
|
|
"memory_keys": list(ctx.memory.read_all().keys())[:10],
|
|
}
|
|
|
|
prompt = f"""You are a routing agent deciding which path to take in a workflow.
|
|
|
|
**Goal**: {ctx.goal.name}
|
|
{ctx.goal.description}
|
|
|
|
**Current Context**:
|
|
{json.dumps(context_data, indent=2, default=str)}
|
|
|
|
**Available Routes**:
|
|
{options_desc}
|
|
|
|
Based on the goal and current context, which route should we take?
|
|
|
|
Respond with ONLY a JSON object:
|
|
{{"chosen": "route_id", "reasoning": "brief explanation"}}"""
|
|
|
|
logger.info(" 🤔 Router using LLM to choose path...")
|
|
|
|
try:
|
|
response = ctx.llm.complete(
|
|
messages=[{"role": "user", "content": prompt}],
|
|
system=ctx.node_spec.system_prompt
|
|
or "You are a routing agent. Respond with JSON only.",
|
|
max_tokens=150,
|
|
)
|
|
|
|
# Parse response using balanced brace matching
|
|
json_str = find_json_object(response.content)
|
|
if json_str:
|
|
data = json.loads(json_str)
|
|
chosen = data.get("chosen", "default")
|
|
reasoning = data.get("reasoning", "")
|
|
|
|
logger.info(f" → Chose: {chosen}")
|
|
logger.info(f" Reason: {reasoning}")
|
|
|
|
# Find the target for this choice
|
|
target = ctx.node_spec.routes.get(
|
|
chosen, ctx.node_spec.routes.get("default", "end")
|
|
)
|
|
return (chosen, target)
|
|
|
|
except Exception as e:
|
|
logger.warning(f" ⚠ LLM routing failed, using default: {e}")
|
|
|
|
# Fallback to default
|
|
default_target = ctx.node_spec.routes.get("default", "end")
|
|
return ("default", default_target)
|
|
|
|
def _check_condition(
|
|
self,
|
|
condition: str,
|
|
value: Any,
|
|
ctx: NodeContext,
|
|
) -> bool:
|
|
"""Check if a routing condition is met."""
|
|
if condition == "default":
|
|
return True
|
|
if condition == "success" and value is True:
|
|
return True
|
|
if condition == "failure" and value is False:
|
|
return True
|
|
if condition == "error" and isinstance(value, Exception):
|
|
return True
|
|
|
|
# String matching
|
|
if isinstance(value, str) and condition in value:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
class FunctionNode(NodeProtocol):
|
|
"""
|
|
A node that executes a Python function.
|
|
|
|
For deterministic operations that don't need LLM reasoning.
|
|
"""
|
|
|
|
def __init__(self, func: Callable):
|
|
self.func = func
|
|
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
"""Execute the function."""
|
|
import time
|
|
|
|
ctx.runtime.set_node(ctx.node_id)
|
|
|
|
decision_id = ctx.runtime.decide(
|
|
intent=f"Execute function {ctx.node_spec.function or 'unknown'}",
|
|
options=[
|
|
{
|
|
"id": "execute",
|
|
"description": f"Run function with inputs: {list(ctx.input_data.keys())}",
|
|
}
|
|
],
|
|
chosen="execute",
|
|
reasoning="Deterministic function execution",
|
|
)
|
|
|
|
start = time.time()
|
|
|
|
try:
|
|
# Call the function
|
|
result = self.func(**ctx.input_data)
|
|
|
|
latency_ms = int((time.time() - start) * 1000)
|
|
|
|
ctx.runtime.record_outcome(
|
|
decision_id=decision_id,
|
|
success=True,
|
|
result=result,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
# Write to output keys
|
|
output = {}
|
|
if ctx.node_spec.output_keys:
|
|
key = ctx.node_spec.output_keys[0]
|
|
output[key] = result
|
|
ctx.memory.write(key, result)
|
|
else:
|
|
output = {"result": result}
|
|
|
|
return NodeResult(success=True, output=output, latency_ms=latency_ms)
|
|
|
|
except Exception as e:
|
|
latency_ms = int((time.time() - start) * 1000)
|
|
ctx.runtime.record_outcome(
|
|
decision_id=decision_id,
|
|
success=False,
|
|
error=str(e),
|
|
latency_ms=latency_ms,
|
|
)
|
|
return NodeResult(success=False, error=str(e), latency_ms=latency_ms)
|