Files
hive/core/framework/graph/node.py
T
2026-02-18 20:29:39 -08:00

650 lines
20 KiB
Python

"""
Node Protocol - The building block of agent graphs.
A Node is a unit of work that:
1. Receives context (goal, shared memory, input)
2. Makes decisions (using LLM, tools, or logic)
3. Produces results (output, state changes)
4. Records everything to the Runtime
Nodes are composable and reusable. The same node can appear
in different graphs for different goals.
Protocol:
Every node must implement the NodeProtocol interface.
The framework provides NodeContext with everything the node needs.
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from pydantic import BaseModel, Field
from framework.llm.provider import LLMProvider, Tool
from framework.runtime.core import Runtime
logger = logging.getLogger(__name__)
def _fix_unescaped_newlines_in_json(json_str: str) -> str:
"""Fix unescaped newlines inside JSON string values.
LLMs sometimes output actual newlines inside JSON strings instead of \\n.
This function fixes that by properly escaping newlines within string values.
"""
result = []
in_string = False
escape_next = False
i = 0
while i < len(json_str):
char = json_str[i]
if escape_next:
result.append(char)
escape_next = False
i += 1
continue
if char == "\\" and in_string:
escape_next = True
result.append(char)
i += 1
continue
if char == '"' and not escape_next:
in_string = not in_string
result.append(char)
i += 1
continue
# Fix unescaped newlines inside strings
if in_string and char == "\n":
result.append("\\n")
i += 1
continue
# Fix unescaped carriage returns inside strings
if in_string and char == "\r":
result.append("\\r")
i += 1
continue
# Fix unescaped tabs inside strings
if in_string and char == "\t":
result.append("\\t")
i += 1
continue
result.append(char)
i += 1
return "".join(result)
def find_json_object(text: str) -> str | None:
"""Find the first valid JSON object in text using balanced brace matching.
This handles nested objects correctly, unlike simple regex like r'\\{[^{}]*\\}'.
"""
start = text.find("{")
if start == -1:
return None
depth = 0
in_string = False
escape_next = False
for i, char in enumerate(text[start:], start):
if escape_next:
escape_next = False
continue
if char == "\\" and in_string:
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
return None
class NodeSpec(BaseModel):
"""
Specification for a node in the graph.
This is the declarative definition of a node - what it does,
what it needs, and what it produces. The actual implementation
is separate (NodeProtocol).
Example:
NodeSpec(
id="calculator",
name="Calculator Node",
description="Performs mathematical calculations",
node_type="event_loop",
input_keys=["expression"],
output_keys=["result"],
tools=["calculate", "math_function"],
system_prompt="You are a calculator..."
)
"""
id: str
name: str
description: str
# Node behavior type
node_type: str = Field(
default="event_loop",
description="Type: 'event_loop' (recommended), 'router', 'human_input'.",
)
# Data flow
input_keys: list[str] = Field(
default_factory=list, description="Keys this node reads from shared memory or input"
)
output_keys: list[str] = Field(
default_factory=list, description="Keys this node writes to shared memory or output"
)
nullable_output_keys: list[str] = Field(
default_factory=list,
description="Output keys that can be None without triggering validation errors",
)
# Optional schemas for validation and cleansing
input_schema: dict[str, dict] = Field(
default_factory=dict,
description=(
"Optional schema for input validation. "
"Format: {key: {type: 'string', required: True, description: '...'}}"
),
)
output_schema: dict[str, dict] = Field(
default_factory=dict,
description=(
"Optional schema for output validation. "
"Format: {key: {type: 'dict', required: True, description: '...'}}"
),
)
# For LLM nodes
system_prompt: str | None = Field(default=None, description="System prompt for LLM nodes")
tools: list[str] = Field(default_factory=list, description="Tool names this node can use")
model: str | None = Field(
default=None, description="Specific model to use (defaults to graph default)"
)
# For router nodes
routes: dict[str, str] = Field(
default_factory=dict, description="Condition -> target_node_id mapping for routers"
)
# Retry behavior
max_retries: int = Field(default=3)
retry_on: list[str] = Field(default_factory=list, description="Error types to retry on")
# Visit limits (for feedback/callback edges)
max_node_visits: int = Field(
default=1,
description=(
"Max times this node executes in one graph run. "
"Set >1 for feedback loops. 0 = unlimited (max_steps guards)."
),
)
# Pydantic model for output validation
output_model: type[BaseModel] | None = Field(
default=None,
description=(
"Optional Pydantic model class for validating and parsing LLM output. "
"When set, the LLM response will be validated against this model."
),
)
max_validation_retries: int = Field(
default=2,
description="Maximum retries when Pydantic validation fails (with feedback to LLM)",
)
# Client-facing behavior
client_facing: bool = Field(
default=False,
description="If True, this node streams output to the end user and can request input.",
)
# Phase completion criteria for conversation-aware judge (Level 2)
success_criteria: str | None = Field(
default=None,
description=(
"Natural-language criteria for phase completion. When set, the "
"implicit judge upgrades to Level 2: after output keys are satisfied, "
"a fast LLM evaluates whether the conversation meets these criteria."
),
)
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
class MemoryWriteError(Exception):
"""Raised when an invalid value is written to memory."""
pass
@dataclass
class SharedMemory:
"""
Shared state between nodes in a graph execution.
Nodes read and write to shared memory using typed keys.
The memory is scoped to a single run.
For parallel execution, use write_async() which provides per-key locking
to prevent race conditions when multiple nodes write concurrently.
"""
_data: dict[str, Any] = field(default_factory=dict)
_allowed_read: set[str] = field(default_factory=set)
_allowed_write: set[str] = field(default_factory=set)
# Locks for thread-safe parallel execution
_lock: asyncio.Lock | None = field(default=None, repr=False)
_key_locks: dict[str, asyncio.Lock] = field(default_factory=dict, repr=False)
def __post_init__(self) -> None:
"""Initialize the main lock if not provided."""
if self._lock is None:
self._lock = asyncio.Lock()
def read(self, key: str) -> Any:
"""Read a value from shared memory."""
if self._allowed_read and key not in self._allowed_read:
raise PermissionError(f"Node not allowed to read key: {key}")
return self._data.get(key)
def write(self, key: str, value: Any, validate: bool = True) -> None:
"""
Write a value to shared memory.
Args:
key: The memory key to write to
value: The value to write
validate: If True, check for suspicious content (default True)
Raises:
PermissionError: If node doesn't have write permission
MemoryWriteError: If value appears to be hallucinated content
"""
if self._allowed_write and key not in self._allowed_write:
raise PermissionError(f"Node not allowed to write key: {key}")
if validate and isinstance(value, str):
# Check for obviously hallucinated content
if len(value) > 5000:
# Long strings that look like code are suspicious
if self._contains_code_indicators(value):
logger.warning(
f"⚠ Suspicious write to key '{key}': appears to be code "
f"({len(value)} chars). Consider using validate=False if intended."
)
raise MemoryWriteError(
f"Rejected suspicious content for key '{key}': "
f"appears to be hallucinated code ({len(value)} chars). "
"If this is intentional, use validate=False."
)
self._data[key] = value
async def write_async(self, key: str, value: Any, validate: bool = True) -> None:
"""
Thread-safe async write with per-key locking.
Use this method when multiple nodes may write concurrently during
parallel execution. Each key has its own lock to minimize contention.
Args:
key: The memory key to write to
value: The value to write
validate: If True, check for suspicious content (default True)
Raises:
PermissionError: If node doesn't have write permission
MemoryWriteError: If value appears to be hallucinated content
"""
# Check permissions first (no lock needed)
if self._allowed_write and key not in self._allowed_write:
raise PermissionError(f"Node not allowed to write key: {key}")
# Ensure key has a lock (double-checked locking pattern)
if key not in self._key_locks:
async with self._lock:
if key not in self._key_locks:
self._key_locks[key] = asyncio.Lock()
# Acquire per-key lock and write
async with self._key_locks[key]:
if validate and isinstance(value, str):
if len(value) > 5000:
if self._contains_code_indicators(value):
logger.warning(
f"⚠ Suspicious write to key '{key}': appears to be code "
f"({len(value)} chars). Consider using validate=False if intended."
)
raise MemoryWriteError(
f"Rejected suspicious content for key '{key}': "
f"appears to be hallucinated code ({len(value)} chars). "
"If this is intentional, use validate=False."
)
self._data[key] = value
def _contains_code_indicators(self, value: str) -> bool:
"""
Check for code patterns in a string using sampling for efficiency.
For strings under 10KB, checks the entire content.
For longer strings, samples at strategic positions to balance
performance with detection accuracy.
Args:
value: The string to check for code indicators
Returns:
True if code indicators are found, False otherwise
"""
code_indicators = [
# Python
"```python",
"def ",
"class ",
"import ",
"async def ",
"from ",
# JavaScript/TypeScript
"function ",
"const ",
"let ",
"=> {",
"require(",
"export ",
# SQL
"SELECT ",
"INSERT ",
"UPDATE ",
"DELETE ",
"DROP ",
# HTML/Script injection
"<script",
"<?php",
"<%",
]
# For strings under 10KB, check the entire content
if len(value) < 10000:
return any(indicator in value for indicator in code_indicators)
# For longer strings, sample at strategic positions
sample_positions = [
0, # Start
len(value) // 4, # 25%
len(value) // 2, # 50%
3 * len(value) // 4, # 75%
max(0, len(value) - 2000), # Near end
]
for pos in sample_positions:
chunk = value[pos : pos + 2000]
if any(indicator in chunk for indicator in code_indicators):
return True
return False
def read_all(self) -> dict[str, Any]:
"""Read all accessible data."""
if self._allowed_read:
return {k: v for k, v in self._data.items() if k in self._allowed_read}
return dict(self._data)
def with_permissions(
self,
read_keys: list[str],
write_keys: list[str],
) -> "SharedMemory":
"""Create a view with restricted permissions for a specific node.
The scoped view shares the same underlying data and locks,
enabling thread-safe parallel execution across scoped views.
"""
return SharedMemory(
_data=self._data,
_allowed_read=set(read_keys) if read_keys else set(),
_allowed_write=set(write_keys) if write_keys else set(),
_lock=self._lock, # Share lock for thread safety
_key_locks=self._key_locks, # Share key locks
)
@dataclass
class NodeContext:
"""
Everything a node needs to execute.
This is passed to every node and provides:
- Access to the runtime (for decision logging)
- Access to shared memory (for state)
- Access to LLM (for generation)
- Access to tools (for actions)
- The goal context (for guidance)
"""
# Core runtime
runtime: Runtime
# Node identity
node_id: str
node_spec: NodeSpec
# State
memory: SharedMemory
input_data: dict[str, Any] = field(default_factory=dict)
# LLM access (if applicable)
llm: LLMProvider | None = None
available_tools: list[Tool] = field(default_factory=list)
# Goal context
goal_context: str = ""
goal: Any = None # Goal object for LLM-powered routers
# LLM configuration
max_tokens: int = 4096 # Maximum tokens for LLM responses
# Execution metadata
attempt: int = 1
max_attempts: int = 3
# Runtime logging (optional)
runtime_logger: Any = None # RuntimeLogger | None — uses Any to avoid import
# Pause control (optional) - asyncio.Event for pause requests
pause_event: Any = None # asyncio.Event | None
# Continuous conversation mode
continuous_mode: bool = False # True when graph has conversation_mode="continuous"
inherited_conversation: Any = None # NodeConversation | None (from prior node)
cumulative_output_keys: list[str] = field(default_factory=list) # All output keys from path
# Event-triggered execution (no interactive user attached)
event_triggered: bool = False
@dataclass
class NodeResult:
"""
The output of a node execution.
Contains:
- Success/failure status
- Output data
- State changes made
- Route decision (for routers)
"""
success: bool
output: dict[str, Any] = field(default_factory=dict)
error: str | None = None
# For routing decisions
next_node: str | None = None
route_reason: str | None = None
# Metadata
tokens_used: int = 0
latency_ms: int = 0
# Pydantic validation errors (if any)
validation_errors: list[str] = field(default_factory=list)
# Continuous conversation mode: return conversation for threading to next node
conversation: Any = None # NodeConversation | None
def to_summary(self, node_spec: Any = None) -> str:
"""
Generate a human-readable summary of this node's execution and output.
This is like toString() - it describes what the node produced in its current state.
Uses Haiku to intelligently summarize complex outputs.
"""
if not self.success:
return f"❌ Failed: {self.error}"
if not self.output:
return "✓ Completed (no output)"
# Use Haiku to generate intelligent summary
import os
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
# Fallback: simple key-value listing
parts = [f"✓ Completed with {len(self.output)} outputs:"]
for key, value in list(self.output.items())[:5]: # Limit to 5 keys
value_str = str(value)[:100]
if len(str(value)) > 100:
value_str += "..."
parts.append(f"{key}: {value_str}")
return "\n".join(parts)
# Use Haiku to generate intelligent summary
try:
import json
import anthropic
node_context = ""
if node_spec:
node_context = f"\nNode: {node_spec.name}\nPurpose: {node_spec.description}"
output_json = json.dumps(self.output, indent=2, default=str)[:2000]
prompt = (
f"Generate a 1-2 sentence human-readable summary of "
f"what this node produced.{node_context}\n\n"
f"Node output:\n{output_json}\n\n"
"Provide a concise, clear summary that a human can quickly "
"understand. Focus on the key information produced."
)
client = anthropic.Anthropic(api_key=api_key)
message = client.messages.create(
model="claude-3-5-haiku-20241022",
max_tokens=200,
messages=[{"role": "user", "content": prompt}],
)
summary = message.content[0].text.strip()
return f"{summary}"
except Exception:
# Fallback on error
parts = [f"✓ Completed with {len(self.output)} outputs:"]
for key, value in list(self.output.items())[:3]:
value_str = str(value)[:80]
if len(str(value)) > 80:
value_str += "..."
parts.append(f"{key}: {value_str}")
return "\n".join(parts)
class NodeProtocol(ABC):
"""
The interface all nodes must implement.
To create a node:
1. Subclass NodeProtocol
2. Implement execute()
3. Register with the executor
Example:
class CalculatorNode(NodeProtocol):
async def execute(self, ctx: NodeContext) -> NodeResult:
expression = ctx.input_data.get("expression")
# Record decision
decision_id = ctx.runtime.decide(
intent="Calculate expression",
options=[...],
chosen="evaluate",
reasoning="Direct evaluation"
)
# Do the work
result = eval(expression)
# Record outcome
ctx.runtime.record_outcome(decision_id, success=True, result=result)
return NodeResult(success=True, output={"result": result})
"""
@abstractmethod
async def execute(self, ctx: NodeContext) -> NodeResult:
"""
Execute this node's logic.
Args:
ctx: NodeContext with everything needed
Returns:
NodeResult with output and status
"""
pass
def validate_input(self, ctx: NodeContext) -> list[str]:
"""
Validate that required inputs are present.
Override to add custom validation.
Returns:
List of validation error messages (empty if valid)
"""
errors = []
for key in ctx.node_spec.input_keys:
if key not in ctx.input_data and ctx.memory.read(key) is None:
errors.append(f"Missing required input: {key}")
return errors