Merge pull request #1428 from TimothyZhang7/feature/parallel-fanout
Release / Create Release (push) Waiting to run
Release / Create Release (push) Waiting to run
feat: parallel execution framework
This commit is contained in:
@@ -412,6 +412,11 @@ class GraphSpec(BaseModel):
|
|||||||
default_model: str = "claude-haiku-4-5-20251001"
|
default_model: str = "claude-haiku-4-5-20251001"
|
||||||
max_tokens: int = 1024
|
max_tokens: int = 1024
|
||||||
|
|
||||||
|
# Cleanup LLM for JSON extraction fallback (fast/cheap model preferred)
|
||||||
|
# If not set, uses CEREBRAS_API_KEY -> cerebras/llama-3.3-70b or
|
||||||
|
# ANTHROPIC_API_KEY -> claude-3-5-haiku as fallback
|
||||||
|
cleanup_llm_model: str | None = None
|
||||||
|
|
||||||
# Execution limits
|
# Execution limits
|
||||||
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
|
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
|
||||||
max_retries_per_node: int = 3
|
max_retries_per_node: int = 3
|
||||||
@@ -449,6 +454,44 @@ class GraphSpec(BaseModel):
|
|||||||
"""Get all edges entering a node."""
|
"""Get all edges entering a node."""
|
||||||
return [e for e in self.edges if e.target == node_id]
|
return [e for e in self.edges if e.target == node_id]
|
||||||
|
|
||||||
|
def detect_fan_out_nodes(self) -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Detect nodes that fan-out to multiple targets.
|
||||||
|
|
||||||
|
A fan-out occurs when a node has multiple outgoing edges with the same
|
||||||
|
condition (typically ON_SUCCESS) that should execute in parallel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping source_node_id -> list of parallel target_node_ids
|
||||||
|
"""
|
||||||
|
fan_outs: dict[str, list[str]] = {}
|
||||||
|
for node in self.nodes:
|
||||||
|
outgoing = self.get_outgoing_edges(node.id)
|
||||||
|
# Fan-out: multiple edges with ON_SUCCESS condition
|
||||||
|
success_edges = [
|
||||||
|
e for e in outgoing if e.condition == EdgeCondition.ON_SUCCESS
|
||||||
|
]
|
||||||
|
if len(success_edges) > 1:
|
||||||
|
fan_outs[node.id] = [e.target for e in success_edges]
|
||||||
|
return fan_outs
|
||||||
|
|
||||||
|
def detect_fan_in_nodes(self) -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Detect nodes that receive from multiple sources (fan-in / convergence).
|
||||||
|
|
||||||
|
A fan-in occurs when a node has multiple incoming edges, meaning
|
||||||
|
it should wait for all predecessor branches to complete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping target_node_id -> list of source_node_ids
|
||||||
|
"""
|
||||||
|
fan_ins: dict[str, list[str]] = {}
|
||||||
|
for node in self.nodes:
|
||||||
|
incoming = self.get_incoming_edges(node.id)
|
||||||
|
if len(incoming) > 1:
|
||||||
|
fan_ins[node.id] = [e.source for e in incoming]
|
||||||
|
return fan_ins
|
||||||
|
|
||||||
def get_entry_point(self, session_state: dict | None = None) -> str:
|
def get_entry_point(self, session_state: dict | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
Get the appropriate entry point based on session state.
|
Get the appropriate entry point based on session state.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from collections.abc import Callable
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from framework.graph.edge import GraphSpec
|
from framework.graph.edge import EdgeSpec, GraphSpec
|
||||||
from framework.graph.goal import Goal
|
from framework.graph.goal import Goal
|
||||||
from framework.graph.node import (
|
from framework.graph.node import (
|
||||||
FunctionNode,
|
FunctionNode,
|
||||||
@@ -48,6 +48,35 @@ class ExecutionResult:
|
|||||||
session_state: dict[str, Any] = field(default_factory=dict) # State to resume from
|
session_state: dict[str, Any] = field(default_factory=dict) # State to resume from
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParallelBranch:
|
||||||
|
"""Tracks a single branch in parallel fan-out execution."""
|
||||||
|
|
||||||
|
branch_id: str
|
||||||
|
node_id: str
|
||||||
|
edge: EdgeSpec
|
||||||
|
result: "NodeResult | None" = None
|
||||||
|
status: str = "pending" # pending, running, completed, failed
|
||||||
|
retry_count: int = 0
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParallelExecutionConfig:
|
||||||
|
"""Configuration for parallel execution behavior."""
|
||||||
|
|
||||||
|
# Error handling: "fail_all" cancels all on first failure,
|
||||||
|
# "continue_others" lets remaining branches complete,
|
||||||
|
# "wait_all" waits for all and reports all failures
|
||||||
|
on_branch_failure: str = "fail_all"
|
||||||
|
|
||||||
|
# Memory conflict handling when branches write same key
|
||||||
|
memory_conflict_strategy: str = "last_wins" # "last_wins", "first_wins", "error"
|
||||||
|
|
||||||
|
# Timeout per branch in seconds
|
||||||
|
branch_timeout_seconds: float = 300.0
|
||||||
|
|
||||||
|
|
||||||
class GraphExecutor:
|
class GraphExecutor:
|
||||||
"""
|
"""
|
||||||
Executes agent graphs.
|
Executes agent graphs.
|
||||||
@@ -76,6 +105,8 @@ class GraphExecutor:
|
|||||||
node_registry: dict[str, NodeProtocol] | None = None,
|
node_registry: dict[str, NodeProtocol] | None = None,
|
||||||
approval_callback: Callable | None = None,
|
approval_callback: Callable | None = None,
|
||||||
cleansing_config: CleansingConfig | None = None,
|
cleansing_config: CleansingConfig | None = None,
|
||||||
|
enable_parallel_execution: bool = True,
|
||||||
|
parallel_config: ParallelExecutionConfig | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the executor.
|
Initialize the executor.
|
||||||
@@ -88,6 +119,8 @@ class GraphExecutor:
|
|||||||
node_registry: Custom node implementations by ID
|
node_registry: Custom node implementations by ID
|
||||||
approval_callback: Optional callback for human-in-the-loop approval
|
approval_callback: Optional callback for human-in-the-loop approval
|
||||||
cleansing_config: Optional output cleansing configuration
|
cleansing_config: Optional output cleansing configuration
|
||||||
|
enable_parallel_execution: Enable parallel fan-out execution (default True)
|
||||||
|
parallel_config: Configuration for parallel execution behavior
|
||||||
"""
|
"""
|
||||||
self.runtime = runtime
|
self.runtime = runtime
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
@@ -105,6 +138,10 @@ class GraphExecutor:
|
|||||||
llm_provider=llm,
|
llm_provider=llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Parallel execution settings
|
||||||
|
self.enable_parallel_execution = enable_parallel_execution
|
||||||
|
self._parallel_config = parallel_config or ParallelExecutionConfig()
|
||||||
|
|
||||||
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Validate that all tools declared by nodes are available.
|
Validate that all tools declared by nodes are available.
|
||||||
@@ -246,6 +283,7 @@ class GraphExecutor:
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
goal=goal,
|
goal=goal,
|
||||||
input_data=input_data or {},
|
input_data=input_data or {},
|
||||||
|
max_tokens=graph.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log actual input data being read
|
# Log actual input data being read
|
||||||
@@ -261,7 +299,7 @@ class GraphExecutor:
|
|||||||
self.logger.info(f" {key}: {value_str}")
|
self.logger.info(f" {key}: {value_str}")
|
||||||
|
|
||||||
# Get or create node implementation
|
# Get or create node implementation
|
||||||
node_impl = self._get_node_implementation(node_spec)
|
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||||
|
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
validation_errors = node_impl.validate_input(ctx)
|
validation_errors = node_impl.validate_input(ctx)
|
||||||
@@ -419,8 +457,8 @@ class GraphExecutor:
|
|||||||
self.logger.info(f" → Router directing to: {result.next_node}")
|
self.logger.info(f" → Router directing to: {result.next_node}")
|
||||||
current_node_id = result.next_node
|
current_node_id = result.next_node
|
||||||
else:
|
else:
|
||||||
# Follow edges
|
# Get all traversable edges for fan-out detection
|
||||||
next_node = self._follow_edges(
|
traversable_edges = self._get_all_traversable_edges(
|
||||||
graph=graph,
|
graph=graph,
|
||||||
goal=goal,
|
goal=goal,
|
||||||
current_node_id=current_node_id,
|
current_node_id=current_node_id,
|
||||||
@@ -428,12 +466,55 @@ class GraphExecutor:
|
|||||||
result=result,
|
result=result,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
)
|
)
|
||||||
if next_node is None:
|
|
||||||
|
if not traversable_edges:
|
||||||
self.logger.info(" → No more edges, ending execution")
|
self.logger.info(" → No more edges, ending execution")
|
||||||
break # No valid edge, end execution
|
break # No valid edge, end execution
|
||||||
next_spec = graph.get_node(next_node)
|
|
||||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
# Check for fan-out (multiple traversable edges)
|
||||||
current_node_id = next_node
|
if self.enable_parallel_execution and len(traversable_edges) > 1:
|
||||||
|
# Find convergence point (fan-in node)
|
||||||
|
targets = [e.target for e in traversable_edges]
|
||||||
|
fan_in_node = self._find_convergence_node(graph, targets)
|
||||||
|
|
||||||
|
# Execute branches in parallel
|
||||||
|
_branch_results, branch_tokens, branch_latency = await self._execute_parallel_branches(
|
||||||
|
graph=graph,
|
||||||
|
goal=goal,
|
||||||
|
edges=traversable_edges,
|
||||||
|
memory=memory,
|
||||||
|
source_result=result,
|
||||||
|
source_node_spec=node_spec,
|
||||||
|
path=path,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_tokens += branch_tokens
|
||||||
|
total_latency += branch_latency
|
||||||
|
|
||||||
|
# Continue from fan-in node
|
||||||
|
if fan_in_node:
|
||||||
|
self.logger.info(f" ⑃ Fan-in: converging at {fan_in_node}")
|
||||||
|
current_node_id = fan_in_node
|
||||||
|
else:
|
||||||
|
# No convergence point - branches are terminal
|
||||||
|
self.logger.info(" → Parallel branches completed (no convergence)")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Sequential: follow single edge (existing logic via _follow_edges)
|
||||||
|
next_node = self._follow_edges(
|
||||||
|
graph=graph,
|
||||||
|
goal=goal,
|
||||||
|
current_node_id=current_node_id,
|
||||||
|
current_node_spec=node_spec,
|
||||||
|
result=result,
|
||||||
|
memory=memory,
|
||||||
|
)
|
||||||
|
if next_node is None:
|
||||||
|
self.logger.info(" → No more edges, ending execution")
|
||||||
|
break
|
||||||
|
next_spec = graph.get_node(next_node)
|
||||||
|
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||||
|
current_node_id = next_node
|
||||||
|
|
||||||
# Update input_data for next node
|
# Update input_data for next node
|
||||||
input_data = result.output
|
input_data = result.output
|
||||||
@@ -484,6 +565,7 @@ class GraphExecutor:
|
|||||||
memory: SharedMemory,
|
memory: SharedMemory,
|
||||||
goal: Goal,
|
goal: Goal,
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
|
max_tokens: int = 4096,
|
||||||
) -> NodeContext:
|
) -> NodeContext:
|
||||||
"""Build execution context for a node."""
|
"""Build execution context for a node."""
|
||||||
# Filter tools to those available to this node
|
# Filter tools to those available to this node
|
||||||
@@ -507,12 +589,15 @@ class GraphExecutor:
|
|||||||
available_tools=available_tools,
|
available_tools=available_tools,
|
||||||
goal_context=goal.to_prompt_context(),
|
goal_context=goal.to_prompt_context(),
|
||||||
goal=goal, # Pass Goal object for LLM-powered routers
|
goal=goal, # Pass Goal object for LLM-powered routers
|
||||||
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Valid node types - no ambiguous "llm" type allowed
|
# Valid node types - no ambiguous "llm" type allowed
|
||||||
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||||
|
|
||||||
def _get_node_implementation(self, node_spec: NodeSpec) -> NodeProtocol:
|
def _get_node_implementation(
|
||||||
|
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
|
||||||
|
) -> NodeProtocol:
|
||||||
"""Get or create a node implementation."""
|
"""Get or create a node implementation."""
|
||||||
# Check registry first
|
# Check registry first
|
||||||
if node_spec.id in self.node_registry:
|
if node_spec.id in self.node_registry:
|
||||||
@@ -533,10 +618,18 @@ class GraphExecutor:
|
|||||||
f"Node '{node_spec.id}' is type 'llm_tool_use' but declares no tools. "
|
f"Node '{node_spec.id}' is type 'llm_tool_use' but declares no tools. "
|
||||||
"Either add tools to the node or change type to 'llm_generate'."
|
"Either add tools to the node or change type to 'llm_generate'."
|
||||||
)
|
)
|
||||||
return LLMNode(tool_executor=self.tool_executor, require_tools=True)
|
return LLMNode(
|
||||||
|
tool_executor=self.tool_executor,
|
||||||
|
require_tools=True,
|
||||||
|
cleanup_llm_model=cleanup_llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
if node_spec.node_type == "llm_generate":
|
if node_spec.node_type == "llm_generate":
|
||||||
return LLMNode(tool_executor=None, require_tools=False)
|
return LLMNode(
|
||||||
|
tool_executor=None,
|
||||||
|
require_tools=False,
|
||||||
|
cleanup_llm_model=cleanup_llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
if node_spec.node_type == "router":
|
if node_spec.node_type == "router":
|
||||||
return RouterNode()
|
return RouterNode()
|
||||||
@@ -549,7 +642,11 @@ class GraphExecutor:
|
|||||||
|
|
||||||
if node_spec.node_type == "human_input":
|
if node_spec.node_type == "human_input":
|
||||||
# Human input nodes are handled specially by HITL mechanism
|
# Human input nodes are handled specially by HITL mechanism
|
||||||
return LLMNode(tool_executor=None, require_tools=False)
|
return LLMNode(
|
||||||
|
tool_executor=None,
|
||||||
|
require_tools=False,
|
||||||
|
cleanup_llm_model=cleanup_llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
# Should never reach here due to validation above
|
# Should never reach here due to validation above
|
||||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||||
@@ -608,9 +705,9 @@ class GraphExecutor:
|
|||||||
# Update result with cleaned output
|
# Update result with cleaned output
|
||||||
result.output = cleaned_output
|
result.output = cleaned_output
|
||||||
|
|
||||||
# Write cleaned output back to memory
|
# Write cleaned output back to memory (skip validation for LLM output)
|
||||||
for key, value in cleaned_output.items():
|
for key, value in cleaned_output.items():
|
||||||
memory.write(key, value)
|
memory.write(key, value, validate=False)
|
||||||
|
|
||||||
# Revalidate
|
# Revalidate
|
||||||
revalidation = self.output_cleaner.validate_output(
|
revalidation = self.output_cleaner.validate_output(
|
||||||
@@ -629,15 +726,234 @@ class GraphExecutor:
|
|||||||
)
|
)
|
||||||
# Continue anyway if fallback_to_raw is True
|
# Continue anyway if fallback_to_raw is True
|
||||||
|
|
||||||
# Map inputsss
|
# Map inputs (skip validation for processed LLM output)
|
||||||
mapped = edge.map_inputs(result.output, memory.read_all())
|
mapped = edge.map_inputs(result.output, memory.read_all())
|
||||||
for key, value in mapped.items():
|
for key, value in mapped.items():
|
||||||
memory.write(key, value)
|
memory.write(key, value, validate=False)
|
||||||
|
|
||||||
return edge.target
|
return edge.target
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _get_all_traversable_edges(
|
||||||
|
self,
|
||||||
|
graph: GraphSpec,
|
||||||
|
goal: Goal,
|
||||||
|
current_node_id: str,
|
||||||
|
current_node_spec: Any,
|
||||||
|
result: NodeResult,
|
||||||
|
memory: SharedMemory,
|
||||||
|
) -> list[EdgeSpec]:
|
||||||
|
"""
|
||||||
|
Get ALL edges that should be traversed (for fan-out detection).
|
||||||
|
|
||||||
|
Unlike _follow_edges which returns the first match, this returns
|
||||||
|
all matching edges to enable parallel execution.
|
||||||
|
"""
|
||||||
|
edges = graph.get_outgoing_edges(current_node_id)
|
||||||
|
traversable = []
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
target_node_spec = graph.get_node(edge.target)
|
||||||
|
if edge.should_traverse(
|
||||||
|
source_success=result.success,
|
||||||
|
source_output=result.output,
|
||||||
|
memory=memory.read_all(),
|
||||||
|
llm=self.llm,
|
||||||
|
goal=goal,
|
||||||
|
source_node_name=current_node_spec.name if current_node_spec else current_node_id,
|
||||||
|
target_node_name=target_node_spec.name if target_node_spec else edge.target,
|
||||||
|
):
|
||||||
|
traversable.append(edge)
|
||||||
|
|
||||||
|
return traversable
|
||||||
|
|
||||||
|
def _find_convergence_node(
|
||||||
|
self,
|
||||||
|
graph: GraphSpec,
|
||||||
|
parallel_targets: list[str],
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Find the common target node where parallel branches converge (fan-in).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: The graph specification
|
||||||
|
parallel_targets: List of node IDs that are running in parallel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Node ID where all branches converge, or None if no convergence
|
||||||
|
"""
|
||||||
|
# Get all nodes that parallel branches lead to
|
||||||
|
next_nodes: dict[str, int] = {} # node_id -> count of branches leading to it
|
||||||
|
|
||||||
|
for target in parallel_targets:
|
||||||
|
outgoing = graph.get_outgoing_edges(target)
|
||||||
|
for edge in outgoing:
|
||||||
|
next_nodes[edge.target] = next_nodes.get(edge.target, 0) + 1
|
||||||
|
|
||||||
|
# Convergence node is where ALL branches lead
|
||||||
|
for node_id, count in next_nodes.items():
|
||||||
|
if count == len(parallel_targets):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
# Fallback: return most common target if any
|
||||||
|
if next_nodes:
|
||||||
|
return max(next_nodes.keys(), key=lambda k: next_nodes[k])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _execute_parallel_branches(
|
||||||
|
self,
|
||||||
|
graph: GraphSpec,
|
||||||
|
goal: Goal,
|
||||||
|
edges: list[EdgeSpec],
|
||||||
|
memory: SharedMemory,
|
||||||
|
source_result: NodeResult,
|
||||||
|
source_node_spec: Any,
|
||||||
|
path: list[str],
|
||||||
|
) -> tuple[dict[str, NodeResult], int, int]:
|
||||||
|
"""
|
||||||
|
Execute multiple branches in parallel using asyncio.gather.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: The graph specification
|
||||||
|
goal: The execution goal
|
||||||
|
edges: List of edges to follow in parallel
|
||||||
|
memory: Shared memory instance
|
||||||
|
source_result: Result from the source node
|
||||||
|
source_node_spec: Spec of the source node
|
||||||
|
path: Execution path list to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (branch_results dict, total_tokens, total_latency)
|
||||||
|
"""
|
||||||
|
branches: dict[str, ParallelBranch] = {}
|
||||||
|
|
||||||
|
# Create branches for each edge
|
||||||
|
for edge in edges:
|
||||||
|
branch_id = f"{edge.source}_to_{edge.target}"
|
||||||
|
branches[branch_id] = ParallelBranch(
|
||||||
|
branch_id=branch_id,
|
||||||
|
node_id=edge.target,
|
||||||
|
edge=edge,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f" ⑂ Fan-out: executing {len(branches)} branches in parallel")
|
||||||
|
for branch in branches.values():
|
||||||
|
target_spec = graph.get_node(branch.node_id)
|
||||||
|
self.logger.info(f" • {target_spec.name if target_spec else branch.node_id}")
|
||||||
|
|
||||||
|
async def execute_single_branch(branch: ParallelBranch) -> tuple[ParallelBranch, NodeResult | Exception]:
|
||||||
|
"""Execute a single branch with retry logic."""
|
||||||
|
node_spec = graph.get_node(branch.node_id)
|
||||||
|
branch.status = "running"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate and clean output before mapping inputs (same as _follow_edges)
|
||||||
|
if self.cleansing_config.enabled and node_spec:
|
||||||
|
validation = self.output_cleaner.validate_output(
|
||||||
|
output=source_result.output,
|
||||||
|
source_node_id=source_node_spec.id if source_node_spec else "unknown",
|
||||||
|
target_node_spec=node_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not validation.valid:
|
||||||
|
self.logger.warning(
|
||||||
|
f"⚠ Output validation failed for branch {branch.node_id}: {validation.errors}"
|
||||||
|
)
|
||||||
|
cleaned_output = self.output_cleaner.clean_output(
|
||||||
|
output=source_result.output,
|
||||||
|
source_node_id=source_node_spec.id if source_node_spec else "unknown",
|
||||||
|
target_node_spec=node_spec,
|
||||||
|
validation_errors=validation.errors,
|
||||||
|
)
|
||||||
|
# Write cleaned output to memory
|
||||||
|
for key, value in cleaned_output.items():
|
||||||
|
await memory.write_async(key, value)
|
||||||
|
|
||||||
|
# Map inputs via edge
|
||||||
|
mapped = branch.edge.map_inputs(source_result.output, memory.read_all())
|
||||||
|
for key, value in mapped.items():
|
||||||
|
await memory.write_async(key, value)
|
||||||
|
|
||||||
|
# Execute with retries
|
||||||
|
last_result = None
|
||||||
|
for attempt in range(node_spec.max_retries):
|
||||||
|
branch.retry_count = attempt
|
||||||
|
|
||||||
|
# Build context for this branch
|
||||||
|
ctx = self._build_context(node_spec, memory, goal, mapped, graph.max_tokens)
|
||||||
|
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||||
|
|
||||||
|
self.logger.info(f" ▶ Branch {node_spec.name}: executing (attempt {attempt + 1})")
|
||||||
|
result = await node_impl.execute(ctx)
|
||||||
|
last_result = result
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
# Write outputs to shared memory using async write
|
||||||
|
for key, value in result.output.items():
|
||||||
|
await memory.write_async(key, value)
|
||||||
|
|
||||||
|
branch.result = result
|
||||||
|
branch.status = "completed"
|
||||||
|
self.logger.info(
|
||||||
|
f" ✓ Branch {node_spec.name}: success "
|
||||||
|
f"(tokens: {result.tokens_used}, latency: {result.latency_ms}ms)"
|
||||||
|
)
|
||||||
|
return branch, result
|
||||||
|
|
||||||
|
self.logger.warning(
|
||||||
|
f" ↻ Branch {node_spec.name}: retry {attempt + 1}/{node_spec.max_retries}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# All retries exhausted
|
||||||
|
branch.status = "failed"
|
||||||
|
branch.error = last_result.error if last_result else "Unknown error"
|
||||||
|
branch.result = last_result
|
||||||
|
self.logger.error(f" ✗ Branch {node_spec.name}: failed after {node_spec.max_retries} attempts")
|
||||||
|
return branch, last_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
branch.status = "failed"
|
||||||
|
branch.error = str(e)
|
||||||
|
self.logger.error(f" ✗ Branch {branch.node_id}: exception - {e}")
|
||||||
|
return branch, e
|
||||||
|
|
||||||
|
# Execute all branches concurrently
|
||||||
|
tasks = [execute_single_branch(b) for b in branches.values()]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
total_tokens = 0
|
||||||
|
total_latency = 0
|
||||||
|
branch_results: dict[str, NodeResult] = {}
|
||||||
|
failed_branches: list[ParallelBranch] = []
|
||||||
|
|
||||||
|
for branch, result in results:
|
||||||
|
path.append(branch.node_id)
|
||||||
|
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
failed_branches.append(branch)
|
||||||
|
elif result is None or not result.success:
|
||||||
|
failed_branches.append(branch)
|
||||||
|
else:
|
||||||
|
total_tokens += result.tokens_used
|
||||||
|
total_latency += result.latency_ms
|
||||||
|
branch_results[branch.branch_id] = result
|
||||||
|
|
||||||
|
# Handle failures based on config
|
||||||
|
if failed_branches:
|
||||||
|
failed_names = [graph.get_node(b.node_id).name for b in failed_branches]
|
||||||
|
if self._parallel_config.on_branch_failure == "fail_all":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Parallel execution failed: branches {failed_names} failed"
|
||||||
|
)
|
||||||
|
elif self._parallel_config.on_branch_failure == "continue_others":
|
||||||
|
self.logger.warning(f"⚠ Some branches failed ({failed_names}), continuing with successful ones")
|
||||||
|
|
||||||
|
self.logger.info(f" ⑃ Fan-out complete: {len(branch_results)}/{len(branches)} branches succeeded")
|
||||||
|
return branch_results, total_tokens, total_latency
|
||||||
|
|
||||||
def register_node(self, node_id: str, implementation: NodeProtocol) -> None:
|
def register_node(self, node_id: str, implementation: NodeProtocol) -> None:
|
||||||
"""Register a custom node implementation."""
|
"""Register a custom node implementation."""
|
||||||
self.node_registry[node_id] = implementation
|
self.node_registry[node_id] = implementation
|
||||||
|
|||||||
+439
-148
@@ -15,6 +15,7 @@ Protocol:
|
|||||||
The framework provides NodeContext with everything the node needs.
|
The framework provides NodeContext with everything the node needs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
@@ -29,6 +30,62 @@ from framework.runtime.core import Runtime
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_unescaped_newlines_in_json(json_str: str) -> str:
|
||||||
|
"""Fix unescaped newlines inside JSON string values.
|
||||||
|
|
||||||
|
LLMs sometimes output actual newlines inside JSON strings instead of \\n.
|
||||||
|
This function fixes that by properly escaping newlines within string values.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
in_string = False
|
||||||
|
escape_next = False
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(json_str):
|
||||||
|
char = json_str[i]
|
||||||
|
|
||||||
|
if escape_next:
|
||||||
|
result.append(char)
|
||||||
|
escape_next = False
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if char == "\\" and in_string:
|
||||||
|
escape_next = True
|
||||||
|
result.append(char)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if char == '"' and not escape_next:
|
||||||
|
in_string = not in_string
|
||||||
|
result.append(char)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fix unescaped newlines inside strings
|
||||||
|
if in_string and char == "\n":
|
||||||
|
result.append("\\n")
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fix unescaped carriage returns inside strings
|
||||||
|
if in_string and char == "\r":
|
||||||
|
result.append("\\r")
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fix unescaped tabs inside strings
|
||||||
|
if in_string and char == "\t":
|
||||||
|
result.append("\\t")
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return "".join(result)
|
||||||
|
|
||||||
|
|
||||||
def find_json_object(text: str) -> str | None:
|
def find_json_object(text: str) -> str | None:
|
||||||
"""Find the first valid JSON object in text using balanced brace matching.
|
"""Find the first valid JSON object in text using balanced brace matching.
|
||||||
|
|
||||||
@@ -173,11 +230,22 @@ class SharedMemory:
|
|||||||
|
|
||||||
Nodes read and write to shared memory using typed keys.
|
Nodes read and write to shared memory using typed keys.
|
||||||
The memory is scoped to a single run.
|
The memory is scoped to a single run.
|
||||||
|
|
||||||
|
For parallel execution, use write_async() which provides per-key locking
|
||||||
|
to prevent race conditions when multiple nodes write concurrently.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_data: dict[str, Any] = field(default_factory=dict)
|
_data: dict[str, Any] = field(default_factory=dict)
|
||||||
_allowed_read: set[str] = field(default_factory=set)
|
_allowed_read: set[str] = field(default_factory=set)
|
||||||
_allowed_write: set[str] = field(default_factory=set)
|
_allowed_write: set[str] = field(default_factory=set)
|
||||||
|
# Locks for thread-safe parallel execution
|
||||||
|
_lock: asyncio.Lock | None = field(default=None, repr=False)
|
||||||
|
_key_locks: dict[str, asyncio.Lock] = field(default_factory=dict, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Initialize the main lock if not provided."""
|
||||||
|
if self._lock is None:
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
def read(self, key: str) -> Any:
|
def read(self, key: str) -> Any:
|
||||||
"""Read a value from shared memory."""
|
"""Read a value from shared memory."""
|
||||||
@@ -218,6 +286,48 @@ class SharedMemory:
|
|||||||
|
|
||||||
self._data[key] = value
|
self._data[key] = value
|
||||||
|
|
||||||
|
async def write_async(self, key: str, value: Any, validate: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Thread-safe async write with per-key locking.
|
||||||
|
|
||||||
|
Use this method when multiple nodes may write concurrently during
|
||||||
|
parallel execution. Each key has its own lock to minimize contention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The memory key to write to
|
||||||
|
value: The value to write
|
||||||
|
validate: If True, check for suspicious content (default True)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If node doesn't have write permission
|
||||||
|
MemoryWriteError: If value appears to be hallucinated content
|
||||||
|
"""
|
||||||
|
# Check permissions first (no lock needed)
|
||||||
|
if self._allowed_write and key not in self._allowed_write:
|
||||||
|
raise PermissionError(f"Node not allowed to write key: {key}")
|
||||||
|
|
||||||
|
# Ensure key has a lock (double-checked locking pattern)
|
||||||
|
if key not in self._key_locks:
|
||||||
|
async with self._lock:
|
||||||
|
if key not in self._key_locks:
|
||||||
|
self._key_locks[key] = asyncio.Lock()
|
||||||
|
|
||||||
|
# Acquire per-key lock and write
|
||||||
|
async with self._key_locks[key]:
|
||||||
|
if validate and isinstance(value, str):
|
||||||
|
if len(value) > 5000:
|
||||||
|
if self._contains_code_indicators(value):
|
||||||
|
logger.warning(
|
||||||
|
f"⚠ Suspicious write to key '{key}': appears to be code "
|
||||||
|
f"({len(value)} chars). Consider using validate=False if intended."
|
||||||
|
)
|
||||||
|
raise MemoryWriteError(
|
||||||
|
f"Rejected suspicious content for key '{key}': "
|
||||||
|
f"appears to be hallucinated code ({len(value)} chars). "
|
||||||
|
"If this is intentional, use validate=False."
|
||||||
|
)
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
def _contains_code_indicators(self, value: str) -> bool:
|
def _contains_code_indicators(self, value: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check for code patterns in a string using sampling for efficiency.
|
Check for code patterns in a string using sampling for efficiency.
|
||||||
@@ -290,11 +400,17 @@ class SharedMemory:
|
|||||||
read_keys: list[str],
|
read_keys: list[str],
|
||||||
write_keys: list[str],
|
write_keys: list[str],
|
||||||
) -> "SharedMemory":
|
) -> "SharedMemory":
|
||||||
"""Create a view with restricted permissions for a specific node."""
|
"""Create a view with restricted permissions for a specific node.
|
||||||
|
|
||||||
|
The scoped view shares the same underlying data and locks,
|
||||||
|
enabling thread-safe parallel execution across scoped views.
|
||||||
|
"""
|
||||||
return SharedMemory(
|
return SharedMemory(
|
||||||
_data=self._data,
|
_data=self._data,
|
||||||
_allowed_read=set(read_keys) if read_keys else set(),
|
_allowed_read=set(read_keys) if read_keys else set(),
|
||||||
_allowed_write=set(write_keys) if write_keys else set(),
|
_allowed_write=set(write_keys) if write_keys else set(),
|
||||||
|
_lock=self._lock, # Share lock for thread safety
|
||||||
|
_key_locks=self._key_locks, # Share key locks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -330,6 +446,9 @@ class NodeContext:
|
|||||||
goal_context: str = ""
|
goal_context: str = ""
|
||||||
goal: Any = None # Goal object for LLM-powered routers
|
goal: Any = None # Goal object for LLM-powered routers
|
||||||
|
|
||||||
|
# LLM configuration
|
||||||
|
max_tokens: int = 4096 # Maximum tokens for LLM responses
|
||||||
|
|
||||||
# Execution metadata
|
# Execution metadata
|
||||||
attempt: int = 1
|
attempt: int = 1
|
||||||
max_attempts: int = 3
|
max_attempts: int = 3
|
||||||
@@ -503,9 +622,33 @@ class LLMNode(NodeProtocol):
|
|||||||
The LLM decides how to achieve the goal within constraints.
|
The LLM decides how to achieve the goal within constraints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tool_executor: Callable | None = None, require_tools: bool = False):
|
# Stop reasons indicating truncation (varies by provider)
|
||||||
|
TRUNCATION_STOP_REASONS = {"length", "max_tokens", "token_limit"}
|
||||||
|
|
||||||
|
# Compaction instruction added when response is truncated
|
||||||
|
COMPACTION_INSTRUCTION = """
|
||||||
|
IMPORTANT: Your previous response was truncated because it exceeded the token limit.
|
||||||
|
Please provide a MORE CONCISE response that fits within the limit.
|
||||||
|
Focus on the essential information and omit verbose details.
|
||||||
|
Keep the same JSON structure but with shorter content values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tool_executor: Callable | None = None,
|
||||||
|
require_tools: bool = False,
|
||||||
|
cleanup_llm_model: str | None = None,
|
||||||
|
max_compaction_retries: int = 2,
|
||||||
|
):
|
||||||
self.tool_executor = tool_executor
|
self.tool_executor = tool_executor
|
||||||
self.require_tools = require_tools
|
self.require_tools = require_tools
|
||||||
|
self.cleanup_llm_model = cleanup_llm_model
|
||||||
|
self.max_compaction_retries = max_compaction_retries
|
||||||
|
|
||||||
|
def _is_truncated(self, response) -> bool:
|
||||||
|
"""Check if LLM response was truncated due to token limit."""
|
||||||
|
stop_reason = getattr(response, "stop_reason", "").lower()
|
||||||
|
return stop_reason in self.TRUNCATION_STOP_REASONS
|
||||||
|
|
||||||
def _strip_code_blocks(self, content: str) -> str:
|
def _strip_code_blocks(self, content: str) -> str:
|
||||||
"""Strip markdown code block wrappers from content.
|
"""Strip markdown code block wrappers from content.
|
||||||
@@ -599,6 +742,7 @@ class LLMNode(NodeProtocol):
|
|||||||
system=system,
|
system=system,
|
||||||
tools=ctx.available_tools,
|
tools=ctx.available_tools,
|
||||||
tool_executor=executor,
|
tool_executor=executor,
|
||||||
|
max_tokens=ctx.max_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use JSON mode for llm_generate nodes with output_keys
|
# Use JSON mode for llm_generate nodes with output_keys
|
||||||
@@ -613,128 +757,172 @@ class LLMNode(NodeProtocol):
|
|||||||
f" 📋 Expecting JSON output with keys: {ctx.node_spec.output_keys}"
|
f" 📋 Expecting JSON output with keys: {ctx.node_spec.output_keys}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Phase 3: Auto-generate JSON schema from Pydantic model
|
response = ctx.llm.complete(
|
||||||
response_format = None
|
messages=messages,
|
||||||
if ctx.node_spec.output_model is not None:
|
system=system,
|
||||||
json_schema = ctx.node_spec.output_model.model_json_schema()
|
json_mode=use_json_mode,
|
||||||
response_format = {
|
max_tokens=ctx.max_tokens,
|
||||||
"type": "json_schema",
|
)
|
||||||
"json_schema": {
|
|
||||||
"name": ctx.node_spec.output_model.__name__,
|
|
||||||
"schema": json_schema,
|
|
||||||
"strict": True,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
model_name = ctx.node_spec.output_model.__name__
|
|
||||||
logger.info(f" 📐 Using JSON schema from Pydantic model: {model_name}")
|
|
||||||
|
|
||||||
# Phase 2: Retry loop for Pydantic validation
|
# Check for truncation and retry with compaction if needed
|
||||||
max_retries = ctx.node_spec.max_validation_retries
|
expects_json = (
|
||||||
max_validation_retries = max_retries if ctx.node_spec.output_model else 0
|
ctx.node_spec.node_type in ("llm_generate", "llm_tool_use")
|
||||||
validation_attempt = 0
|
and ctx.node_spec.output_keys
|
||||||
total_input_tokens = 0
|
and len(ctx.node_spec.output_keys) >= 1
|
||||||
total_output_tokens = 0
|
)
|
||||||
current_messages = messages.copy()
|
|
||||||
|
|
||||||
while True:
|
compaction_attempt = 0
|
||||||
|
while self._is_truncated(response) and expects_json and compaction_attempt < self.max_compaction_retries:
|
||||||
|
compaction_attempt += 1
|
||||||
|
logger.warning(
|
||||||
|
f" ⚠ Response truncated (stop_reason: {response.stop_reason}), "
|
||||||
|
f"retrying with compaction ({compaction_attempt}/{self.max_compaction_retries})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add compaction instruction to messages
|
||||||
|
compaction_messages = messages + [
|
||||||
|
{"role": "assistant", "content": response.content},
|
||||||
|
{"role": "user", "content": self.COMPACTION_INSTRUCTION},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Retry the call with compaction instruction
|
||||||
|
if ctx.available_tools and self.tool_executor:
|
||||||
|
response = ctx.llm.complete_with_tools(
|
||||||
|
messages=compaction_messages,
|
||||||
|
system=system,
|
||||||
|
tools=ctx.available_tools,
|
||||||
|
tool_executor=executor,
|
||||||
|
max_tokens=ctx.max_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
response = ctx.llm.complete(
|
response = ctx.llm.complete(
|
||||||
messages=current_messages,
|
messages=compaction_messages,
|
||||||
system=system,
|
system=system,
|
||||||
json_mode=use_json_mode,
|
json_mode=use_json_mode,
|
||||||
response_format=response_format,
|
max_tokens=ctx.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
total_input_tokens += response.input_tokens
|
if self._is_truncated(response) and expects_json:
|
||||||
total_output_tokens += response.output_tokens
|
logger.warning(
|
||||||
|
f" ⚠ Response still truncated after {compaction_attempt} compaction attempts"
|
||||||
|
)
|
||||||
|
|
||||||
# Log the response
|
# Phase 2: Validation retry loop for Pydantic models
|
||||||
response_preview = (
|
max_validation_retries = ctx.node_spec.max_validation_retries if ctx.node_spec.output_model else 0
|
||||||
response.content[:200] if len(response.content) > 200 else response.content
|
validation_attempt = 0
|
||||||
)
|
total_input_tokens = 0
|
||||||
if len(response.content) > 200:
|
total_output_tokens = 0
|
||||||
response_preview += "..."
|
current_messages = messages.copy()
|
||||||
logger.info(f" ← Response: {response_preview}")
|
|
||||||
|
|
||||||
# If no output_model, break immediately (no validation needed)
|
while True:
|
||||||
if ctx.node_spec.output_model is None:
|
total_input_tokens += response.input_tokens
|
||||||
break
|
total_output_tokens += response.output_tokens
|
||||||
|
|
||||||
# Try to parse and validate the response
|
# Log the response
|
||||||
try:
|
response_preview = (
|
||||||
import json
|
response.content[:200] if len(response.content) > 200 else response.content
|
||||||
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
|
)
|
||||||
|
if len(response.content) > 200:
|
||||||
|
response_preview += "..."
|
||||||
|
logger.info(f" ← Response: {response_preview}")
|
||||||
|
|
||||||
if isinstance(parsed, dict):
|
# If no output_model, break immediately (no validation needed)
|
||||||
from framework.graph.validator import OutputValidator
|
if ctx.node_spec.output_model is None:
|
||||||
validator = OutputValidator()
|
break
|
||||||
validation_result, validated_model = validator.validate_with_pydantic(
|
|
||||||
parsed, ctx.node_spec.output_model
|
|
||||||
)
|
|
||||||
|
|
||||||
if validation_result.success:
|
# Try to parse and validate the response
|
||||||
# Validation passed, break out of retry loop
|
try:
|
||||||
model_name = ctx.node_spec.output_model.__name__
|
import json
|
||||||
logger.info(f" ✓ Pydantic validation passed for {model_name}")
|
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Validation failed
|
|
||||||
validation_attempt += 1
|
|
||||||
|
|
||||||
if validation_attempt <= max_validation_retries:
|
if isinstance(parsed, dict):
|
||||||
# Add validation feedback to messages and retry
|
from framework.graph.validator import OutputValidator
|
||||||
feedback = validator.format_validation_feedback(
|
validator = OutputValidator()
|
||||||
validation_result, ctx.node_spec.output_model
|
validation_result, validated_model = validator.validate_with_pydantic(
|
||||||
)
|
parsed, ctx.node_spec.output_model
|
||||||
logger.warning(
|
)
|
||||||
f" ⚠ Pydantic validation failed "
|
|
||||||
f"(attempt {validation_attempt}/{max_validation_retries}): "
|
|
||||||
f"{validation_result.error}"
|
|
||||||
)
|
|
||||||
logger.info(" 🔄 Retrying with validation feedback...")
|
|
||||||
|
|
||||||
# Add the assistant's failed response and feedback
|
if validation_result.success:
|
||||||
current_messages.append({
|
# Validation passed, break out of retry loop
|
||||||
"role": "assistant",
|
model_name = ctx.node_spec.output_model.__name__
|
||||||
"content": response.content
|
logger.info(f" ✓ Pydantic validation passed for {model_name}")
|
||||||
})
|
|
||||||
current_messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": feedback
|
|
||||||
})
|
|
||||||
continue # Retry the LLM call
|
|
||||||
else:
|
|
||||||
# Max retries exceeded
|
|
||||||
latency_ms = int((time.time() - start) * 1000)
|
|
||||||
err = validation_result.error
|
|
||||||
logger.error(
|
|
||||||
f" ✗ Pydantic validation failed after "
|
|
||||||
f"{max_validation_retries} retries: {err}"
|
|
||||||
)
|
|
||||||
ctx.runtime.record_outcome(
|
|
||||||
decision_id=decision_id,
|
|
||||||
success=False,
|
|
||||||
error=f"Validation failed: {validation_result.error}",
|
|
||||||
tokens_used=total_input_tokens + total_output_tokens,
|
|
||||||
latency_ms=latency_ms,
|
|
||||||
)
|
|
||||||
error_msg = (
|
|
||||||
f"Pydantic validation failed after "
|
|
||||||
f"{max_validation_retries} retries: {err}"
|
|
||||||
)
|
|
||||||
return NodeResult(
|
|
||||||
success=False,
|
|
||||||
error=error_msg,
|
|
||||||
output=parsed,
|
|
||||||
tokens_used=total_input_tokens + total_output_tokens,
|
|
||||||
latency_ms=latency_ms,
|
|
||||||
validation_errors=validation_result.errors,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Not a dict, can't validate - break and let downstream handle
|
|
||||||
break
|
break
|
||||||
except Exception:
|
else:
|
||||||
# JSON extraction failed - break and let downstream handle
|
# Validation failed
|
||||||
|
validation_attempt += 1
|
||||||
|
|
||||||
|
if validation_attempt <= max_validation_retries:
|
||||||
|
# Add validation feedback to messages and retry
|
||||||
|
feedback = validator.format_validation_feedback(
|
||||||
|
validation_result, ctx.node_spec.output_model
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f" ⚠ Pydantic validation failed "
|
||||||
|
f"(attempt {validation_attempt}/{max_validation_retries}): "
|
||||||
|
f"{validation_result.error}"
|
||||||
|
)
|
||||||
|
logger.info(" 🔄 Retrying with validation feedback...")
|
||||||
|
|
||||||
|
# Add the assistant's failed response and feedback
|
||||||
|
current_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content
|
||||||
|
})
|
||||||
|
current_messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": feedback
|
||||||
|
})
|
||||||
|
|
||||||
|
# Re-call LLM with feedback
|
||||||
|
if ctx.available_tools and self.tool_executor:
|
||||||
|
response = ctx.llm.complete_with_tools(
|
||||||
|
messages=current_messages,
|
||||||
|
system=system,
|
||||||
|
tools=ctx.available_tools,
|
||||||
|
tool_executor=executor,
|
||||||
|
max_tokens=ctx.max_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = ctx.llm.complete(
|
||||||
|
messages=current_messages,
|
||||||
|
system=system,
|
||||||
|
json_mode=use_json_mode,
|
||||||
|
max_tokens=ctx.max_tokens,
|
||||||
|
)
|
||||||
|
continue # Retry validation
|
||||||
|
else:
|
||||||
|
# Max retries exceeded
|
||||||
|
latency_ms = int((time.time() - start) * 1000)
|
||||||
|
err = validation_result.error
|
||||||
|
logger.error(
|
||||||
|
f" ✗ Pydantic validation failed after "
|
||||||
|
f"{max_validation_retries} retries: {err}"
|
||||||
|
)
|
||||||
|
ctx.runtime.record_outcome(
|
||||||
|
decision_id=decision_id,
|
||||||
|
success=False,
|
||||||
|
error=f"Validation failed: {validation_result.error}",
|
||||||
|
tokens_used=total_input_tokens + total_output_tokens,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
error_msg = (
|
||||||
|
f"Pydantic validation failed after "
|
||||||
|
f"{max_validation_retries} retries: {err}"
|
||||||
|
)
|
||||||
|
return NodeResult(
|
||||||
|
success=False,
|
||||||
|
error=error_msg,
|
||||||
|
output=parsed,
|
||||||
|
tokens_used=total_input_tokens + total_output_tokens,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
validation_errors=validation_result.errors,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Not a dict, can't validate - break and let downstream handle
|
||||||
break
|
break
|
||||||
|
except Exception:
|
||||||
|
# JSON extraction failed - break and let downstream handle
|
||||||
|
break
|
||||||
|
|
||||||
latency_ms = int((time.time() - start) * 1000)
|
latency_ms = int((time.time() - start) * 1000)
|
||||||
|
|
||||||
@@ -758,9 +946,13 @@ class LLMNode(NodeProtocol):
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
# Try to extract JSON from response
|
# Try to extract JSON from response
|
||||||
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
|
parsed = self._extract_json(
|
||||||
|
response.content, ctx.node_spec.output_keys, self.cleanup_llm_model
|
||||||
|
)
|
||||||
|
|
||||||
# If parsed successfully, validate against Pydantic model if specified
|
# If parsed successfully, write each field to its corresponding output key
|
||||||
|
# Use validate=False since LLM output legitimately contains text that
|
||||||
|
# may trigger false positives (e.g., "from OpenAI" matches "from ")
|
||||||
if isinstance(parsed, dict):
|
if isinstance(parsed, dict):
|
||||||
# If we have output_model, the validation already happened in the retry loop
|
# If we have output_model, the validation already happened in the retry loop
|
||||||
if ctx.node_spec.output_model is not None:
|
if ctx.node_spec.output_model is not None:
|
||||||
@@ -779,22 +971,22 @@ class LLMNode(NodeProtocol):
|
|||||||
# Strip code block wrappers from string values
|
# Strip code block wrappers from string values
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = self._strip_code_blocks(value)
|
value = self._strip_code_blocks(value)
|
||||||
ctx.memory.write(key, value)
|
ctx.memory.write(key, value, validate=False)
|
||||||
output[key] = value
|
output[key] = value
|
||||||
elif key in ctx.input_data:
|
elif key in ctx.input_data:
|
||||||
# Key not in JSON but exists in input - pass through
|
# Key not in JSON but exists in input - pass through
|
||||||
ctx.memory.write(key, ctx.input_data[key])
|
ctx.memory.write(key, ctx.input_data[key], validate=False)
|
||||||
output[key] = ctx.input_data[key]
|
output[key] = ctx.input_data[key]
|
||||||
else:
|
else:
|
||||||
# Key not in JSON or input, write whole response (stripped)
|
# Key not in JSON or input, write whole response (stripped)
|
||||||
stripped_content = self._strip_code_blocks(response.content)
|
stripped_content = self._strip_code_blocks(response.content)
|
||||||
ctx.memory.write(key, stripped_content)
|
ctx.memory.write(key, stripped_content, validate=False)
|
||||||
output[key] = stripped_content
|
output[key] = stripped_content
|
||||||
else:
|
else:
|
||||||
# Not a dict, fall back to writing entire response to all keys (stripped)
|
# Not a dict, fall back to writing entire response to all keys (stripped)
|
||||||
stripped_content = self._strip_code_blocks(response.content)
|
stripped_content = self._strip_code_blocks(response.content)
|
||||||
for key in ctx.node_spec.output_keys:
|
for key in ctx.node_spec.output_keys:
|
||||||
ctx.memory.write(key, stripped_content)
|
ctx.memory.write(key, stripped_content, validate=False)
|
||||||
output[key] = stripped_content
|
output[key] = stripped_content
|
||||||
|
|
||||||
except (json.JSONDecodeError, Exception) as e:
|
except (json.JSONDecodeError, Exception) as e:
|
||||||
@@ -825,7 +1017,7 @@ class LLMNode(NodeProtocol):
|
|||||||
# For non-llm_generate or single output nodes, write entire response (stripped)
|
# For non-llm_generate or single output nodes, write entire response (stripped)
|
||||||
stripped_content = self._strip_code_blocks(response.content)
|
stripped_content = self._strip_code_blocks(response.content)
|
||||||
for key in ctx.node_spec.output_keys:
|
for key in ctx.node_spec.output_keys:
|
||||||
ctx.memory.write(key, stripped_content)
|
ctx.memory.write(key, stripped_content, validate=False)
|
||||||
output[key] = stripped_content
|
output[key] = stripped_content
|
||||||
|
|
||||||
return NodeResult(
|
return NodeResult(
|
||||||
@@ -855,14 +1047,21 @@ class LLMNode(NodeProtocol):
|
|||||||
# Default output
|
# Default output
|
||||||
return {"result": content}
|
return {"result": content}
|
||||||
|
|
||||||
def _extract_json(self, raw_response: str, output_keys: list[str]) -> dict[str, Any]:
|
def _extract_json(
|
||||||
|
self, raw_response: str, output_keys: list[str], cleanup_llm_model: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Extract clean JSON from potentially verbose LLM response.
|
"""Extract clean JSON from potentially verbose LLM response.
|
||||||
|
|
||||||
Tries multiple extraction strategies in order:
|
Tries multiple extraction strategies in order:
|
||||||
1. Direct JSON parse
|
1. Direct JSON parse
|
||||||
2. Markdown code block extraction
|
2. Markdown code block extraction
|
||||||
3. Balanced brace matching
|
3. Balanced brace matching
|
||||||
4. Haiku LLM fallback (last resort)
|
4. Configured LLM fallback (last resort)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_response: The raw LLM response text
|
||||||
|
output_keys: Expected output keys for the JSON
|
||||||
|
cleanup_llm_model: Optional model to use for LLM cleanup fallback
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@@ -889,55 +1088,116 @@ class LLMNode(NodeProtocol):
|
|||||||
parsed = json.loads(content)
|
parsed = json.loads(content)
|
||||||
if isinstance(parsed, dict):
|
if isinstance(parsed, dict):
|
||||||
return parsed
|
return parsed
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError as e:
|
||||||
pass
|
logger.info(f" Direct JSON parse failed: {e}")
|
||||||
|
logger.info(f" Content first 200 chars repr: {repr(content[:200])}")
|
||||||
|
# Try fixing unescaped newlines in string values
|
||||||
|
try:
|
||||||
|
fixed = _fix_unescaped_newlines_in_json(content)
|
||||||
|
logger.info(f" Fixed content first 200 chars repr: {repr(fixed[:200])}")
|
||||||
|
parsed = json.loads(fixed)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
logger.info(" ✓ Parsed JSON after fixing unescaped newlines")
|
||||||
|
return parsed
|
||||||
|
except json.JSONDecodeError as e2:
|
||||||
|
logger.info(f" Newline fix also failed: {e2}")
|
||||||
|
|
||||||
# Try to extract JSON from markdown code blocks (greedy match to handle nested blocks)
|
# Try to extract JSON from markdown code blocks (greedy match to handle nested blocks)
|
||||||
# Use anchored match to capture from first ``` to last ```
|
# Multiple patterns to handle different LLM formatting styles
|
||||||
code_block_match = re.match(r"^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$", content, re.DOTALL)
|
code_block_patterns = [
|
||||||
if code_block_match:
|
# Anchored match from first ``` to last ```
|
||||||
try:
|
r"^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$",
|
||||||
parsed = json.loads(code_block_match.group(1).strip())
|
# Non-anchored: find ```json anywhere and extract to closing ```
|
||||||
if isinstance(parsed, dict):
|
r"```(?:json|JSON)?\s*\n([\s\S]*?)\n```",
|
||||||
return parsed
|
# Handle case where closing ``` might have trailing content
|
||||||
except json.JSONDecodeError:
|
r"```(?:json|JSON)?\s*\n([\s\S]*?)\n```",
|
||||||
pass
|
]
|
||||||
|
for pattern in code_block_patterns:
|
||||||
|
code_block_match = re.search(pattern, content, re.DOTALL)
|
||||||
|
if code_block_match:
|
||||||
|
try:
|
||||||
|
extracted = code_block_match.group(1).strip()
|
||||||
|
if extracted: # Skip empty matches
|
||||||
|
# Try direct parse first, then with newline fix
|
||||||
|
try:
|
||||||
|
parsed = json.loads(extracted)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parsed = json.loads(_fix_unescaped_newlines_in_json(extracted))
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return parsed
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Try to find JSON object by matching balanced braces (use module-level helper)
|
# Try to find JSON object by matching balanced braces (use module-level helper)
|
||||||
json_str = find_json_object(content)
|
json_str = find_json_object(content)
|
||||||
if json_str:
|
if json_str:
|
||||||
try:
|
try:
|
||||||
parsed = json.loads(json_str)
|
# Try direct parse first, then with newline fix
|
||||||
|
try:
|
||||||
|
parsed = json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parsed = json.loads(_fix_unescaped_newlines_in_json(json_str))
|
||||||
if isinstance(parsed, dict):
|
if isinstance(parsed, dict):
|
||||||
return parsed
|
return parsed
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Try stripping markdown prefix and finding JSON from there
|
||||||
|
# This handles cases like "```json\n{...}" where regex might fail
|
||||||
|
if "```" in content:
|
||||||
|
# Find position after ```json or ``` marker
|
||||||
|
json_start = content.find("{")
|
||||||
|
if json_start > 0:
|
||||||
|
# Extract from first { to end, then find balanced JSON
|
||||||
|
json_str = find_json_object(content[json_start:])
|
||||||
|
if json_str:
|
||||||
|
try:
|
||||||
|
# Try direct parse first, then with newline fix
|
||||||
|
try:
|
||||||
|
parsed = json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parsed = json.loads(_fix_unescaped_newlines_in_json(json_str))
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
logger.info(" ✓ Extracted JSON via brace matching after markdown strip")
|
||||||
|
return parsed
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
# All local extraction failed - use LLM as last resort
|
# All local extraction failed - use LLM as last resort
|
||||||
# Prefer Cerebras (faster/cheaper), fallback to Haiku
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
api_key = os.environ.get("CEREBRAS_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot parse JSON and no API key for LLM cleanup "
|
|
||||||
"(set CEREBRAS_API_KEY or ANTHROPIC_API_KEY)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use fast LLM to clean the response (Cerebras llama-3.3-70b preferred)
|
|
||||||
from framework.llm.litellm import LiteLLMProvider
|
from framework.llm.litellm import LiteLLMProvider
|
||||||
|
|
||||||
if os.environ.get("CEREBRAS_API_KEY"):
|
logger.info(f" cleanup_llm_model param: {cleanup_llm_model}")
|
||||||
|
|
||||||
|
# Use configured cleanup model, or fall back to defaults
|
||||||
|
if cleanup_llm_model:
|
||||||
|
# Use the configured cleanup model (LiteLLM handles API keys via env vars)
|
||||||
cleaner_llm = LiteLLMProvider(
|
cleaner_llm = LiteLLMProvider(
|
||||||
api_key=os.environ.get("CEREBRAS_API_KEY"),
|
model=cleanup_llm_model,
|
||||||
model="cerebras/llama-3.3-70b",
|
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
|
logger.info(f" Using configured cleanup LLM: {cleanup_llm_model}")
|
||||||
else:
|
else:
|
||||||
# Fallback to Anthropic Haiku via LiteLLM for consistency
|
# Fall back to default logic: Cerebras preferred, then Haiku
|
||||||
cleaner_llm = LiteLLMProvider(
|
api_key = os.environ.get("CEREBRAS_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
|
||||||
api_key=api_key, model="claude-3-5-haiku-20241022", temperature=0.0
|
if not api_key:
|
||||||
)
|
raise ValueError(
|
||||||
|
"Cannot parse JSON and no API key for LLM cleanup "
|
||||||
|
"(set CEREBRAS_API_KEY or ANTHROPIC_API_KEY, or configure cleanup_llm_model)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.environ.get("CEREBRAS_API_KEY"):
|
||||||
|
cleaner_llm = LiteLLMProvider(
|
||||||
|
api_key=os.environ.get("CEREBRAS_API_KEY"),
|
||||||
|
model="cerebras/llama-3.3-70b",
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cleaner_llm = LiteLLMProvider(
|
||||||
|
api_key=api_key,
|
||||||
|
model="claude-3-5-haiku-20241022",
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
prompt = f"""Extract the JSON object from this LLM response.
|
prompt = f"""Extract the JSON object from this LLM response.
|
||||||
|
|
||||||
@@ -955,7 +1215,16 @@ Output ONLY the JSON object, nothing else."""
|
|||||||
json_mode=True,
|
json_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cleaned = result.content.strip()
|
cleaned = result.content.strip() if result.content else ""
|
||||||
|
|
||||||
|
# Check for empty response
|
||||||
|
if not cleaned:
|
||||||
|
logger.warning(" ⚠ LLM cleanup returned empty response")
|
||||||
|
raise ValueError(
|
||||||
|
f"LLM cleanup returned empty response. "
|
||||||
|
f"Raw response starts with: {raw_response[:200]}..."
|
||||||
|
)
|
||||||
|
|
||||||
# Remove markdown if LLM added it
|
# Remove markdown if LLM added it
|
||||||
if cleaned.startswith("```"):
|
if cleaned.startswith("```"):
|
||||||
match = re.search(r"^```(?:json)?\s*\n([\s\S]*?)\n```\s*$", cleaned)
|
match = re.search(r"^```(?:json)?\s*\n([\s\S]*?)\n```\s*$", cleaned)
|
||||||
@@ -967,10 +1236,32 @@ Output ONLY the JSON object, nothing else."""
|
|||||||
if lines[0].startswith("```") and lines[-1].strip() == "```":
|
if lines[0].startswith("```") and lines[-1].strip() == "```":
|
||||||
cleaned = "\n".join(lines[1:-1]).strip()
|
cleaned = "\n".join(lines[1:-1]).strip()
|
||||||
|
|
||||||
parsed = json.loads(cleaned)
|
# Try balanced brace extraction if still not valid JSON
|
||||||
|
if not cleaned.startswith("{"):
|
||||||
|
json_str = find_json_object(cleaned)
|
||||||
|
if json_str:
|
||||||
|
cleaned = json_str
|
||||||
|
|
||||||
|
if not cleaned:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not extract JSON from LLM cleanup response. "
|
||||||
|
f"Raw response starts with: {raw_response[:200]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try direct parse first, then with newline fix
|
||||||
|
try:
|
||||||
|
parsed = json.loads(cleaned)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parsed = json.loads(_fix_unescaped_newlines_in_json(cleaned))
|
||||||
logger.info(" ✓ LLM cleaned JSON output")
|
logger.info(" ✓ LLM cleaned JSON output")
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f" ⚠ LLM cleanup response not valid JSON: {e}")
|
||||||
|
raise ValueError(
|
||||||
|
f"LLM cleanup response not valid JSON: {e}. "
|
||||||
|
f"Expected keys: {output_keys}"
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise # Re-raise our descriptive error
|
raise # Re-raise our descriptive error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -159,6 +159,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
tools: list[Tool],
|
tools: list[Tool],
|
||||||
tool_executor: Callable[[ToolUse], ToolResult],
|
tool_executor: Callable[[ToolUse], ToolResult],
|
||||||
max_iterations: int = 10,
|
max_iterations: int = 10,
|
||||||
|
max_tokens: int = 4096,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Run a tool-use loop until the LLM produces a final response."""
|
"""Run a tool-use loop until the LLM produces a final response."""
|
||||||
# Prepare messages with system prompt
|
# Prepare messages with system prompt
|
||||||
@@ -178,7 +179,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": current_messages,
|
"messages": current_messages,
|
||||||
"max_tokens": 1024,
|
"max_tokens": max_tokens,
|
||||||
"tools": openai_tools,
|
"tools": openai_tools,
|
||||||
**self.extra_kwargs,
|
**self.extra_kwargs,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -421,6 +421,7 @@ class ExecutionStream:
|
|||||||
default_model=self.graph.default_model,
|
default_model=self.graph.default_model,
|
||||||
max_tokens=self.graph.max_tokens,
|
max_tokens=self.graph.max_tokens,
|
||||||
max_steps=self.graph.max_steps,
|
max_steps=self.graph.max_steps,
|
||||||
|
cleanup_llm_model=self.graph.cleanup_llm_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def wait_for_completion(
|
async def wait_for_completion(
|
||||||
|
|||||||
@@ -0,0 +1,413 @@
|
|||||||
|
"""
|
||||||
|
Tests for fan-out / fan-in parallel execution in GraphExecutor.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Fan-out triggers with multiple ON_SUCCESS edges
|
||||||
|
- Concurrent branch execution
|
||||||
|
- Convergence at fan-in node
|
||||||
|
- fail_all / continue_others / wait_all strategies
|
||||||
|
- Branch timeout
|
||||||
|
- Memory conflict strategies
|
||||||
|
- Per-branch retry
|
||||||
|
- Single-edge paths unaffected
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||||
|
from framework.graph.executor import GraphExecutor, ParallelExecutionConfig
|
||||||
|
from framework.graph.goal import Goal
|
||||||
|
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||||
|
from framework.runtime.core import Runtime
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test node implementations ---
|
||||||
|
|
||||||
|
|
||||||
|
class SuccessNode(NodeProtocol):
|
||||||
|
"""Always succeeds with configurable output."""
|
||||||
|
|
||||||
|
def __init__(self, output: dict | None = None):
|
||||||
|
self._output = output or {"result": "ok"}
|
||||||
|
self.executed = False
|
||||||
|
|
||||||
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||||
|
self.executed = True
|
||||||
|
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
|
||||||
|
|
||||||
|
|
||||||
|
class FailNode(NodeProtocol):
|
||||||
|
"""Always fails."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.attempt_count = 0
|
||||||
|
|
||||||
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||||
|
self.attempt_count += 1
|
||||||
|
return NodeResult(success=False, error="branch failed")
|
||||||
|
|
||||||
|
|
||||||
|
class FlakyNode(NodeProtocol):
|
||||||
|
"""Fails N times, then succeeds."""
|
||||||
|
|
||||||
|
def __init__(self, fail_times: int = 1, output: dict | None = None):
|
||||||
|
self.fail_times = fail_times
|
||||||
|
self.attempt_count = 0
|
||||||
|
self._output = output or {"result": "recovered"}
|
||||||
|
|
||||||
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||||
|
self.attempt_count += 1
|
||||||
|
if self.attempt_count <= self.fail_times:
|
||||||
|
return NodeResult(success=False, error=f"fail #{self.attempt_count}")
|
||||||
|
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
|
||||||
|
|
||||||
|
|
||||||
|
class TimingNode(NodeProtocol):
|
||||||
|
"""Records execution order to a shared list."""
|
||||||
|
|
||||||
|
def __init__(self, label: str, order_tracker: list):
|
||||||
|
self.label = label
|
||||||
|
self.order_tracker = order_tracker
|
||||||
|
|
||||||
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||||
|
self.order_tracker.append(self.label)
|
||||||
|
return NodeResult(success=True, output={f"{self.label}_done": True}, tokens_used=1, latency_ms=1)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Fixtures ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runtime():
|
||||||
|
rt = MagicMock(spec=Runtime)
|
||||||
|
rt.start_run = MagicMock(return_value="run_id")
|
||||||
|
rt.decide = MagicMock(return_value="decision_id")
|
||||||
|
rt.record_outcome = MagicMock()
|
||||||
|
rt.end_run = MagicMock()
|
||||||
|
rt.report_problem = MagicMock()
|
||||||
|
rt.set_node = MagicMock()
|
||||||
|
return rt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def goal():
|
||||||
|
return Goal(id="g1", name="Test", description="Fanout tests")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fanout_graph(
|
||||||
|
branch_nodes: list[NodeSpec],
|
||||||
|
fan_in_node: NodeSpec | None = None,
|
||||||
|
source_node: NodeSpec | None = None,
|
||||||
|
) -> GraphSpec:
|
||||||
|
"""
|
||||||
|
Build a diamond graph:
|
||||||
|
|
||||||
|
source
|
||||||
|
/ | \\
|
||||||
|
b0 b1 b2 ...
|
||||||
|
\\ | /
|
||||||
|
fan_in
|
||||||
|
"""
|
||||||
|
if source_node is None:
|
||||||
|
source_node = NodeSpec(
|
||||||
|
id="source", name="Source", description="entry",
|
||||||
|
node_type="function", output_keys=["data"],
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes = [source_node] + branch_nodes
|
||||||
|
terminal_nodes = [b.id for b in branch_nodes]
|
||||||
|
|
||||||
|
edges = [
|
||||||
|
EdgeSpec(
|
||||||
|
id=f"source_to_{b.id}",
|
||||||
|
source="source",
|
||||||
|
target=b.id,
|
||||||
|
condition=EdgeCondition.ON_SUCCESS,
|
||||||
|
)
|
||||||
|
for b in branch_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
if fan_in_node is not None:
|
||||||
|
nodes.append(fan_in_node)
|
||||||
|
terminal_nodes = [fan_in_node.id]
|
||||||
|
for b in branch_nodes:
|
||||||
|
edges.append(
|
||||||
|
EdgeSpec(
|
||||||
|
id=f"{b.id}_to_{fan_in_node.id}",
|
||||||
|
source=b.id,
|
||||||
|
target=fan_in_node.id,
|
||||||
|
condition=EdgeCondition.ON_SUCCESS,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return GraphSpec(
|
||||||
|
id="fanout_graph",
|
||||||
|
goal_id="g1",
|
||||||
|
name="Fanout Graph",
|
||||||
|
entry_node="source",
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges,
|
||||||
|
terminal_nodes=terminal_nodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# === 1. Fan-out triggers with multiple ON_SUCCESS edges ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fanout_triggers_on_multiple_success_edges(runtime, goal):
|
||||||
|
"""Fan-out should activate when a node has >1 ON_SUCCESS outgoing edges."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_out"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_out"])
|
||||||
|
|
||||||
|
graph = _make_fanout_graph([b1, b2])
|
||||||
|
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
||||||
|
source_impl = SuccessNode({"data": "x"})
|
||||||
|
b1_impl = SuccessNode({"b1_out": "done1"})
|
||||||
|
b2_impl = SuccessNode({"b2_out": "done2"})
|
||||||
|
executor.register_node("source", source_impl)
|
||||||
|
executor.register_node("b1", b1_impl)
|
||||||
|
executor.register_node("b2", b2_impl)
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert b1_impl.executed
|
||||||
|
assert b2_impl.executed
|
||||||
|
|
||||||
|
|
||||||
|
# === 2. All branches execute concurrently ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_branches_execute_concurrently(runtime, goal):
|
||||||
|
"""All fan-out branches should be launched via asyncio.gather (concurrent)."""
|
||||||
|
order = []
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_done"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_done"])
|
||||||
|
|
||||||
|
graph = _make_fanout_graph([b1, b2])
|
||||||
|
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
||||||
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||||
|
executor.register_node("b1", TimingNode("b1", order))
|
||||||
|
executor.register_node("b2", TimingNode("b2", order))
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
# Both executed
|
||||||
|
assert "b1" in order
|
||||||
|
assert "b2" in order
|
||||||
|
|
||||||
|
|
||||||
|
# === 3. Convergence at fan-in node ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convergence_at_fan_in_node(runtime, goal):
|
||||||
|
"""After fan-out branches complete, execution should continue at convergence node."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_out"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_out"])
|
||||||
|
merge = NodeSpec(id="merge", name="Merge", description="fan-in", node_type="function", output_keys=["merged"])
|
||||||
|
|
||||||
|
graph = _make_fanout_graph([b1, b2], fan_in_node=merge)
|
||||||
|
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
||||||
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||||
|
executor.register_node("b1", SuccessNode({"b1_out": "1"}))
|
||||||
|
executor.register_node("b2", SuccessNode({"b2_out": "2"}))
|
||||||
|
merge_impl = SuccessNode({"merged": "done"})
|
||||||
|
executor.register_node("merge", merge_impl)
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert merge_impl.executed
|
||||||
|
assert "merge" in result.path
|
||||||
|
|
||||||
|
|
||||||
|
# === 4. fail_all strategy ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fail_all_strategy_raises_on_branch_failure(runtime, goal):
|
||||||
|
"""fail_all should raise RuntimeError if any branch fails."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="ok branch", node_type="function", output_keys=["b1_out"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="bad branch", node_type="function", output_keys=["b2_out"], max_retries=1)
|
||||||
|
|
||||||
|
graph = _make_fanout_graph([b1, b2])
|
||||||
|
|
||||||
|
config = ParallelExecutionConfig(on_branch_failure="fail_all")
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True, parallel_config=config)
|
||||||
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||||
|
executor.register_node("b1", SuccessNode({"b1_out": "ok"}))
|
||||||
|
executor.register_node("b2", FailNode())
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
# fail_all raises RuntimeError which gets caught by the outer try/except
|
||||||
|
assert not result.success
|
||||||
|
assert "failed" in result.error.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# === 5. continue_others strategy ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_continue_others_strategy_allows_partial_success(runtime, goal):
|
||||||
|
"""continue_others should let successful branches complete even if one fails."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="ok", node_type="function", output_keys=["b1_out"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="fail", node_type="function", output_keys=["b2_out"], max_retries=1)
|
||||||
|
|
||||||
|
graph = _make_fanout_graph([b1, b2])
|
||||||
|
|
||||||
|
config = ParallelExecutionConfig(on_branch_failure="continue_others")
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True, parallel_config=config)
|
||||||
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||||
|
b1_impl = SuccessNode({"b1_out": "ok"})
|
||||||
|
executor.register_node("b1", b1_impl)
|
||||||
|
executor.register_node("b2", FailNode())
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
# Should not fail because continue_others tolerates branch failures
|
||||||
|
assert result.success or b1_impl.executed
|
||||||
|
|
||||||
|
|
||||||
|
# === 6. wait_all strategy ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wait_all_strategy_collects_all_results(runtime, goal):
|
||||||
|
"""wait_all should wait for all branches before proceeding."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="ok", node_type="function", output_keys=["b1_out"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="fail", node_type="function", output_keys=["b2_out"], max_retries=1)
|
||||||
|
|
||||||
|
graph = _make_fanout_graph([b1, b2])
|
||||||
|
|
||||||
|
config = ParallelExecutionConfig(on_branch_failure="wait_all")
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True, parallel_config=config)
|
||||||
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||||
|
b1_impl = SuccessNode({"b1_out": "ok"})
|
||||||
|
b2_impl = FailNode()
|
||||||
|
executor.register_node("b1", b1_impl)
|
||||||
|
executor.register_node("b2", b2_impl)
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
# Both branches should have executed regardless
|
||||||
|
assert b1_impl.executed
|
||||||
|
assert b2_impl.attempt_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# === 7. Per-branch retry ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_per_branch_retry(runtime, goal):
|
||||||
|
"""Each branch should retry up to its node's max_retries."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="flaky", node_type="function", output_keys=["b1_out"], max_retries=5)
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="solid", node_type="function", output_keys=["b2_out"])
|
||||||
|
|
||||||
|
graph = _make_fanout_graph([b1, b2])
|
||||||
|
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
||||||
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||||
|
flaky = FlakyNode(fail_times=3, output={"b1_out": "recovered"})
|
||||||
|
executor.register_node("b1", flaky)
|
||||||
|
executor.register_node("b2", SuccessNode({"b2_out": "ok"}))
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert flaky.attempt_count == 4 # 3 fails + 1 success
|
||||||
|
|
||||||
|
|
||||||
|
# === 8. Single-edge path unaffected ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_edge_no_parallel_overhead(runtime, goal):
|
||||||
|
"""A single outgoing edge should follow normal sequential path, not fan-out."""
|
||||||
|
n1 = NodeSpec(id="n1", name="N1", description="entry", node_type="function", output_keys=["out1"])
|
||||||
|
n2 = NodeSpec(id="n2", name="N2", description="next", node_type="function", input_keys=["out1"], output_keys=["out2"])
|
||||||
|
|
||||||
|
graph = GraphSpec(
|
||||||
|
id="seq_graph", goal_id="g1", name="Sequential",
|
||||||
|
entry_node="n1",
|
||||||
|
nodes=[n1, n2],
|
||||||
|
edges=[EdgeSpec(id="e1", source="n1", target="n2", condition=EdgeCondition.ON_SUCCESS)],
|
||||||
|
terminal_nodes=["n2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
||||||
|
executor.register_node("n1", SuccessNode({"out1": "a"}))
|
||||||
|
n2_impl = SuccessNode({"out2": "b"})
|
||||||
|
executor.register_node("n2", n2_impl)
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert n2_impl.executed
|
||||||
|
assert result.path == ["n1", "n2"]
|
||||||
|
|
||||||
|
|
||||||
|
# === 9. detect_fan_out_nodes static analysis ===
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_fan_out_nodes():
|
||||||
|
"""GraphSpec.detect_fan_out_nodes should identify fan-out topology."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="function", output_keys=["x"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="function", output_keys=["y"])
|
||||||
|
graph = _make_fanout_graph([b1, b2])
|
||||||
|
|
||||||
|
fan_outs = graph.detect_fan_out_nodes()
|
||||||
|
|
||||||
|
assert "source" in fan_outs
|
||||||
|
assert set(fan_outs["source"]) == {"b1", "b2"}
|
||||||
|
|
||||||
|
|
||||||
|
# === 10. detect_fan_in_nodes static analysis ===
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_fan_in_nodes():
|
||||||
|
"""GraphSpec.detect_fan_in_nodes should identify convergence topology."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="function", output_keys=["x"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="function", output_keys=["y"])
|
||||||
|
merge = NodeSpec(id="merge", name="Merge", description="m", node_type="function", output_keys=["z"])
|
||||||
|
graph = _make_fanout_graph([b1, b2], fan_in_node=merge)
|
||||||
|
|
||||||
|
fan_ins = graph.detect_fan_in_nodes()
|
||||||
|
|
||||||
|
assert "merge" in fan_ins
|
||||||
|
assert set(fan_ins["merge"]) == {"b1", "b2"}
|
||||||
|
|
||||||
|
|
||||||
|
# === 11. Parallel disabled falls back to sequential ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parallel_disabled_uses_sequential(runtime, goal):
|
||||||
|
"""When enable_parallel_execution=False, multi-edge should follow first match only."""
|
||||||
|
b1 = NodeSpec(id="b1", name="B1", description="b1", node_type="function", output_keys=["b1_out"])
|
||||||
|
b2 = NodeSpec(id="b2", name="B2", description="b2", node_type="function", output_keys=["b2_out"])
|
||||||
|
|
||||||
|
graph = _make_fanout_graph([b1, b2])
|
||||||
|
|
||||||
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=False)
|
||||||
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||||
|
b1_impl = SuccessNode({"b1_out": "ok"})
|
||||||
|
b2_impl = SuccessNode({"b2_out": "ok"})
|
||||||
|
executor.register_node("b1", b1_impl)
|
||||||
|
executor.register_node("b2", b2_impl)
|
||||||
|
|
||||||
|
result = await executor.execute(graph, goal, {})
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
# Only one branch should have executed (sequential follows first edge)
|
||||||
|
executed_count = sum([b1_impl.executed, b2_impl.executed])
|
||||||
|
assert executed_count == 1
|
||||||
Reference in New Issue
Block a user