Files
hive/core/framework/graph/edge.py
T
2026-01-27 10:45:49 -08:00

613 lines
21 KiB
Python

"""
Edge Protocol - How nodes connect in a graph.
Edges define:
1. Source and target nodes
2. Conditions for traversal
3. Data mapping between nodes
Unlike traditional graph frameworks where edges are programmatic,
our edges can be created dynamically by a Builder agent based on the goal.
Edge Types:
- always: Always traverse after source completes
- always: Always traverse after source completes
- on_success: Traverse only if source succeeds
- on_failure: Traverse only if source fails
- conditional: Traverse based on expression evaluation (SAFE SUBSET ONLY)
- llm_decide: Let LLM decide based on goal and context (goal-aware routing)
The llm_decide condition is particularly powerful for goal-driven agents,
allowing the LLM to evaluate whether proceeding along an edge makes sense
given the current goal, context, and execution state.
"""
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from framework.graph.safe_eval import safe_eval
class EdgeCondition(str, Enum):
"""When an edge should be traversed."""
ALWAYS = "always" # Always after source completes
ON_SUCCESS = "on_success" # Only if source succeeds
ON_FAILURE = "on_failure" # Only if source fails
CONDITIONAL = "conditional" # Based on expression
LLM_DECIDE = "llm_decide" # Let LLM decide based on goal and context
class EdgeSpec(BaseModel):
"""
Specification for an edge between nodes.
Examples:
# Simple success-based routing
EdgeSpec(
id="calc-to-format",
source="calculator",
target="formatter",
condition=EdgeCondition.ON_SUCCESS,
input_mapping={"result": "value_to_format"}
)
# Conditional routing based on output
EdgeSpec(
id="validate-to-retry",
source="validator",
target="retry_handler",
condition=EdgeCondition.CONDITIONAL,
condition_expr="output.confidence < 0.8",
)
# LLM-powered routing (goal-aware)
EdgeSpec(
id="search-to-filter",
source="search_results",
target="filter_results",
condition=EdgeCondition.LLM_DECIDE,
description="Only filter if results need refinement to meet goal",
)
"""
id: str
source: str = Field(description="Source node ID")
target: str = Field(description="Target node ID")
# When to traverse
condition: EdgeCondition = EdgeCondition.ALWAYS
condition_expr: str | None = Field(
default=None,
description="Expression for CONDITIONAL edges, e.g., 'output.confidence > 0.8'",
)
# Data flow
input_mapping: dict[str, str] = Field(
default_factory=dict,
description="Map source outputs to target inputs: {target_key: source_key}",
)
# Priority for multiple outgoing edges
priority: int = Field(default=0, description="Higher priority edges are evaluated first")
# Metadata
description: str = ""
model_config = {"extra": "allow"}
def should_traverse(
self,
source_success: bool,
source_output: dict[str, Any],
memory: dict[str, Any],
llm: Any | None = None,
goal: Any | None = None,
source_node_name: str | None = None,
target_node_name: str | None = None,
) -> bool:
"""
Determine if this edge should be traversed.
Args:
source_success: Whether the source node succeeded
source_output: Output from the source node
memory: Current shared memory state
llm: LLM provider for LLM_DECIDE edges
goal: Goal object for LLM_DECIDE edges
source_node_name: Name of source node (for LLM context)
target_node_name: Name of target node (for LLM context)
Returns:
True if the edge should be traversed
"""
if self.condition == EdgeCondition.ALWAYS:
return True
if self.condition == EdgeCondition.ON_SUCCESS:
return source_success
if self.condition == EdgeCondition.ON_FAILURE:
return not source_success
if self.condition == EdgeCondition.CONDITIONAL:
return self._evaluate_condition(source_output, memory)
if self.condition == EdgeCondition.LLM_DECIDE:
if llm is None or goal is None:
# Fallback to ON_SUCCESS if LLM not available
return source_success
return self._llm_decide(
llm=llm,
goal=goal,
source_success=source_success,
source_output=source_output,
memory=memory,
source_node_name=source_node_name,
target_node_name=target_node_name,
)
return False
def _evaluate_condition(
self,
output: dict[str, Any],
memory: dict[str, Any],
) -> bool:
"""Evaluate a conditional expression."""
if not self.condition_expr:
return True
# Build evaluation context
# Include memory keys directly for easier access in conditions
context = {
"output": output,
"memory": memory,
"result": output.get("result"),
"true": True, # Allow lowercase true/false in conditions
"false": False,
**memory, # Unpack memory keys directly into context
}
try:
# Safe evaluation using AST-based whitelist
return bool(safe_eval(self.condition_expr, context))
except Exception as e:
# Log the error for debugging
import logging
logger = logging.getLogger(__name__)
logger.warning(f" ⚠ Condition evaluation failed: {self.condition_expr}")
logger.warning(f" Error: {e}")
logger.warning(f" Available context keys: {list(context.keys())}")
return False
def _llm_decide(
self,
llm: Any,
goal: Any,
source_success: bool,
source_output: dict[str, Any],
memory: dict[str, Any],
source_node_name: str | None,
target_node_name: str | None,
) -> bool:
"""
Use LLM to decide if this edge should be traversed.
The LLM evaluates whether proceeding to the target node
is the best next step toward achieving the goal.
"""
import json
# Build context for LLM
prompt = f"""You are evaluating whether to proceed along an edge in an agent workflow.
**Goal**: {goal.name}
{goal.description}
**Current State**:
- Just completed: {source_node_name or "unknown node"}
- Success: {source_success}
- Output: {json.dumps(source_output, default=str)}
**Decision**:
Should we proceed to: {target_node_name or self.target}?
Edge description: {self.description or "No description"}
**Context from memory**:
{json.dumps({k: str(v)[:100] for k, v in list(memory.items())[:5]}, indent=2)}
Evaluate whether proceeding to this next node is the right step toward achieving the goal.
Consider:
1. Does the current output suggest we should proceed?
2. Is this the logical next step given the goal?
3. Are there any issues that would make proceeding unwise?
Respond with ONLY a JSON object:
{{"proceed": true/false, "reasoning": "brief explanation"}}"""
try:
response = llm.complete(
messages=[{"role": "user", "content": prompt}],
system="You are a routing agent. Respond with JSON only.",
max_tokens=150,
)
# Parse response
import re
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
proceed = data.get("proceed", False)
reasoning = data.get("reasoning", "")
# Log the decision (using basic print for now)
import logging
logger = logging.getLogger(__name__)
logger.info(f" 🤔 LLM routing decision: {'PROCEED' if proceed else 'SKIP'}")
logger.info(f" Reason: {reasoning}")
return proceed
except Exception as e:
# Fallback: proceed on success
import logging
logger = logging.getLogger(__name__)
logger.warning(f" ⚠ LLM routing failed, defaulting to on_success: {e}")
return source_success
return source_success
def map_inputs(
self,
source_output: dict[str, Any],
memory: dict[str, Any],
) -> dict[str, Any]:
"""
Map source outputs to target inputs.
Args:
source_output: Output from source node
memory: Current shared memory
Returns:
Input dict for target node
"""
if not self.input_mapping:
# Default: pass through all outputs
return dict(source_output)
result = {}
for target_key, source_key in self.input_mapping.items():
# Try source output first, then memory
if source_key in source_output:
result[target_key] = source_output[source_key]
elif source_key in memory:
result[target_key] = memory[source_key]
return result
class AsyncEntryPointSpec(BaseModel):
"""
Specification for an asynchronous entry point.
Used with AgentRuntime for multi-entry-point agents that handle
concurrent execution streams (e.g., webhook + API handlers).
Example:
AsyncEntryPointSpec(
id="webhook",
name="Zendesk Webhook Handler",
entry_node="process-webhook",
trigger_type="webhook",
isolation_level="shared",
)
"""
id: str = Field(description="Unique identifier for this entry point")
name: str = Field(description="Human-readable name")
entry_node: str = Field(description="Node ID to start execution from")
trigger_type: str = Field(
default="manual",
description="How this entry point is triggered: webhook, api, timer, event, manual",
)
trigger_config: dict[str, Any] = Field(
default_factory=dict,
description="Trigger-specific configuration (e.g., webhook URL, timer interval)",
)
isolation_level: str = Field(
default="shared", description="State isolation: isolated, shared, or synchronized"
)
priority: int = Field(default=0, description="Execution priority (higher = more priority)")
max_concurrent: int = Field(
default=10, description="Maximum concurrent executions for this entry point"
)
model_config = {"extra": "allow"}
class GraphSpec(BaseModel):
"""
Complete specification of an agent graph.
Contains all nodes, edges, and metadata needed to execute.
For single-entry-point agents (traditional pattern):
GraphSpec(
id="calculator-graph",
goal_id="calc-001",
entry_node="input_parser",
terminal_nodes=["output_formatter", "error_handler"],
nodes=[...],
edges=[...],
)
For multi-entry-point agents (concurrent streams):
GraphSpec(
id="support-agent-graph",
goal_id="support-001",
entry_node="process-webhook", # Default entry
async_entry_points=[
AsyncEntryPointSpec(
id="webhook",
name="Zendesk Webhook",
entry_node="process-webhook",
trigger_type="webhook",
),
AsyncEntryPointSpec(
id="api",
name="API Handler",
entry_node="process-request",
trigger_type="api",
),
],
nodes=[...],
edges=[...],
)
"""
id: str
goal_id: str
version: str = "1.0.0"
# Graph structure
entry_node: str = Field(description="ID of the first node to execute")
entry_points: dict[str, str] = Field(
default_factory=dict,
description="Named entry points for resuming execution. Format: {name: node_id}",
)
async_entry_points: list[AsyncEntryPointSpec] = Field(
default_factory=list,
description=(
"Asynchronous entry points for concurrent execution streams (used with AgentRuntime)"
),
)
terminal_nodes: list[str] = Field(
default_factory=list, description="IDs of nodes that end execution"
)
pause_nodes: list[str] = Field(
default_factory=list, description="IDs of nodes that pause execution for HITL input"
)
# Components
nodes: list[Any] = Field( # NodeSpec, but avoiding circular import
default_factory=list, description="All node specifications"
)
edges: list[EdgeSpec] = Field(default_factory=list, description="All edge specifications")
# Shared memory keys
memory_keys: list[str] = Field(
default_factory=list, description="Keys available in shared memory"
)
# Default LLM settings
default_model: str = "claude-haiku-4-5-20251001"
max_tokens: int = 1024
# Cleanup LLM for JSON extraction fallback (fast/cheap model preferred)
# If not set, uses CEREBRAS_API_KEY -> cerebras/llama-3.3-70b or
# ANTHROPIC_API_KEY -> claude-3-5-haiku as fallback
cleanup_llm_model: str | None = None
# Execution limits
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
max_retries_per_node: int = 3
# Metadata
description: str = ""
created_by: str = "" # "human" or "builder_agent"
model_config = {"extra": "allow"}
def get_node(self, node_id: str) -> Any | None:
"""Get a node by ID."""
for node in self.nodes:
if node.id == node_id:
return node
return None
def has_async_entry_points(self) -> bool:
"""Check if this graph uses async entry points (multi-stream execution)."""
return len(self.async_entry_points) > 0
def get_async_entry_point(self, entry_point_id: str) -> AsyncEntryPointSpec | None:
"""Get an async entry point by ID."""
for ep in self.async_entry_points:
if ep.id == entry_point_id:
return ep
return None
def get_outgoing_edges(self, node_id: str) -> list[EdgeSpec]:
"""Get all edges leaving a node, sorted by priority."""
edges = [e for e in self.edges if e.source == node_id]
return sorted(edges, key=lambda e: -e.priority)
def get_incoming_edges(self, node_id: str) -> list[EdgeSpec]:
"""Get all edges entering a node."""
return [e for e in self.edges if e.target == node_id]
def detect_fan_out_nodes(self) -> dict[str, list[str]]:
"""
Detect nodes that fan-out to multiple targets.
A fan-out occurs when a node has multiple outgoing edges with the same
condition (typically ON_SUCCESS) that should execute in parallel.
Returns:
Dict mapping source_node_id -> list of parallel target_node_ids
"""
fan_outs: dict[str, list[str]] = {}
for node in self.nodes:
outgoing = self.get_outgoing_edges(node.id)
# Fan-out: multiple edges with ON_SUCCESS condition
success_edges = [e for e in outgoing if e.condition == EdgeCondition.ON_SUCCESS]
if len(success_edges) > 1:
fan_outs[node.id] = [e.target for e in success_edges]
return fan_outs
def detect_fan_in_nodes(self) -> dict[str, list[str]]:
"""
Detect nodes that receive from multiple sources (fan-in / convergence).
A fan-in occurs when a node has multiple incoming edges, meaning
it should wait for all predecessor branches to complete.
Returns:
Dict mapping target_node_id -> list of source_node_ids
"""
fan_ins: dict[str, list[str]] = {}
for node in self.nodes:
incoming = self.get_incoming_edges(node.id)
if len(incoming) > 1:
fan_ins[node.id] = [e.source for e in incoming]
return fan_ins
def get_entry_point(self, session_state: dict | None = None) -> str:
"""
Get the appropriate entry point based on session state.
Args:
session_state: Optional session state with 'paused_at' or 'resume_from' key
Returns:
Node ID to start execution from
"""
if not session_state:
return self.entry_node
# Check if resuming from a pause node
paused_at = session_state.get("paused_at")
if paused_at and paused_at in self.pause_nodes:
# Look for a resume entry point
resume_key = f"{paused_at}_resume"
if resume_key in self.entry_points:
return self.entry_points[resume_key]
# Check for explicit resume_from
resume_from = session_state.get("resume_from")
if resume_from:
if resume_from in self.entry_points:
return self.entry_points[resume_from]
elif resume_from in [n.id for n in self.nodes]:
return resume_from
# Default to main entry
return self.entry_node
def validate(self) -> list[str]:
"""Validate the graph structure."""
errors = []
# Check entry node exists
if not self.get_node(self.entry_node):
errors.append(f"Entry node '{self.entry_node}' not found")
# Check async entry points
seen_entry_ids = set()
for entry_point in self.async_entry_points:
# Check for duplicate IDs
if entry_point.id in seen_entry_ids:
errors.append(f"Duplicate async entry point ID: '{entry_point.id}'")
seen_entry_ids.add(entry_point.id)
# Check entry node exists
if not self.get_node(entry_point.entry_node):
errors.append(
f"Async entry point '{entry_point.id}' references "
f"missing node '{entry_point.entry_node}'"
)
# Validate isolation level
valid_isolation = {"isolated", "shared", "synchronized"}
if entry_point.isolation_level not in valid_isolation:
errors.append(
f"Async entry point '{entry_point.id}' has invalid isolation_level "
f"'{entry_point.isolation_level}'. Valid: {valid_isolation}"
)
# Validate trigger type
valid_triggers = {"webhook", "api", "timer", "event", "manual"}
if entry_point.trigger_type not in valid_triggers:
errors.append(
f"Async entry point '{entry_point.id}' has invalid trigger_type "
f"'{entry_point.trigger_type}'. Valid: {valid_triggers}"
)
# Check terminal nodes exist
for term in self.terminal_nodes:
if not self.get_node(term):
errors.append(f"Terminal node '{term}' not found")
# Check edge references
for edge in self.edges:
if not self.get_node(edge.source):
errors.append(f"Edge '{edge.id}' references missing source '{edge.source}'")
if not self.get_node(edge.target):
errors.append(f"Edge '{edge.id}' references missing target '{edge.target}'")
# Check for unreachable nodes
# Start with main entry node and all entry points (for pause/resume architecture)
reachable = set()
to_visit = [self.entry_node]
# Add all entry points as valid starting points (they're reachable by definition)
for entry_point_node in self.entry_points.values():
to_visit.append(entry_point_node)
# Add all async entry points as valid starting points
for async_entry in self.async_entry_points:
to_visit.append(async_entry.entry_node)
# Traverse from all entry points
while to_visit:
current = to_visit.pop()
if current in reachable:
continue
reachable.add(current)
for edge in self.get_outgoing_edges(current):
to_visit.append(edge.target)
# Build set of async entry point nodes for quick lookup
async_entry_nodes = {ep.entry_node for ep in self.async_entry_points}
for node in self.nodes:
if node.id not in reachable:
# Skip if node is a pause node, entry point target, or async entry
# (pause/resume architecture and async entry points make reachable)
if (
node.id in self.pause_nodes
or node.id in self.entry_points.values()
or node.id in async_entry_nodes
):
continue
errors.append(f"Node '{node.id}' is unreachable from entry")
return errors