Files
hive/core/framework/graph/edge.py
T
2026-01-23 17:21:59 -08:00

442 lines
14 KiB
Python

"""
Edge Protocol - How nodes connect in a graph.
Edges define:
1. Source and target nodes
2. Conditions for traversal
3. Data mapping between nodes
Unlike traditional graph frameworks where edges are programmatic,
our edges can be created dynamically by a Builder agent based on the goal.
Edge Types:
- always: Always traverse after source completes
- on_success: Traverse only if source succeeds
- on_failure: Traverse only if source fails
- conditional: Traverse based on expression evaluation
- llm_decide: Let LLM decide based on goal and context (goal-aware routing)
The llm_decide condition is particularly powerful for goal-driven agents,
allowing the LLM to evaluate whether proceeding along an edge makes sense
given the current goal, context, and execution state.
"""
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class EdgeCondition(str, Enum):
"""When an edge should be traversed."""
ALWAYS = "always" # Always after source completes
ON_SUCCESS = "on_success" # Only if source succeeds
ON_FAILURE = "on_failure" # Only if source fails
CONDITIONAL = "conditional" # Based on expression
LLM_DECIDE = "llm_decide" # Let LLM decide based on goal and context
class EdgeSpec(BaseModel):
"""
Specification for an edge between nodes.
Examples:
# Simple success-based routing
EdgeSpec(
id="calc-to-format",
source="calculator",
target="formatter",
condition=EdgeCondition.ON_SUCCESS,
input_mapping={"result": "value_to_format"}
)
# Conditional routing based on output
EdgeSpec(
id="validate-to-retry",
source="validator",
target="retry_handler",
condition=EdgeCondition.CONDITIONAL,
condition_expr="output.confidence < 0.8",
)
# LLM-powered routing (goal-aware)
EdgeSpec(
id="search-to-filter",
source="search_results",
target="filter_results",
condition=EdgeCondition.LLM_DECIDE,
description="Only filter if results need refinement to meet goal",
)
"""
id: str
source: str = Field(description="Source node ID")
target: str = Field(description="Target node ID")
# When to traverse
condition: EdgeCondition = EdgeCondition.ALWAYS
condition_expr: str | None = Field(
default=None,
description="Expression for CONDITIONAL edges, e.g., 'output.confidence > 0.8'",
)
# Data flow
input_mapping: dict[str, str] = Field(
default_factory=dict,
description="Map source outputs to target inputs: {target_key: source_key}",
)
# Priority for multiple outgoing edges
priority: int = Field(default=0, description="Higher priority edges are evaluated first")
# Metadata
description: str = ""
model_config = {"extra": "allow"}
def should_traverse(
self,
source_success: bool,
source_output: dict[str, Any],
memory: dict[str, Any],
llm: Any | None = None,
goal: Any | None = None,
source_node_name: str | None = None,
target_node_name: str | None = None,
) -> bool:
"""
Determine if this edge should be traversed.
Args:
source_success: Whether the source node succeeded
source_output: Output from the source node
memory: Current shared memory state
llm: LLM provider for LLM_DECIDE edges
goal: Goal object for LLM_DECIDE edges
source_node_name: Name of source node (for LLM context)
target_node_name: Name of target node (for LLM context)
Returns:
True if the edge should be traversed
"""
if self.condition == EdgeCondition.ALWAYS:
return True
if self.condition == EdgeCondition.ON_SUCCESS:
return source_success
if self.condition == EdgeCondition.ON_FAILURE:
return not source_success
if self.condition == EdgeCondition.CONDITIONAL:
return self._evaluate_condition(source_output, memory)
if self.condition == EdgeCondition.LLM_DECIDE:
if llm is None or goal is None:
# Fallback to ON_SUCCESS if LLM not available
return source_success
return self._llm_decide(
llm=llm,
goal=goal,
source_success=source_success,
source_output=source_output,
memory=memory,
source_node_name=source_node_name,
target_node_name=target_node_name,
)
return False
def _evaluate_condition(
self,
output: dict[str, Any],
memory: dict[str, Any],
) -> bool:
"""Evaluate a conditional expression."""
if not self.condition_expr:
return True
# Build evaluation context
# Include memory keys directly for easier access in conditions
context = {
"output": output,
"memory": memory,
"result": output.get("result"),
"true": True, # Allow lowercase true/false in conditions
"false": False,
**memory, # Unpack memory keys directly into context
}
try:
# Safe evaluation (in production, use a proper expression evaluator)
return bool(eval(self.condition_expr, {"__builtins__": {}}, context))
except Exception as e:
# Log the error for debugging
import logging
logger = logging.getLogger(__name__)
logger.warning(f" ⚠ Condition evaluation failed: {self.condition_expr}")
logger.warning(f" Error: {e}")
logger.warning(f" Available context keys: {list(context.keys())}")
return False
def _llm_decide(
self,
llm: Any,
goal: Any,
source_success: bool,
source_output: dict[str, Any],
memory: dict[str, Any],
source_node_name: str | None,
target_node_name: str | None,
) -> bool:
"""
Use LLM to decide if this edge should be traversed.
The LLM evaluates whether proceeding to the target node
is the best next step toward achieving the goal.
"""
import json
# Build context for LLM
prompt = f"""You are evaluating whether to proceed along an edge in an agent workflow.
**Goal**: {goal.name}
{goal.description}
**Current State**:
- Just completed: {source_node_name or "unknown node"}
- Success: {source_success}
- Output: {json.dumps(source_output, default=str)}
**Decision**:
Should we proceed to: {target_node_name or self.target}?
Edge description: {self.description or "No description"}
**Context from memory**:
{json.dumps({k: str(v)[:100] for k, v in list(memory.items())[:5]}, indent=2)}
Evaluate whether proceeding to this next node is the right step toward achieving the goal.
Consider:
1. Does the current output suggest we should proceed?
2. Is this the logical next step given the goal?
3. Are there any issues that would make proceeding unwise?
Respond with ONLY a JSON object:
{{"proceed": true/false, "reasoning": "brief explanation"}}"""
try:
response = llm.complete(
messages=[{"role": "user", "content": prompt}],
system="You are a routing agent. Respond with JSON only.",
max_tokens=150,
)
# Parse response
import re
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
proceed = data.get("proceed", False)
reasoning = data.get("reasoning", "")
# Log the decision (using basic print for now)
import logging
logger = logging.getLogger(__name__)
logger.info(f" 🤔 LLM routing decision: {'PROCEED' if proceed else 'SKIP'}")
logger.info(f" Reason: {reasoning}")
return proceed
except Exception as e:
# Fallback: proceed on success
import logging
logger = logging.getLogger(__name__)
logger.warning(f" ⚠ LLM routing failed, defaulting to on_success: {e}")
return source_success
return source_success
def map_inputs(
self,
source_output: dict[str, Any],
memory: dict[str, Any],
) -> dict[str, Any]:
"""
Map source outputs to target inputs.
Args:
source_output: Output from source node
memory: Current shared memory
Returns:
Input dict for target node
"""
if not self.input_mapping:
# Default: pass through all outputs
return dict(source_output)
result = {}
for target_key, source_key in self.input_mapping.items():
# Try source output first, then memory
if source_key in source_output:
result[target_key] = source_output[source_key]
elif source_key in memory:
result[target_key] = memory[source_key]
return result
class GraphSpec(BaseModel):
"""
Complete specification of an agent graph.
Contains all nodes, edges, and metadata needed to execute.
Example:
GraphSpec(
id="calculator-graph",
goal_id="calc-001",
entry_node="input_parser",
terminal_nodes=["output_formatter", "error_handler"],
nodes=[...],
edges=[...],
)
"""
id: str
goal_id: str
version: str = "1.0.0"
# Graph structure
entry_node: str = Field(description="ID of the first node to execute")
entry_points: dict[str, str] = Field(
default_factory=dict,
description="Named entry points for resuming execution. Format: {name: node_id}",
)
terminal_nodes: list[str] = Field(default_factory=list, description="IDs of nodes that end execution")
pause_nodes: list[str] = Field(default_factory=list, description="IDs of nodes that pause execution for HITL input")
# Components
nodes: list[Any] = Field( # NodeSpec, but avoiding circular import
default_factory=list, description="All node specifications"
)
edges: list[EdgeSpec] = Field(default_factory=list, description="All edge specifications")
# Shared memory keys
memory_keys: list[str] = Field(default_factory=list, description="Keys available in shared memory")
# Default LLM settings
default_model: str = "claude-haiku-4-5-20251001"
max_tokens: int = 1024
# Execution limits
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
max_retries_per_node: int = 3
# Metadata
description: str = ""
created_by: str = "" # "human" or "builder_agent"
model_config = {"extra": "allow"}
def get_node(self, node_id: str) -> Any | None:
"""Get a node by ID."""
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_outgoing_edges(self, node_id: str) -> list[EdgeSpec]:
"""Get all edges leaving a node, sorted by priority."""
edges = [e for e in self.edges if e.source == node_id]
return sorted(edges, key=lambda e: -e.priority)
def get_incoming_edges(self, node_id: str) -> list[EdgeSpec]:
"""Get all edges entering a node."""
return [e for e in self.edges if e.target == node_id]
def get_entry_point(self, session_state: dict | None = None) -> str:
"""
Get the appropriate entry point based on session state.
Args:
session_state: Optional session state with 'paused_at' or 'resume_from' key
Returns:
Node ID to start execution from
"""
if not session_state:
return self.entry_node
# Check if resuming from a pause node
paused_at = session_state.get("paused_at")
if paused_at and paused_at in self.pause_nodes:
# Look for a resume entry point
resume_key = f"{paused_at}_resume"
if resume_key in self.entry_points:
return self.entry_points[resume_key]
# Check for explicit resume_from
resume_from = session_state.get("resume_from")
if resume_from:
if resume_from in self.entry_points:
return self.entry_points[resume_from]
elif resume_from in [n.id for n in self.nodes]:
return resume_from
# Default to main entry
return self.entry_node
def validate(self) -> list[str]:
"""Validate the graph structure."""
errors = []
# Check entry node exists
if not self.get_node(self.entry_node):
errors.append(f"Entry node '{self.entry_node}' not found")
# Check terminal nodes exist
for term in self.terminal_nodes:
if not self.get_node(term):
errors.append(f"Terminal node '{term}' not found")
# Check edge references
for edge in self.edges:
if not self.get_node(edge.source):
errors.append(f"Edge '{edge.id}' references missing source '{edge.source}'")
if not self.get_node(edge.target):
errors.append(f"Edge '{edge.id}' references missing target '{edge.target}'")
# Check for unreachable nodes
# Start with main entry node and all entry points (for pause/resume architecture)
reachable = set()
to_visit = [self.entry_node]
# Add all entry points as valid starting points (they're reachable by definition)
for entry_point_node in self.entry_points.values():
to_visit.append(entry_point_node)
# Traverse from all entry points
while to_visit:
current = to_visit.pop()
if current in reachable:
continue
reachable.add(current)
for edge in self.get_outgoing_edges(current):
to_visit.append(edge.target)
for node in self.nodes:
if node.id not in reachable:
# Skip this error if the node is a pause node or an entry point target
# (pause/resume architecture makes these reachable via session state)
if node.id in self.pause_nodes or node.id in self.entry_points.values():
continue
errors.append(f"Node '{node.id}' is unreachable from entry")
return errors