692 lines
22 KiB
Python
692 lines
22 KiB
Python
"""
|
|
Tests for fan-out / fan-in parallel execution in GraphExecutor.
|
|
|
|
Covers:
|
|
- Fan-out triggers with multiple ON_SUCCESS edges
|
|
- Concurrent branch execution
|
|
- Convergence at fan-in node
|
|
- fail_all / continue_others / wait_all strategies
|
|
- Branch timeout
|
|
- Memory conflict strategies
|
|
- Per-branch retry
|
|
- Single-edge paths unaffected
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
|
from framework.graph.executor import GraphExecutor, ParallelExecutionConfig
|
|
from framework.graph.goal import Goal
|
|
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
|
from framework.runtime.core import Runtime
|
|
|
|
# --- Test node implementations ---
|
|
|
|
|
|
class SuccessNode(NodeProtocol):
|
|
"""Always succeeds with configurable output."""
|
|
|
|
def __init__(self, output: dict | None = None):
|
|
self._output = output or {"result": "ok"}
|
|
self.executed = False
|
|
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
self.executed = True
|
|
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
|
|
|
|
|
|
class FailNode(NodeProtocol):
|
|
"""Always fails."""
|
|
|
|
def __init__(self):
|
|
self.attempt_count = 0
|
|
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
self.attempt_count += 1
|
|
return NodeResult(success=False, error="branch failed")
|
|
|
|
|
|
class FlakyNode(NodeProtocol):
|
|
"""Fails N times, then succeeds."""
|
|
|
|
def __init__(self, fail_times: int = 1, output: dict | None = None):
|
|
self.fail_times = fail_times
|
|
self.attempt_count = 0
|
|
self._output = output or {"result": "recovered"}
|
|
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
self.attempt_count += 1
|
|
if self.attempt_count <= self.fail_times:
|
|
return NodeResult(success=False, error=f"fail #{self.attempt_count}")
|
|
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
|
|
|
|
|
|
class TimingNode(NodeProtocol):
|
|
"""Records execution order to a shared list."""
|
|
|
|
def __init__(self, label: str, order_tracker: list):
|
|
self.label = label
|
|
self.order_tracker = order_tracker
|
|
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
self.order_tracker.append(self.label)
|
|
return NodeResult(
|
|
success=True, output={f"{self.label}_done": True}, tokens_used=1, latency_ms=1
|
|
)
|
|
|
|
|
|
class SlowNode(NodeProtocol):
|
|
"""Sleeps before returning -- used for timeout testing."""
|
|
|
|
def __init__(self, delay: float = 10.0):
|
|
self.delay = delay
|
|
self.executed = False
|
|
|
|
async def execute(self, ctx: NodeContext) -> NodeResult:
|
|
await asyncio.sleep(self.delay)
|
|
self.executed = True
|
|
return NodeResult(success=True, output={"result": "slow"}, tokens_used=1, latency_ms=1)
|
|
|
|
|
|
# --- Fixtures ---
|
|
|
|
|
|
@pytest.fixture
|
|
def runtime():
|
|
rt = MagicMock(spec=Runtime)
|
|
rt.start_run = MagicMock(return_value="run_id")
|
|
rt.decide = MagicMock(return_value="decision_id")
|
|
rt.record_outcome = MagicMock()
|
|
rt.end_run = MagicMock()
|
|
rt.report_problem = MagicMock()
|
|
rt.set_node = MagicMock()
|
|
return rt
|
|
|
|
|
|
@pytest.fixture
|
|
def goal():
|
|
return Goal(id="g1", name="Test", description="Fanout tests")
|
|
|
|
|
|
def _make_fanout_graph(
|
|
branch_nodes: list[NodeSpec],
|
|
fan_in_node: NodeSpec | None = None,
|
|
source_node: NodeSpec | None = None,
|
|
) -> GraphSpec:
|
|
"""
|
|
Build a diamond graph:
|
|
|
|
source
|
|
/ | \\
|
|
b0 b1 b2 ...
|
|
\\ | /
|
|
fan_in
|
|
"""
|
|
if source_node is None:
|
|
source_node = NodeSpec(
|
|
id="source",
|
|
name="Source",
|
|
description="entry",
|
|
node_type="event_loop",
|
|
output_keys=["data"],
|
|
)
|
|
|
|
nodes = [source_node] + branch_nodes
|
|
terminal_nodes = [b.id for b in branch_nodes]
|
|
|
|
edges = [
|
|
EdgeSpec(
|
|
id=f"source_to_{b.id}",
|
|
source="source",
|
|
target=b.id,
|
|
condition=EdgeCondition.ON_SUCCESS,
|
|
)
|
|
for b in branch_nodes
|
|
]
|
|
|
|
if fan_in_node is not None:
|
|
nodes.append(fan_in_node)
|
|
terminal_nodes = [fan_in_node.id]
|
|
for b in branch_nodes:
|
|
edges.append(
|
|
EdgeSpec(
|
|
id=f"{b.id}_to_{fan_in_node.id}",
|
|
source=b.id,
|
|
target=fan_in_node.id,
|
|
condition=EdgeCondition.ON_SUCCESS,
|
|
)
|
|
)
|
|
|
|
return GraphSpec(
|
|
id="fanout_graph",
|
|
goal_id="g1",
|
|
name="Fanout Graph",
|
|
entry_node="source",
|
|
nodes=nodes,
|
|
edges=edges,
|
|
terminal_nodes=terminal_nodes,
|
|
)
|
|
|
|
|
|
# === 1. Fan-out triggers with multiple ON_SUCCESS edges ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fanout_triggers_on_multiple_success_edges(runtime, goal):
|
|
"""Fan-out should activate when a node has >1 ON_SUCCESS outgoing edges."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="branch 1", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="branch 2", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
|
source_impl = SuccessNode({"data": "x"})
|
|
b1_impl = SuccessNode({"b1_out": "done1"})
|
|
b2_impl = SuccessNode({"b2_out": "done2"})
|
|
executor.register_node("source", source_impl)
|
|
executor.register_node("b1", b1_impl)
|
|
executor.register_node("b2", b2_impl)
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert result.success
|
|
assert b1_impl.executed
|
|
assert b2_impl.executed
|
|
|
|
|
|
# === 2. All branches execute concurrently ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_branches_execute_concurrently(runtime, goal):
|
|
"""All fan-out branches should be launched via asyncio.gather (concurrent)."""
|
|
order = []
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="branch 1", node_type="event_loop", output_keys=["b1_done"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="branch 2", node_type="event_loop", output_keys=["b2_done"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
executor.register_node("b1", TimingNode("b1", order))
|
|
executor.register_node("b2", TimingNode("b2", order))
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert result.success
|
|
# Both executed
|
|
assert "b1" in order
|
|
assert "b2" in order
|
|
|
|
|
|
# === 3. Convergence at fan-in node ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convergence_at_fan_in_node(runtime, goal):
|
|
"""After fan-out branches complete, execution should continue at convergence node."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="branch 1", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="branch 2", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
merge = NodeSpec(
|
|
id="merge",
|
|
name="Merge",
|
|
description="fan-in",
|
|
node_type="event_loop",
|
|
output_keys=["merged"],
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2], fan_in_node=merge)
|
|
|
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
executor.register_node("b1", SuccessNode({"b1_out": "1"}))
|
|
executor.register_node("b2", SuccessNode({"b2_out": "2"}))
|
|
merge_impl = SuccessNode({"merged": "done"})
|
|
executor.register_node("merge", merge_impl)
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert result.success
|
|
assert merge_impl.executed
|
|
assert "merge" in result.path
|
|
|
|
|
|
# === 4. fail_all strategy ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fail_all_strategy_raises_on_branch_failure(runtime, goal):
|
|
"""fail_all should raise RuntimeError if any branch fails."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="ok branch", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2",
|
|
name="B2",
|
|
description="bad branch",
|
|
node_type="event_loop",
|
|
output_keys=["b2_out"],
|
|
max_retries=1,
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(on_branch_failure="fail_all")
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
executor.register_node("b1", SuccessNode({"b1_out": "ok"}))
|
|
executor.register_node("b2", FailNode())
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
# fail_all raises RuntimeError which gets caught by the outer try/except
|
|
assert not result.success
|
|
assert "failed" in result.error.lower()
|
|
|
|
|
|
# === 5. continue_others strategy ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_continue_others_strategy_allows_partial_success(runtime, goal):
|
|
"""continue_others should let successful branches complete even if one fails."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="ok", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2",
|
|
name="B2",
|
|
description="fail",
|
|
node_type="event_loop",
|
|
output_keys=["b2_out"],
|
|
max_retries=1,
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(on_branch_failure="continue_others")
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
b1_impl = SuccessNode({"b1_out": "ok"})
|
|
executor.register_node("b1", b1_impl)
|
|
executor.register_node("b2", FailNode())
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
# Should not fail because continue_others tolerates branch failures
|
|
assert result.success or b1_impl.executed
|
|
|
|
|
|
# === 6. wait_all strategy ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wait_all_strategy_collects_all_results(runtime, goal):
|
|
"""wait_all should wait for all branches before proceeding."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="ok", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2",
|
|
name="B2",
|
|
description="fail",
|
|
node_type="event_loop",
|
|
output_keys=["b2_out"],
|
|
max_retries=1,
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(on_branch_failure="wait_all")
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
b1_impl = SuccessNode({"b1_out": "ok"})
|
|
b2_impl = FailNode()
|
|
executor.register_node("b1", b1_impl)
|
|
executor.register_node("b2", b2_impl)
|
|
|
|
await executor.execute(graph, goal, {})
|
|
|
|
# Both branches should have executed regardless
|
|
assert b1_impl.executed
|
|
assert b2_impl.attempt_count >= 1
|
|
|
|
|
|
# === 7. Per-branch retry ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_per_branch_retry(runtime, goal):
|
|
"""Each branch should retry up to its node's max_retries."""
|
|
b1 = NodeSpec(
|
|
id="b1",
|
|
name="B1",
|
|
description="flaky",
|
|
node_type="event_loop",
|
|
output_keys=["b1_out"],
|
|
max_retries=5,
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="solid", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
flaky = FlakyNode(fail_times=3, output={"b1_out": "recovered"})
|
|
executor.register_node("b1", flaky)
|
|
executor.register_node("b2", SuccessNode({"b2_out": "ok"}))
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert result.success
|
|
assert flaky.attempt_count == 4 # 3 fails + 1 success
|
|
|
|
|
|
# === 8. Single-edge path unaffected ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_edge_no_parallel_overhead(runtime, goal):
|
|
"""A single outgoing edge should follow normal sequential path, not fan-out."""
|
|
n1 = NodeSpec(
|
|
id="n1", name="N1", description="entry", node_type="event_loop", output_keys=["out1"]
|
|
)
|
|
n2 = NodeSpec(
|
|
id="n2",
|
|
name="N2",
|
|
description="next",
|
|
node_type="event_loop",
|
|
input_keys=["out1"],
|
|
output_keys=["out2"],
|
|
)
|
|
|
|
graph = GraphSpec(
|
|
id="seq_graph",
|
|
goal_id="g1",
|
|
name="Sequential",
|
|
entry_node="n1",
|
|
nodes=[n1, n2],
|
|
edges=[EdgeSpec(id="e1", source="n1", target="n2", condition=EdgeCondition.ON_SUCCESS)],
|
|
terminal_nodes=["n2"],
|
|
)
|
|
|
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
|
executor.register_node("n1", SuccessNode({"out1": "a"}))
|
|
n2_impl = SuccessNode({"out2": "b"})
|
|
executor.register_node("n2", n2_impl)
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert result.success
|
|
assert n2_impl.executed
|
|
assert result.path == ["n1", "n2"]
|
|
|
|
|
|
# === 9. detect_fan_out_nodes static analysis ===
|
|
|
|
|
|
def test_detect_fan_out_nodes():
|
|
"""GraphSpec.detect_fan_out_nodes should identify fan-out topology."""
|
|
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="event_loop", output_keys=["x"])
|
|
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="event_loop", output_keys=["y"])
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
fan_outs = graph.detect_fan_out_nodes()
|
|
|
|
assert "source" in fan_outs
|
|
assert set(fan_outs["source"]) == {"b1", "b2"}
|
|
|
|
|
|
# === 10. detect_fan_in_nodes static analysis ===
|
|
|
|
|
|
def test_detect_fan_in_nodes():
|
|
"""GraphSpec.detect_fan_in_nodes should identify convergence topology."""
|
|
b1 = NodeSpec(id="b1", name="B1", description="b", node_type="event_loop", output_keys=["x"])
|
|
b2 = NodeSpec(id="b2", name="B2", description="b", node_type="event_loop", output_keys=["y"])
|
|
merge = NodeSpec(
|
|
id="merge", name="Merge", description="m", node_type="event_loop", output_keys=["z"]
|
|
)
|
|
graph = _make_fanout_graph([b1, b2], fan_in_node=merge)
|
|
|
|
fan_ins = graph.detect_fan_in_nodes()
|
|
|
|
assert "merge" in fan_ins
|
|
assert set(fan_ins["merge"]) == {"b1", "b2"}
|
|
|
|
|
|
# === 11. Parallel disabled falls back to sequential ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parallel_disabled_uses_sequential(runtime, goal):
|
|
"""When enable_parallel_execution=False, multi-edge should follow first match only."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="b1", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="b2", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=False)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
b1_impl = SuccessNode({"b1_out": "ok"})
|
|
b2_impl = SuccessNode({"b2_out": "ok"})
|
|
executor.register_node("b1", b1_impl)
|
|
executor.register_node("b2", b2_impl)
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert result.success
|
|
# Only one branch should have executed (sequential follows first edge)
|
|
executed_count = sum([b1_impl.executed, b2_impl.executed])
|
|
assert executed_count == 1
|
|
|
|
|
|
# === 12. Branch timeout cancels slow branch ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_branch_timeout_cancels_slow_branch(runtime, goal):
|
|
"""A branch exceeding branch_timeout_seconds should be cancelled."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="slow", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="fast", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(branch_timeout_seconds=0.1, on_branch_failure="fail_all")
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
executor.register_node("b1", SlowNode(delay=10.0))
|
|
executor.register_node("b2", SuccessNode({"b2_out": "ok"}))
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
# fail_all: one branch timed out → execution fails
|
|
assert not result.success
|
|
assert "failed" in result.error.lower()
|
|
|
|
|
|
# === 13. Branch timeout with continue_others ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_branch_timeout_with_continue_others(runtime, goal):
|
|
"""continue_others should let fast branches finish even when one times out."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="slow", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="fast", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(
|
|
branch_timeout_seconds=0.1, on_branch_failure="continue_others"
|
|
)
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
executor.register_node("b1", SlowNode(delay=10.0))
|
|
b2_impl = SuccessNode({"b2_out": "ok"})
|
|
executor.register_node("b2", b2_impl)
|
|
|
|
await executor.execute(graph, goal, {})
|
|
|
|
# continue_others tolerates the timeout
|
|
assert b2_impl.executed
|
|
|
|
|
|
# === 14. Branch timeout with fail_all (explicit) ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_branch_timeout_with_fail_all(runtime, goal):
|
|
"""fail_all should propagate timeout as execution failure."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="slow", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="also slow", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(branch_timeout_seconds=0.1, on_branch_failure="fail_all")
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
executor.register_node("b1", SlowNode(delay=10.0))
|
|
executor.register_node("b2", SlowNode(delay=10.0))
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert not result.success
|
|
|
|
|
|
# === 15. Memory conflict: last_wins ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_conflict_last_wins(runtime, goal):
|
|
"""last_wins should allow both branches to write the same key without error."""
|
|
# Use distinct output_keys in spec (to pass graph validation) but have
|
|
# the node impl write a shared key at runtime — this is the scenario
|
|
# buffer_conflict_strategy is designed to handle.
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="b1", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="b2", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(buffer_conflict_strategy="last_wins")
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
# Both impls write "shared_key" — triggers conflict detection at runtime
|
|
executor.register_node("b1", SuccessNode({"shared_key": "from_b1", "b1_out": "ok"}))
|
|
executor.register_node("b2", SuccessNode({"shared_key": "from_b2", "b2_out": "ok"}))
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert result.success
|
|
# The key should exist with one of the two values
|
|
assert result.output.get("shared_key") in ("from_b1", "from_b2")
|
|
|
|
|
|
# === 16. Memory conflict: first_wins ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_conflict_first_wins(runtime, goal):
|
|
"""first_wins should keep the first branch's value and skip later writes."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="b1", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="b2", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(buffer_conflict_strategy="first_wins")
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
executor.register_node("b1", SuccessNode({"shared_key": "from_b1", "b1_out": "ok"}))
|
|
executor.register_node("b2", SuccessNode({"shared_key": "from_b2", "b2_out": "ok"}))
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert result.success
|
|
|
|
|
|
# === 17. Memory conflict: error raises ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_conflict_error_raises(runtime, goal):
|
|
"""error strategy should fail when two branches write the same key."""
|
|
b1 = NodeSpec(
|
|
id="b1", name="B1", description="b1", node_type="event_loop", output_keys=["b1_out"]
|
|
)
|
|
b2 = NodeSpec(
|
|
id="b2", name="B2", description="b2", node_type="event_loop", output_keys=["b2_out"]
|
|
)
|
|
|
|
graph = _make_fanout_graph([b1, b2])
|
|
|
|
config = ParallelExecutionConfig(buffer_conflict_strategy="error")
|
|
executor = GraphExecutor(
|
|
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
|
)
|
|
executor.register_node("source", SuccessNode({"data": "x"}))
|
|
executor.register_node("b1", SuccessNode({"shared_key": "from_b1", "b1_out": "ok"}))
|
|
executor.register_node("b2", SuccessNode({"shared_key": "from_b2", "b2_out": "ok"}))
|
|
|
|
result = await executor.execute(graph, goal, {})
|
|
|
|
assert not result.success
|
|
# The conflict RuntimeError is caught inside execute_single_branch,
|
|
# which causes the branch to fail. fail_all then raises its own error.
|
|
assert "failed" in result.error.lower()
|