Merge pull request #1428 from TimothyZhang7/feature/parallel-fanout
Release / Create Release (push) Waiting to run

feat: parallel execution framework
This commit is contained in:
Timothy @aden
2026-01-27 10:17:07 -08:00
committed by GitHub
6 changed files with 1230 additions and 165 deletions
+43
View File
@@ -412,6 +412,11 @@ class GraphSpec(BaseModel):
default_model: str = "claude-haiku-4-5-20251001"
max_tokens: int = 1024
# Cleanup LLM for JSON extraction fallback (fast/cheap model preferred)
# If not set, uses CEREBRAS_API_KEY -> cerebras/llama-3.3-70b or
# ANTHROPIC_API_KEY -> claude-3-5-haiku as fallback
cleanup_llm_model: str | None = None
# Execution limits
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
max_retries_per_node: int = 3
@@ -449,6 +454,44 @@ class GraphSpec(BaseModel):
"""Get all edges entering a node."""
return [e for e in self.edges if e.target == node_id]
def detect_fan_out_nodes(self) -> dict[str, list[str]]:
"""
Detect nodes that fan-out to multiple targets.
A fan-out occurs when a node has multiple outgoing edges with the same
condition (typically ON_SUCCESS) that should execute in parallel.
Returns:
Dict mapping source_node_id -> list of parallel target_node_ids
"""
fan_outs: dict[str, list[str]] = {}
for node in self.nodes:
outgoing = self.get_outgoing_edges(node.id)
# Fan-out: multiple edges with ON_SUCCESS condition
success_edges = [
e for e in outgoing if e.condition == EdgeCondition.ON_SUCCESS
]
if len(success_edges) > 1:
fan_outs[node.id] = [e.target for e in success_edges]
return fan_outs
def detect_fan_in_nodes(self) -> dict[str, list[str]]:
"""
Detect nodes that receive from multiple sources (fan-in / convergence).
A fan-in occurs when a node has multiple incoming edges, meaning
it should wait for all predecessor branches to complete.
Returns:
Dict mapping target_node_id -> list of source_node_ids
"""
fan_ins: dict[str, list[str]] = {}
for node in self.nodes:
incoming = self.get_incoming_edges(node.id)
if len(incoming) > 1:
fan_ins[node.id] = [e.source for e in incoming]
return fan_ins
def get_entry_point(self, session_state: dict | None = None) -> str:
"""
Get the appropriate entry point based on session state.
+332 -16
View File
@@ -15,7 +15,7 @@ from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from framework.graph.edge import GraphSpec
from framework.graph.edge import EdgeSpec, GraphSpec
from framework.graph.goal import Goal
from framework.graph.node import (
FunctionNode,
@@ -48,6 +48,35 @@ class ExecutionResult:
session_state: dict[str, Any] = field(default_factory=dict) # State to resume from
@dataclass
class ParallelBranch:
"""Tracks a single branch in parallel fan-out execution."""
branch_id: str
node_id: str
edge: EdgeSpec
result: "NodeResult | None" = None
status: str = "pending" # pending, running, completed, failed
retry_count: int = 0
error: str | None = None
@dataclass
class ParallelExecutionConfig:
"""Configuration for parallel execution behavior."""
# Error handling: "fail_all" cancels all on first failure,
# "continue_others" lets remaining branches complete,
# "wait_all" waits for all and reports all failures
on_branch_failure: str = "fail_all"
# Memory conflict handling when branches write same key
memory_conflict_strategy: str = "last_wins" # "last_wins", "first_wins", "error"
# Timeout per branch in seconds
branch_timeout_seconds: float = 300.0
class GraphExecutor:
"""
Executes agent graphs.
@@ -76,6 +105,8 @@ class GraphExecutor:
node_registry: dict[str, NodeProtocol] | None = None,
approval_callback: Callable | None = None,
cleansing_config: CleansingConfig | None = None,
enable_parallel_execution: bool = True,
parallel_config: ParallelExecutionConfig | None = None,
):
"""
Initialize the executor.
@@ -88,6 +119,8 @@ class GraphExecutor:
node_registry: Custom node implementations by ID
approval_callback: Optional callback for human-in-the-loop approval
cleansing_config: Optional output cleansing configuration
enable_parallel_execution: Enable parallel fan-out execution (default True)
parallel_config: Configuration for parallel execution behavior
"""
self.runtime = runtime
self.llm = llm
@@ -105,6 +138,10 @@ class GraphExecutor:
llm_provider=llm,
)
# Parallel execution settings
self.enable_parallel_execution = enable_parallel_execution
self._parallel_config = parallel_config or ParallelExecutionConfig()
def _validate_tools(self, graph: GraphSpec) -> list[str]:
"""
Validate that all tools declared by nodes are available.
@@ -246,6 +283,7 @@ class GraphExecutor:
memory=memory,
goal=goal,
input_data=input_data or {},
max_tokens=graph.max_tokens,
)
# Log actual input data being read
@@ -261,7 +299,7 @@ class GraphExecutor:
self.logger.info(f" {key}: {value_str}")
# Get or create node implementation
node_impl = self._get_node_implementation(node_spec)
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
# Validate inputs
validation_errors = node_impl.validate_input(ctx)
@@ -419,8 +457,8 @@ class GraphExecutor:
self.logger.info(f" → Router directing to: {result.next_node}")
current_node_id = result.next_node
else:
# Follow edges
next_node = self._follow_edges(
# Get all traversable edges for fan-out detection
traversable_edges = self._get_all_traversable_edges(
graph=graph,
goal=goal,
current_node_id=current_node_id,
@@ -428,12 +466,55 @@ class GraphExecutor:
result=result,
memory=memory,
)
if next_node is None:
if not traversable_edges:
self.logger.info(" → No more edges, ending execution")
break # No valid edge, end execution
next_spec = graph.get_node(next_node)
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
current_node_id = next_node
# Check for fan-out (multiple traversable edges)
if self.enable_parallel_execution and len(traversable_edges) > 1:
# Find convergence point (fan-in node)
targets = [e.target for e in traversable_edges]
fan_in_node = self._find_convergence_node(graph, targets)
# Execute branches in parallel
_branch_results, branch_tokens, branch_latency = await self._execute_parallel_branches(
graph=graph,
goal=goal,
edges=traversable_edges,
memory=memory,
source_result=result,
source_node_spec=node_spec,
path=path,
)
total_tokens += branch_tokens
total_latency += branch_latency
# Continue from fan-in node
if fan_in_node:
self.logger.info(f" ⑃ Fan-in: converging at {fan_in_node}")
current_node_id = fan_in_node
else:
# No convergence point - branches are terminal
self.logger.info(" → Parallel branches completed (no convergence)")
break
else:
# Sequential: follow single edge (existing logic via _follow_edges)
next_node = self._follow_edges(
graph=graph,
goal=goal,
current_node_id=current_node_id,
current_node_spec=node_spec,
result=result,
memory=memory,
)
if next_node is None:
self.logger.info(" → No more edges, ending execution")
break
next_spec = graph.get_node(next_node)
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
current_node_id = next_node
# Update input_data for next node
input_data = result.output
@@ -484,6 +565,7 @@ class GraphExecutor:
memory: SharedMemory,
goal: Goal,
input_data: dict[str, Any],
max_tokens: int = 4096,
) -> NodeContext:
"""Build execution context for a node."""
# Filter tools to those available to this node
@@ -507,12 +589,15 @@ class GraphExecutor:
available_tools=available_tools,
goal_context=goal.to_prompt_context(),
goal=goal, # Pass Goal object for LLM-powered routers
max_tokens=max_tokens,
)
# Valid node types - no ambiguous "llm" type allowed
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
def _get_node_implementation(self, node_spec: NodeSpec) -> NodeProtocol:
def _get_node_implementation(
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
) -> NodeProtocol:
"""Get or create a node implementation."""
# Check registry first
if node_spec.id in self.node_registry:
@@ -533,10 +618,18 @@ class GraphExecutor:
f"Node '{node_spec.id}' is type 'llm_tool_use' but declares no tools. "
"Either add tools to the node or change type to 'llm_generate'."
)
return LLMNode(tool_executor=self.tool_executor, require_tools=True)
return LLMNode(
tool_executor=self.tool_executor,
require_tools=True,
cleanup_llm_model=cleanup_llm_model,
)
if node_spec.node_type == "llm_generate":
return LLMNode(tool_executor=None, require_tools=False)
return LLMNode(
tool_executor=None,
require_tools=False,
cleanup_llm_model=cleanup_llm_model,
)
if node_spec.node_type == "router":
return RouterNode()
@@ -549,7 +642,11 @@ class GraphExecutor:
if node_spec.node_type == "human_input":
# Human input nodes are handled specially by HITL mechanism
return LLMNode(tool_executor=None, require_tools=False)
return LLMNode(
tool_executor=None,
require_tools=False,
cleanup_llm_model=cleanup_llm_model,
)
# Should never reach here due to validation above
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
@@ -608,9 +705,9 @@ class GraphExecutor:
# Update result with cleaned output
result.output = cleaned_output
# Write cleaned output back to memory
# Write cleaned output back to memory (skip validation for LLM output)
for key, value in cleaned_output.items():
memory.write(key, value)
memory.write(key, value, validate=False)
# Revalidate
revalidation = self.output_cleaner.validate_output(
@@ -629,15 +726,234 @@ class GraphExecutor:
)
# Continue anyway if fallback_to_raw is True
# Map inputsss
# Map inputs (skip validation for processed LLM output)
mapped = edge.map_inputs(result.output, memory.read_all())
for key, value in mapped.items():
memory.write(key, value)
memory.write(key, value, validate=False)
return edge.target
return None
def _get_all_traversable_edges(
self,
graph: GraphSpec,
goal: Goal,
current_node_id: str,
current_node_spec: Any,
result: NodeResult,
memory: SharedMemory,
) -> list[EdgeSpec]:
"""
Get ALL edges that should be traversed (for fan-out detection).
Unlike _follow_edges which returns the first match, this returns
all matching edges to enable parallel execution.
"""
edges = graph.get_outgoing_edges(current_node_id)
traversable = []
for edge in edges:
target_node_spec = graph.get_node(edge.target)
if edge.should_traverse(
source_success=result.success,
source_output=result.output,
memory=memory.read_all(),
llm=self.llm,
goal=goal,
source_node_name=current_node_spec.name if current_node_spec else current_node_id,
target_node_name=target_node_spec.name if target_node_spec else edge.target,
):
traversable.append(edge)
return traversable
def _find_convergence_node(
self,
graph: GraphSpec,
parallel_targets: list[str],
) -> str | None:
"""
Find the common target node where parallel branches converge (fan-in).
Args:
graph: The graph specification
parallel_targets: List of node IDs that are running in parallel
Returns:
Node ID where all branches converge, or None if no convergence
"""
# Get all nodes that parallel branches lead to
next_nodes: dict[str, int] = {} # node_id -> count of branches leading to it
for target in parallel_targets:
outgoing = graph.get_outgoing_edges(target)
for edge in outgoing:
next_nodes[edge.target] = next_nodes.get(edge.target, 0) + 1
# Convergence node is where ALL branches lead
for node_id, count in next_nodes.items():
if count == len(parallel_targets):
return node_id
# Fallback: return most common target if any
if next_nodes:
return max(next_nodes.keys(), key=lambda k: next_nodes[k])
return None
async def _execute_parallel_branches(
self,
graph: GraphSpec,
goal: Goal,
edges: list[EdgeSpec],
memory: SharedMemory,
source_result: NodeResult,
source_node_spec: Any,
path: list[str],
) -> tuple[dict[str, NodeResult], int, int]:
"""
Execute multiple branches in parallel using asyncio.gather.
Args:
graph: The graph specification
goal: The execution goal
edges: List of edges to follow in parallel
memory: Shared memory instance
source_result: Result from the source node
source_node_spec: Spec of the source node
path: Execution path list to update
Returns:
Tuple of (branch_results dict, total_tokens, total_latency)
"""
branches: dict[str, ParallelBranch] = {}
# Create branches for each edge
for edge in edges:
branch_id = f"{edge.source}_to_{edge.target}"
branches[branch_id] = ParallelBranch(
branch_id=branch_id,
node_id=edge.target,
edge=edge,
)
self.logger.info(f" ⑂ Fan-out: executing {len(branches)} branches in parallel")
for branch in branches.values():
target_spec = graph.get_node(branch.node_id)
self.logger.info(f"{target_spec.name if target_spec else branch.node_id}")
async def execute_single_branch(branch: ParallelBranch) -> tuple[ParallelBranch, NodeResult | Exception]:
"""Execute a single branch with retry logic."""
node_spec = graph.get_node(branch.node_id)
branch.status = "running"
try:
# Validate and clean output before mapping inputs (same as _follow_edges)
if self.cleansing_config.enabled and node_spec:
validation = self.output_cleaner.validate_output(
output=source_result.output,
source_node_id=source_node_spec.id if source_node_spec else "unknown",
target_node_spec=node_spec,
)
if not validation.valid:
self.logger.warning(
f"⚠ Output validation failed for branch {branch.node_id}: {validation.errors}"
)
cleaned_output = self.output_cleaner.clean_output(
output=source_result.output,
source_node_id=source_node_spec.id if source_node_spec else "unknown",
target_node_spec=node_spec,
validation_errors=validation.errors,
)
# Write cleaned output to memory
for key, value in cleaned_output.items():
await memory.write_async(key, value)
# Map inputs via edge
mapped = branch.edge.map_inputs(source_result.output, memory.read_all())
for key, value in mapped.items():
await memory.write_async(key, value)
# Execute with retries
last_result = None
for attempt in range(node_spec.max_retries):
branch.retry_count = attempt
# Build context for this branch
ctx = self._build_context(node_spec, memory, goal, mapped, graph.max_tokens)
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
self.logger.info(f" ▶ Branch {node_spec.name}: executing (attempt {attempt + 1})")
result = await node_impl.execute(ctx)
last_result = result
if result.success:
# Write outputs to shared memory using async write
for key, value in result.output.items():
await memory.write_async(key, value)
branch.result = result
branch.status = "completed"
self.logger.info(
f" ✓ Branch {node_spec.name}: success "
f"(tokens: {result.tokens_used}, latency: {result.latency_ms}ms)"
)
return branch, result
self.logger.warning(
f" ↻ Branch {node_spec.name}: retry {attempt + 1}/{node_spec.max_retries}"
)
# All retries exhausted
branch.status = "failed"
branch.error = last_result.error if last_result else "Unknown error"
branch.result = last_result
self.logger.error(f" ✗ Branch {node_spec.name}: failed after {node_spec.max_retries} attempts")
return branch, last_result
except Exception as e:
branch.status = "failed"
branch.error = str(e)
self.logger.error(f" ✗ Branch {branch.node_id}: exception - {e}")
return branch, e
# Execute all branches concurrently
tasks = [execute_single_branch(b) for b in branches.values()]
results = await asyncio.gather(*tasks, return_exceptions=False)
# Process results
total_tokens = 0
total_latency = 0
branch_results: dict[str, NodeResult] = {}
failed_branches: list[ParallelBranch] = []
for branch, result in results:
path.append(branch.node_id)
if isinstance(result, Exception):
failed_branches.append(branch)
elif result is None or not result.success:
failed_branches.append(branch)
else:
total_tokens += result.tokens_used
total_latency += result.latency_ms
branch_results[branch.branch_id] = result
# Handle failures based on config
if failed_branches:
failed_names = [graph.get_node(b.node_id).name for b in failed_branches]
if self._parallel_config.on_branch_failure == "fail_all":
raise RuntimeError(
f"Parallel execution failed: branches {failed_names} failed"
)
elif self._parallel_config.on_branch_failure == "continue_others":
self.logger.warning(f"⚠ Some branches failed ({failed_names}), continuing with successful ones")
self.logger.info(f" ⑃ Fan-out complete: {len(branch_results)}/{len(branches)} branches succeeded")
return branch_results, total_tokens, total_latency
def register_node(self, node_id: str, implementation: NodeProtocol) -> None:
"""Register a custom node implementation."""
self.node_registry[node_id] = implementation
+439 -148
View File
@@ -15,6 +15,7 @@ Protocol:
The framework provides NodeContext with everything the node needs.
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
@@ -29,6 +30,62 @@ from framework.runtime.core import Runtime
logger = logging.getLogger(__name__)
def _fix_unescaped_newlines_in_json(json_str: str) -> str:
"""Fix unescaped newlines inside JSON string values.
LLMs sometimes output actual newlines inside JSON strings instead of \\n.
This function fixes that by properly escaping newlines within string values.
"""
result = []
in_string = False
escape_next = False
i = 0
while i < len(json_str):
char = json_str[i]
if escape_next:
result.append(char)
escape_next = False
i += 1
continue
if char == "\\" and in_string:
escape_next = True
result.append(char)
i += 1
continue
if char == '"' and not escape_next:
in_string = not in_string
result.append(char)
i += 1
continue
# Fix unescaped newlines inside strings
if in_string and char == "\n":
result.append("\\n")
i += 1
continue
# Fix unescaped carriage returns inside strings
if in_string and char == "\r":
result.append("\\r")
i += 1
continue
# Fix unescaped tabs inside strings
if in_string and char == "\t":
result.append("\\t")
i += 1
continue
result.append(char)
i += 1
return "".join(result)
def find_json_object(text: str) -> str | None:
"""Find the first valid JSON object in text using balanced brace matching.
@@ -173,11 +230,22 @@ class SharedMemory:
Nodes read and write to shared memory using typed keys.
The memory is scoped to a single run.
For parallel execution, use write_async() which provides per-key locking
to prevent race conditions when multiple nodes write concurrently.
"""
_data: dict[str, Any] = field(default_factory=dict)
_allowed_read: set[str] = field(default_factory=set)
_allowed_write: set[str] = field(default_factory=set)
# Locks for thread-safe parallel execution
_lock: asyncio.Lock | None = field(default=None, repr=False)
_key_locks: dict[str, asyncio.Lock] = field(default_factory=dict, repr=False)
def __post_init__(self) -> None:
"""Initialize the main lock if not provided."""
if self._lock is None:
self._lock = asyncio.Lock()
def read(self, key: str) -> Any:
"""Read a value from shared memory."""
@@ -218,6 +286,48 @@ class SharedMemory:
self._data[key] = value
async def write_async(self, key: str, value: Any, validate: bool = True) -> None:
"""
Thread-safe async write with per-key locking.
Use this method when multiple nodes may write concurrently during
parallel execution. Each key has its own lock to minimize contention.
Args:
key: The memory key to write to
value: The value to write
validate: If True, check for suspicious content (default True)
Raises:
PermissionError: If node doesn't have write permission
MemoryWriteError: If value appears to be hallucinated content
"""
# Check permissions first (no lock needed)
if self._allowed_write and key not in self._allowed_write:
raise PermissionError(f"Node not allowed to write key: {key}")
# Ensure key has a lock (double-checked locking pattern)
if key not in self._key_locks:
async with self._lock:
if key not in self._key_locks:
self._key_locks[key] = asyncio.Lock()
# Acquire per-key lock and write
async with self._key_locks[key]:
if validate and isinstance(value, str):
if len(value) > 5000:
if self._contains_code_indicators(value):
logger.warning(
f"⚠ Suspicious write to key '{key}': appears to be code "
f"({len(value)} chars). Consider using validate=False if intended."
)
raise MemoryWriteError(
f"Rejected suspicious content for key '{key}': "
f"appears to be hallucinated code ({len(value)} chars). "
"If this is intentional, use validate=False."
)
self._data[key] = value
def _contains_code_indicators(self, value: str) -> bool:
"""
Check for code patterns in a string using sampling for efficiency.
@@ -290,11 +400,17 @@ class SharedMemory:
read_keys: list[str],
write_keys: list[str],
) -> "SharedMemory":
"""Create a view with restricted permissions for a specific node."""
"""Create a view with restricted permissions for a specific node.
The scoped view shares the same underlying data and locks,
enabling thread-safe parallel execution across scoped views.
"""
return SharedMemory(
_data=self._data,
_allowed_read=set(read_keys) if read_keys else set(),
_allowed_write=set(write_keys) if write_keys else set(),
_lock=self._lock, # Share lock for thread safety
_key_locks=self._key_locks, # Share key locks
)
@@ -330,6 +446,9 @@ class NodeContext:
goal_context: str = ""
goal: Any = None # Goal object for LLM-powered routers
# LLM configuration
max_tokens: int = 4096 # Maximum tokens for LLM responses
# Execution metadata
attempt: int = 1
max_attempts: int = 3
@@ -503,9 +622,33 @@ class LLMNode(NodeProtocol):
The LLM decides how to achieve the goal within constraints.
"""
def __init__(self, tool_executor: Callable | None = None, require_tools: bool = False):
# Stop reasons indicating truncation (varies by provider)
TRUNCATION_STOP_REASONS = {"length", "max_tokens", "token_limit"}
# Compaction instruction added when response is truncated
COMPACTION_INSTRUCTION = """
IMPORTANT: Your previous response was truncated because it exceeded the token limit.
Please provide a MORE CONCISE response that fits within the limit.
Focus on the essential information and omit verbose details.
Keep the same JSON structure but with shorter content values.
"""
def __init__(
self,
tool_executor: Callable | None = None,
require_tools: bool = False,
cleanup_llm_model: str | None = None,
max_compaction_retries: int = 2,
):
self.tool_executor = tool_executor
self.require_tools = require_tools
self.cleanup_llm_model = cleanup_llm_model
self.max_compaction_retries = max_compaction_retries
def _is_truncated(self, response) -> bool:
"""Check if LLM response was truncated due to token limit."""
stop_reason = getattr(response, "stop_reason", "").lower()
return stop_reason in self.TRUNCATION_STOP_REASONS
def _strip_code_blocks(self, content: str) -> str:
"""Strip markdown code block wrappers from content.
@@ -599,6 +742,7 @@ class LLMNode(NodeProtocol):
system=system,
tools=ctx.available_tools,
tool_executor=executor,
max_tokens=ctx.max_tokens,
)
else:
# Use JSON mode for llm_generate nodes with output_keys
@@ -613,128 +757,172 @@ class LLMNode(NodeProtocol):
f" 📋 Expecting JSON output with keys: {ctx.node_spec.output_keys}"
)
# Phase 3: Auto-generate JSON schema from Pydantic model
response_format = None
if ctx.node_spec.output_model is not None:
json_schema = ctx.node_spec.output_model.model_json_schema()
response_format = {
"type": "json_schema",
"json_schema": {
"name": ctx.node_spec.output_model.__name__,
"schema": json_schema,
"strict": True,
}
}
model_name = ctx.node_spec.output_model.__name__
logger.info(f" 📐 Using JSON schema from Pydantic model: {model_name}")
response = ctx.llm.complete(
messages=messages,
system=system,
json_mode=use_json_mode,
max_tokens=ctx.max_tokens,
)
# Phase 2: Retry loop for Pydantic validation
max_retries = ctx.node_spec.max_validation_retries
max_validation_retries = max_retries if ctx.node_spec.output_model else 0
validation_attempt = 0
total_input_tokens = 0
total_output_tokens = 0
current_messages = messages.copy()
# Check for truncation and retry with compaction if needed
expects_json = (
ctx.node_spec.node_type in ("llm_generate", "llm_tool_use")
and ctx.node_spec.output_keys
and len(ctx.node_spec.output_keys) >= 1
)
while True:
compaction_attempt = 0
while self._is_truncated(response) and expects_json and compaction_attempt < self.max_compaction_retries:
compaction_attempt += 1
logger.warning(
f" ⚠ Response truncated (stop_reason: {response.stop_reason}), "
f"retrying with compaction ({compaction_attempt}/{self.max_compaction_retries})"
)
# Add compaction instruction to messages
compaction_messages = messages + [
{"role": "assistant", "content": response.content},
{"role": "user", "content": self.COMPACTION_INSTRUCTION},
]
# Retry the call with compaction instruction
if ctx.available_tools and self.tool_executor:
response = ctx.llm.complete_with_tools(
messages=compaction_messages,
system=system,
tools=ctx.available_tools,
tool_executor=executor,
max_tokens=ctx.max_tokens,
)
else:
response = ctx.llm.complete(
messages=current_messages,
messages=compaction_messages,
system=system,
json_mode=use_json_mode,
response_format=response_format,
max_tokens=ctx.max_tokens,
)
total_input_tokens += response.input_tokens
total_output_tokens += response.output_tokens
if self._is_truncated(response) and expects_json:
logger.warning(
f" ⚠ Response still truncated after {compaction_attempt} compaction attempts"
)
# Log the response
response_preview = (
response.content[:200] if len(response.content) > 200 else response.content
)
if len(response.content) > 200:
response_preview += "..."
logger.info(f" ← Response: {response_preview}")
# Phase 2: Validation retry loop for Pydantic models
max_validation_retries = ctx.node_spec.max_validation_retries if ctx.node_spec.output_model else 0
validation_attempt = 0
total_input_tokens = 0
total_output_tokens = 0
current_messages = messages.copy()
# If no output_model, break immediately (no validation needed)
if ctx.node_spec.output_model is None:
break
while True:
total_input_tokens += response.input_tokens
total_output_tokens += response.output_tokens
# Try to parse and validate the response
try:
import json
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
# Log the response
response_preview = (
response.content[:200] if len(response.content) > 200 else response.content
)
if len(response.content) > 200:
response_preview += "..."
logger.info(f" ← Response: {response_preview}")
if isinstance(parsed, dict):
from framework.graph.validator import OutputValidator
validator = OutputValidator()
validation_result, validated_model = validator.validate_with_pydantic(
parsed, ctx.node_spec.output_model
)
# If no output_model, break immediately (no validation needed)
if ctx.node_spec.output_model is None:
break
if validation_result.success:
# Validation passed, break out of retry loop
model_name = ctx.node_spec.output_model.__name__
logger.info(f" ✓ Pydantic validation passed for {model_name}")
break
else:
# Validation failed
validation_attempt += 1
# Try to parse and validate the response
try:
import json
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
if validation_attempt <= max_validation_retries:
# Add validation feedback to messages and retry
feedback = validator.format_validation_feedback(
validation_result, ctx.node_spec.output_model
)
logger.warning(
f" ⚠ Pydantic validation failed "
f"(attempt {validation_attempt}/{max_validation_retries}): "
f"{validation_result.error}"
)
logger.info(" 🔄 Retrying with validation feedback...")
if isinstance(parsed, dict):
from framework.graph.validator import OutputValidator
validator = OutputValidator()
validation_result, validated_model = validator.validate_with_pydantic(
parsed, ctx.node_spec.output_model
)
# Add the assistant's failed response and feedback
current_messages.append({
"role": "assistant",
"content": response.content
})
current_messages.append({
"role": "user",
"content": feedback
})
continue # Retry the LLM call
else:
# Max retries exceeded
latency_ms = int((time.time() - start) * 1000)
err = validation_result.error
logger.error(
f" ✗ Pydantic validation failed after "
f"{max_validation_retries} retries: {err}"
)
ctx.runtime.record_outcome(
decision_id=decision_id,
success=False,
error=f"Validation failed: {validation_result.error}",
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
error_msg = (
f"Pydantic validation failed after "
f"{max_validation_retries} retries: {err}"
)
return NodeResult(
success=False,
error=error_msg,
output=parsed,
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
validation_errors=validation_result.errors,
)
else:
# Not a dict, can't validate - break and let downstream handle
if validation_result.success:
# Validation passed, break out of retry loop
model_name = ctx.node_spec.output_model.__name__
logger.info(f" ✓ Pydantic validation passed for {model_name}")
break
except Exception:
# JSON extraction failed - break and let downstream handle
else:
# Validation failed
validation_attempt += 1
if validation_attempt <= max_validation_retries:
# Add validation feedback to messages and retry
feedback = validator.format_validation_feedback(
validation_result, ctx.node_spec.output_model
)
logger.warning(
f" ⚠ Pydantic validation failed "
f"(attempt {validation_attempt}/{max_validation_retries}): "
f"{validation_result.error}"
)
logger.info(" 🔄 Retrying with validation feedback...")
# Add the assistant's failed response and feedback
current_messages.append({
"role": "assistant",
"content": response.content
})
current_messages.append({
"role": "user",
"content": feedback
})
# Re-call LLM with feedback
if ctx.available_tools and self.tool_executor:
response = ctx.llm.complete_with_tools(
messages=current_messages,
system=system,
tools=ctx.available_tools,
tool_executor=executor,
max_tokens=ctx.max_tokens,
)
else:
response = ctx.llm.complete(
messages=current_messages,
system=system,
json_mode=use_json_mode,
max_tokens=ctx.max_tokens,
)
continue # Retry validation
else:
# Max retries exceeded
latency_ms = int((time.time() - start) * 1000)
err = validation_result.error
logger.error(
f" ✗ Pydantic validation failed after "
f"{max_validation_retries} retries: {err}"
)
ctx.runtime.record_outcome(
decision_id=decision_id,
success=False,
error=f"Validation failed: {validation_result.error}",
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
error_msg = (
f"Pydantic validation failed after "
f"{max_validation_retries} retries: {err}"
)
return NodeResult(
success=False,
error=error_msg,
output=parsed,
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
validation_errors=validation_result.errors,
)
else:
# Not a dict, can't validate - break and let downstream handle
break
except Exception:
# JSON extraction failed - break and let downstream handle
break
latency_ms = int((time.time() - start) * 1000)
@@ -758,9 +946,13 @@ class LLMNode(NodeProtocol):
import json
# Try to extract JSON from response
parsed = self._extract_json(response.content, ctx.node_spec.output_keys)
parsed = self._extract_json(
response.content, ctx.node_spec.output_keys, self.cleanup_llm_model
)
# If parsed successfully, validate against Pydantic model if specified
# If parsed successfully, write each field to its corresponding output key
# Use validate=False since LLM output legitimately contains text that
# may trigger false positives (e.g., "from OpenAI" matches "from ")
if isinstance(parsed, dict):
# If we have output_model, the validation already happened in the retry loop
if ctx.node_spec.output_model is not None:
@@ -779,22 +971,22 @@ class LLMNode(NodeProtocol):
# Strip code block wrappers from string values
if isinstance(value, str):
value = self._strip_code_blocks(value)
ctx.memory.write(key, value)
ctx.memory.write(key, value, validate=False)
output[key] = value
elif key in ctx.input_data:
# Key not in JSON but exists in input - pass through
ctx.memory.write(key, ctx.input_data[key])
ctx.memory.write(key, ctx.input_data[key], validate=False)
output[key] = ctx.input_data[key]
else:
# Key not in JSON or input, write whole response (stripped)
stripped_content = self._strip_code_blocks(response.content)
ctx.memory.write(key, stripped_content)
ctx.memory.write(key, stripped_content, validate=False)
output[key] = stripped_content
else:
# Not a dict, fall back to writing entire response to all keys (stripped)
stripped_content = self._strip_code_blocks(response.content)
for key in ctx.node_spec.output_keys:
ctx.memory.write(key, stripped_content)
ctx.memory.write(key, stripped_content, validate=False)
output[key] = stripped_content
except (json.JSONDecodeError, Exception) as e:
@@ -825,7 +1017,7 @@ class LLMNode(NodeProtocol):
# For non-llm_generate or single output nodes, write entire response (stripped)
stripped_content = self._strip_code_blocks(response.content)
for key in ctx.node_spec.output_keys:
ctx.memory.write(key, stripped_content)
ctx.memory.write(key, stripped_content, validate=False)
output[key] = stripped_content
return NodeResult(
@@ -855,14 +1047,21 @@ class LLMNode(NodeProtocol):
# Default output
return {"result": content}
def _extract_json(self, raw_response: str, output_keys: list[str]) -> dict[str, Any]:
def _extract_json(
self, raw_response: str, output_keys: list[str], cleanup_llm_model: str | None = None
) -> dict[str, Any]:
"""Extract clean JSON from potentially verbose LLM response.
Tries multiple extraction strategies in order:
1. Direct JSON parse
2. Markdown code block extraction
3. Balanced brace matching
4. Haiku LLM fallback (last resort)
4. Configured LLM fallback (last resort)
Args:
raw_response: The raw LLM response text
output_keys: Expected output keys for the JSON
cleanup_llm_model: Optional model to use for LLM cleanup fallback
"""
import json
import re
@@ -889,55 +1088,116 @@ class LLMNode(NodeProtocol):
parsed = json.loads(content)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
except json.JSONDecodeError as e:
logger.info(f" Direct JSON parse failed: {e}")
logger.info(f" Content first 200 chars repr: {repr(content[:200])}")
# Try fixing unescaped newlines in string values
try:
fixed = _fix_unescaped_newlines_in_json(content)
logger.info(f" Fixed content first 200 chars repr: {repr(fixed[:200])}")
parsed = json.loads(fixed)
if isinstance(parsed, dict):
logger.info(" ✓ Parsed JSON after fixing unescaped newlines")
return parsed
except json.JSONDecodeError as e2:
logger.info(f" Newline fix also failed: {e2}")
# Try to extract JSON from markdown code blocks (greedy match to handle nested blocks)
# Use anchored match to capture from first ``` to last ```
code_block_match = re.match(r"^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$", content, re.DOTALL)
if code_block_match:
try:
parsed = json.loads(code_block_match.group(1).strip())
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
# Multiple patterns to handle different LLM formatting styles
code_block_patterns = [
# Anchored match from first ``` to last ```
r"^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$",
# Non-anchored: find ```json anywhere and extract to closing ```
r"```(?:json|JSON)?\s*\n([\s\S]*?)\n```",
# Handle case where closing ``` might have trailing content
r"```(?:json|JSON)?\s*\n([\s\S]*?)\n```",
]
for pattern in code_block_patterns:
code_block_match = re.search(pattern, content, re.DOTALL)
if code_block_match:
try:
extracted = code_block_match.group(1).strip()
if extracted: # Skip empty matches
# Try direct parse first, then with newline fix
try:
parsed = json.loads(extracted)
except json.JSONDecodeError:
parsed = json.loads(_fix_unescaped_newlines_in_json(extracted))
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
# Try to find JSON object by matching balanced braces (use module-level helper)
json_str = find_json_object(content)
if json_str:
try:
parsed = json.loads(json_str)
# Try direct parse first, then with newline fix
try:
parsed = json.loads(json_str)
except json.JSONDecodeError:
parsed = json.loads(_fix_unescaped_newlines_in_json(json_str))
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
# Try stripping markdown prefix and finding JSON from there
# This handles cases like "```json\n{...}" where regex might fail
if "```" in content:
# Find position after ```json or ``` marker
json_start = content.find("{")
if json_start > 0:
# Extract from first { to end, then find balanced JSON
json_str = find_json_object(content[json_start:])
if json_str:
try:
# Try direct parse first, then with newline fix
try:
parsed = json.loads(json_str)
except json.JSONDecodeError:
parsed = json.loads(_fix_unescaped_newlines_in_json(json_str))
if isinstance(parsed, dict):
logger.info(" ✓ Extracted JSON via brace matching after markdown strip")
return parsed
except json.JSONDecodeError:
pass
# All local extraction failed - use LLM as last resort
# Prefer Cerebras (faster/cheaper), fallback to Haiku
import os
api_key = os.environ.get("CEREBRAS_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError(
"Cannot parse JSON and no API key for LLM cleanup "
"(set CEREBRAS_API_KEY or ANTHROPIC_API_KEY)"
)
# Use fast LLM to clean the response (Cerebras llama-3.3-70b preferred)
from framework.llm.litellm import LiteLLMProvider
if os.environ.get("CEREBRAS_API_KEY"):
logger.info(f" cleanup_llm_model param: {cleanup_llm_model}")
# Use configured cleanup model, or fall back to defaults
if cleanup_llm_model:
# Use the configured cleanup model (LiteLLM handles API keys via env vars)
cleaner_llm = LiteLLMProvider(
api_key=os.environ.get("CEREBRAS_API_KEY"),
model="cerebras/llama-3.3-70b",
model=cleanup_llm_model,
temperature=0.0,
)
logger.info(f" Using configured cleanup LLM: {cleanup_llm_model}")
else:
# Fallback to Anthropic Haiku via LiteLLM for consistency
cleaner_llm = LiteLLMProvider(
api_key=api_key, model="claude-3-5-haiku-20241022", temperature=0.0
)
# Fall back to default logic: Cerebras preferred, then Haiku
api_key = os.environ.get("CEREBRAS_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError(
"Cannot parse JSON and no API key for LLM cleanup "
"(set CEREBRAS_API_KEY or ANTHROPIC_API_KEY, or configure cleanup_llm_model)"
)
if os.environ.get("CEREBRAS_API_KEY"):
cleaner_llm = LiteLLMProvider(
api_key=os.environ.get("CEREBRAS_API_KEY"),
model="cerebras/llama-3.3-70b",
temperature=0.0,
)
else:
cleaner_llm = LiteLLMProvider(
api_key=api_key,
model="claude-3-5-haiku-20241022",
temperature=0.0,
)
prompt = f"""Extract the JSON object from this LLM response.
@@ -955,7 +1215,16 @@ Output ONLY the JSON object, nothing else."""
json_mode=True,
)
cleaned = result.content.strip()
cleaned = result.content.strip() if result.content else ""
# Check for empty response
if not cleaned:
logger.warning(" ⚠ LLM cleanup returned empty response")
raise ValueError(
f"LLM cleanup returned empty response. "
f"Raw response starts with: {raw_response[:200]}..."
)
# Remove markdown if LLM added it
if cleaned.startswith("```"):
match = re.search(r"^```(?:json)?\s*\n([\s\S]*?)\n```\s*$", cleaned)
@@ -967,10 +1236,32 @@ Output ONLY the JSON object, nothing else."""
if lines[0].startswith("```") and lines[-1].strip() == "```":
cleaned = "\n".join(lines[1:-1]).strip()
parsed = json.loads(cleaned)
# Try balanced brace extraction if still not valid JSON
if not cleaned.startswith("{"):
json_str = find_json_object(cleaned)
if json_str:
cleaned = json_str
if not cleaned:
raise ValueError(
f"Could not extract JSON from LLM cleanup response. "
f"Raw response starts with: {raw_response[:200]}..."
)
# Try direct parse first, then with newline fix
try:
parsed = json.loads(cleaned)
except json.JSONDecodeError:
parsed = json.loads(_fix_unescaped_newlines_in_json(cleaned))
logger.info(" ✓ LLM cleaned JSON output")
return parsed
except json.JSONDecodeError as e:
logger.warning(f" ⚠ LLM cleanup response not valid JSON: {e}")
raise ValueError(
f"LLM cleanup response not valid JSON: {e}. "
f"Expected keys: {output_keys}"
)
except ValueError:
raise # Re-raise our descriptive error
except Exception as e:
+2 -1
View File
@@ -159,6 +159,7 @@ class LiteLLMProvider(LLMProvider):
tools: list[Tool],
tool_executor: Callable[[ToolUse], ToolResult],
max_iterations: int = 10,
max_tokens: int = 4096,
) -> LLMResponse:
"""Run a tool-use loop until the LLM produces a final response."""
# Prepare messages with system prompt
@@ -178,7 +179,7 @@ class LiteLLMProvider(LLMProvider):
kwargs: dict[str, Any] = {
"model": self.model,
"messages": current_messages,
"max_tokens": 1024,
"max_tokens": max_tokens,
"tools": openai_tools,
**self.extra_kwargs,
}
@@ -421,6 +421,7 @@ class ExecutionStream:
default_model=self.graph.default_model,
max_tokens=self.graph.max_tokens,
max_steps=self.graph.max_steps,
cleanup_llm_model=self.graph.cleanup_llm_model,
)
async def wait_for_completion(
+413
View File
@@ -0,0 +1,413 @@
"""
Tests for fan-out / fan-in parallel execution in GraphExecutor.
Covers:
- Fan-out triggers with multiple ON_SUCCESS edges
- Concurrent branch execution
- Convergence at fan-in node
- fail_all / continue_others / wait_all strategies
- Branch timeout
- Memory conflict strategies
- Per-branch retry
- Single-edge paths unaffected
"""
from unittest.mock import MagicMock
import pytest
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.executor import GraphExecutor, ParallelExecutionConfig
from framework.graph.goal import Goal
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
from framework.runtime.core import Runtime
# --- Test node implementations ---
class SuccessNode(NodeProtocol):
"""Always succeeds with configurable output."""
def __init__(self, output: dict | None = None):
self._output = output or {"result": "ok"}
self.executed = False
async def execute(self, ctx: NodeContext) -> NodeResult:
self.executed = True
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
class FailNode(NodeProtocol):
"""Always fails."""
def __init__(self):
self.attempt_count = 0
async def execute(self, ctx: NodeContext) -> NodeResult:
self.attempt_count += 1
return NodeResult(success=False, error="branch failed")
class FlakyNode(NodeProtocol):
"""Fails N times, then succeeds."""
def __init__(self, fail_times: int = 1, output: dict | None = None):
self.fail_times = fail_times
self.attempt_count = 0
self._output = output or {"result": "recovered"}
async def execute(self, ctx: NodeContext) -> NodeResult:
self.attempt_count += 1
if self.attempt_count <= self.fail_times:
return NodeResult(success=False, error=f"fail #{self.attempt_count}")
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
class TimingNode(NodeProtocol):
"""Records execution order to a shared list."""
def __init__(self, label: str, order_tracker: list):
self.label = label
self.order_tracker = order_tracker
async def execute(self, ctx: NodeContext) -> NodeResult:
self.order_tracker.append(self.label)
return NodeResult(success=True, output={f"{self.label}_done": True}, tokens_used=1, latency_ms=1)
# --- Fixtures ---
@pytest.fixture
def runtime():
rt = MagicMock(spec=Runtime)
rt.start_run = MagicMock(return_value="run_id")
rt.decide = MagicMock(return_value="decision_id")
rt.record_outcome = MagicMock()
rt.end_run = MagicMock()
rt.report_problem = MagicMock()
rt.set_node = MagicMock()
return rt
@pytest.fixture
def goal():
return Goal(id="g1", name="Test", description="Fanout tests")
def _make_fanout_graph(
branch_nodes: list[NodeSpec],
fan_in_node: NodeSpec | None = None,
source_node: NodeSpec | None = None,
) -> GraphSpec:
"""
Build a diamond graph:
source
/ | \\
b0 b1 b2 ...
\\ | /
fan_in
"""
if source_node is None:
source_node = NodeSpec(
id="source", name="Source", description="entry",
node_type="function", output_keys=["data"],
)
nodes = [source_node] + branch_nodes
terminal_nodes = [b.id for b in branch_nodes]
edges = [
EdgeSpec(
id=f"source_to_{b.id}",
source="source",
target=b.id,
condition=EdgeCondition.ON_SUCCESS,
)
for b in branch_nodes
]
if fan_in_node is not None:
nodes.append(fan_in_node)
terminal_nodes = [fan_in_node.id]
for b in branch_nodes:
edges.append(
EdgeSpec(
id=f"{b.id}_to_{fan_in_node.id}",
source=b.id,
target=fan_in_node.id,
condition=EdgeCondition.ON_SUCCESS,
)
)
return GraphSpec(
id="fanout_graph",
goal_id="g1",
name="Fanout Graph",
entry_node="source",
nodes=nodes,
edges=edges,
terminal_nodes=terminal_nodes,
)
# === 1. Fan-out triggers with multiple ON_SUCCESS edges ===
@pytest.mark.asyncio
async def test_fanout_triggers_on_multiple_success_edges(runtime, goal):
"""Fan-out should activate when a node has >1 ON_SUCCESS outgoing edges."""
b1 = NodeSpec(id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_out"])
b2 = NodeSpec(id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_out"])
graph = _make_fanout_graph([b1, b2])
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
source_impl = SuccessNode({"data": "x"})
b1_impl = SuccessNode({"b1_out": "done1"})
b2_impl = SuccessNode({"b2_out": "done2"})
executor.register_node("source", source_impl)
executor.register_node("b1", b1_impl)
executor.register_node("b2", b2_impl)
result = await executor.execute(graph, goal, {})
assert result.success
assert b1_impl.executed
assert b2_impl.executed
# === 2. All branches execute concurrently ===
@pytest.mark.asyncio
async def test_branches_execute_concurrently(runtime, goal):
"""All fan-out branches should be launched via asyncio.gather (concurrent)."""
order = []
b1 = NodeSpec(id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_done"])
b2 = NodeSpec(id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_done"])
graph = _make_fanout_graph([b1, b2])
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
executor.register_node("source", SuccessNode({"data": "x"}))
executor.register_node("b1", TimingNode("b1", order))
executor.register_node("b2", TimingNode("b2", order))
result = await executor.execute(graph, goal, {})
assert result.success
# Both executed
assert "b1" in order
assert "b2" in order
# === 3. Convergence at fan-in node ===
@pytest.mark.asyncio
async def test_convergence_at_fan_in_node(runtime, goal):
"""After fan-out branches complete, execution should continue at convergence node."""
b1 = NodeSpec(id="b1", name="B1", description="branch 1", node_type="function", output_keys=["b1_out"])
b2 = NodeSpec(id="b2", name="B2", description="branch 2", node_type="function", output_keys=["b2_out"])
merge = NodeSpec(id="merge", name="Merge", description="fan-in", node_type="function", output_keys=["merged"])
graph = _make_fanout_graph([b1, b2], fan_in_node=merge)
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
executor.register_node("source", SuccessNode({"data": "x"}))
executor.register_node("b1", SuccessNode({"b1_out": "1"}))
executor.register_node("b2", SuccessNode({"b2_out": "2"}))
merge_impl = SuccessNode({"merged": "done"})
executor.register_node("merge", merge_impl)
result = await executor.execute(graph, goal, {})
assert result.success
assert merge_impl.executed
assert "merge" in result.path
# === 4. fail_all strategy ===
@pytest.mark.asyncio
async def test_fail_all_strategy_raises_on_branch_failure(runtime, goal):
"""fail_all should raise RuntimeError if any branch fails."""
b1 = NodeSpec(id="b1", name="B1", description="ok branch", node_type="function", output_keys=["b1_out"])
b2 = NodeSpec(id="b2", name="B2", description="bad branch", node_type="function", output_keys=["b2_out"], max_retries=1)
graph = _make_fanout_graph([b1, b2])
config = ParallelExecutionConfig(on_branch_failure="fail_all")
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True, parallel_config=config)
executor.register_node("source", SuccessNode({"data": "x"}))
executor.register_node("b1", SuccessNode({"b1_out": "ok"}))
executor.register_node("b2", FailNode())
result = await executor.execute(graph, goal, {})
# fail_all raises RuntimeError which gets caught by the outer try/except
assert not result.success
assert "failed" in result.error.lower()
# === 5. continue_others strategy ===
@pytest.mark.asyncio
async def test_continue_others_strategy_allows_partial_success(runtime, goal):
"""continue_others should let successful branches complete even if one fails."""
b1 = NodeSpec(id="b1", name="B1", description="ok", node_type="function", output_keys=["b1_out"])
b2 = NodeSpec(id="b2", name="B2", description="fail", node_type="function", output_keys=["b2_out"], max_retries=1)
graph = _make_fanout_graph([b1, b2])
config = ParallelExecutionConfig(on_branch_failure="continue_others")
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True, parallel_config=config)
executor.register_node("source", SuccessNode({"data": "x"}))
b1_impl = SuccessNode({"b1_out": "ok"})
executor.register_node("b1", b1_impl)
executor.register_node("b2", FailNode())
result = await executor.execute(graph, goal, {})
# Should not fail because continue_others tolerates branch failures
assert result.success or b1_impl.executed
# === 6. wait_all strategy ===
@pytest.mark.asyncio
async def test_wait_all_strategy_collects_all_results(runtime, goal):
"""wait_all should wait for all branches before proceeding."""
b1 = NodeSpec(id="b1", name="B1", description="ok", node_type="function", output_keys=["b1_out"])
b2 = NodeSpec(id="b2", name="B2", description="fail", node_type="function", output_keys=["b2_out"], max_retries=1)
graph = _make_fanout_graph([b1, b2])
config = ParallelExecutionConfig(on_branch_failure="wait_all")
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True, parallel_config=config)
executor.register_node("source", SuccessNode({"data": "x"}))
b1_impl = SuccessNode({"b1_out": "ok"})
b2_impl = FailNode()
executor.register_node("b1", b1_impl)
executor.register_node("b2", b2_impl)
result = await executor.execute(graph, goal, {})
# Both branches should have executed regardless
assert b1_impl.executed
assert b2_impl.attempt_count >= 1
# === 7. Per-branch retry ===
@pytest.mark.asyncio
async def test_per_branch_retry(runtime, goal):
"""Each branch should retry up to its node's max_retries."""
b1 = NodeSpec(id="b1", name="B1", description="flaky", node_type="function", output_keys=["b1_out"], max_retries=5)
b2 = NodeSpec(id="b2", name="B2", description="solid", node_type="function", output_keys=["b2_out"])
graph = _make_fanout_graph([b1, b2])
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
executor.register_node("source", SuccessNode({"data": "x"}))
flaky = FlakyNode(fail_times=3, output={"b1_out": "recovered"})
executor.register_node("b1", flaky)
executor.register_node("b2", SuccessNode({"b2_out": "ok"}))
result = await executor.execute(graph, goal, {})
assert result.success
assert flaky.attempt_count == 4 # 3 fails + 1 success
# === 8. Single-edge path unaffected ===
@pytest.mark.asyncio
async def test_single_edge_no_parallel_overhead(runtime, goal):
"""A single outgoing edge should follow normal sequential path, not fan-out."""
n1 = NodeSpec(id="n1", name="N1", description="entry", node_type="function", output_keys=["out1"])
n2 = NodeSpec(id="n2", name="N2", description="next", node_type="function", input_keys=["out1"], output_keys=["out2"])
graph = GraphSpec(
id="seq_graph", goal_id="g1", name="Sequential",
entry_node="n1",
nodes=[n1, n2],
edges=[EdgeSpec(id="e1", source="n1", target="n2", condition=EdgeCondition.ON_SUCCESS)],
terminal_nodes=["n2"],
)
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
executor.register_node("n1", SuccessNode({"out1": "a"}))
n2_impl = SuccessNode({"out2": "b"})
executor.register_node("n2", n2_impl)
result = await executor.execute(graph, goal, {})
assert result.success
assert n2_impl.executed
assert result.path == ["n1", "n2"]
# === 9. detect_fan_out_nodes static analysis ===
def test_detect_fan_out_nodes():
"""GraphSpec.detect_fan_out_nodes should identify fan-out topology."""
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="function", output_keys=["x"])
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="function", output_keys=["y"])
graph = _make_fanout_graph([b1, b2])
fan_outs = graph.detect_fan_out_nodes()
assert "source" in fan_outs
assert set(fan_outs["source"]) == {"b1", "b2"}
# === 10. detect_fan_in_nodes static analysis ===
def test_detect_fan_in_nodes():
"""GraphSpec.detect_fan_in_nodes should identify convergence topology."""
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="function", output_keys=["x"])
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="function", output_keys=["y"])
merge = NodeSpec(id="merge", name="Merge", description="m", node_type="function", output_keys=["z"])
graph = _make_fanout_graph([b1, b2], fan_in_node=merge)
fan_ins = graph.detect_fan_in_nodes()
assert "merge" in fan_ins
assert set(fan_ins["merge"]) == {"b1", "b2"}
# === 11. Parallel disabled falls back to sequential ===
@pytest.mark.asyncio
async def test_parallel_disabled_uses_sequential(runtime, goal):
"""When enable_parallel_execution=False, multi-edge should follow first match only."""
b1 = NodeSpec(id="b1", name="B1", description="b1", node_type="function", output_keys=["b1_out"])
b2 = NodeSpec(id="b2", name="B2", description="b2", node_type="function", output_keys=["b2_out"])
graph = _make_fanout_graph([b1, b2])
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=False)
executor.register_node("source", SuccessNode({"data": "x"}))
b1_impl = SuccessNode({"b1_out": "ok"})
b2_impl = SuccessNode({"b2_out": "ok"})
executor.register_node("b1", b1_impl)
executor.register_node("b2", b2_impl)
result = await executor.execute(graph, goal, {})
assert result.success
# Only one branch should have executed (sequential follows first edge)
executed_count = sum([b1_impl.executed, b2_impl.executed])
assert executed_count == 1