Compare commits

...

1 Commits

Author SHA1 Message Date
bryan 3579fd422c wp-7 2026-01-30 19:42:30 -08:00
3 changed files with 312 additions and 5 deletions
+37 -4
View File
@@ -380,6 +380,15 @@ class GraphExecutor:
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3 # [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
max_retries = getattr(node_spec, "max_retries", 3) max_retries = getattr(node_spec, "max_retries", 3)
if node_spec.node_type == "event_loop":
if max_retries > 0:
self.logger.warning(
f"EventLoopNode '{node_spec.id}' has "
f"max_retries={max_retries}. Overriding to 0 "
"— event loop nodes handle retry internally."
)
max_retries = 0
if node_retry_counts[current_node_id] < max_retries: if node_retry_counts[current_node_id] < max_retries:
# Retry - don't increment steps for retries # Retry - don't increment steps for retries
steps -= 1 steps -= 1
@@ -658,7 +667,14 @@ class GraphExecutor:
) )
# Valid node types - no ambiguous "llm" type allowed # Valid node types - no ambiguous "llm" type allowed
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"} VALID_NODE_TYPES = {
"llm_tool_use",
"llm_generate",
"router",
"function",
"human_input",
"event_loop",
}
def _get_node_implementation( def _get_node_implementation(
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
@@ -713,6 +729,12 @@ class GraphExecutor:
cleanup_llm_model=cleanup_llm_model, cleanup_llm_model=cleanup_llm_model,
) )
if node_spec.node_type == "event_loop":
raise RuntimeError(
f"EventLoopNode '{node_spec.id}' not found in registry. "
"Register it with executor.register_node() before execution."
)
# Should never reach here due to validation above # Should never reach here due to validation above
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}") raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
@@ -909,6 +931,17 @@ class GraphExecutor:
branch.status = "failed" branch.status = "failed"
branch.error = f"Node {branch.node_id} not found in graph" branch.error = f"Node {branch.node_id} not found in graph"
return branch, RuntimeError(branch.error) return branch, RuntimeError(branch.error)
effective_max_retries = node_spec.max_retries
if node_spec.node_type == "event_loop":
if effective_max_retries > 1:
self.logger.warning(
f"EventLoopNode '{node_spec.id}' has "
f"max_retries={effective_max_retries}. Overriding "
"to 1 — event loop nodes handle retry internally."
)
effective_max_retries = 1
branch.status = "running" branch.status = "running"
try: try:
@@ -942,7 +975,7 @@ class GraphExecutor:
# Execute with retries # Execute with retries
last_result = None last_result = None
for attempt in range(node_spec.max_retries): for attempt in range(effective_max_retries):
branch.retry_count = attempt branch.retry_count = attempt
# Build context for this branch # Build context for this branch
@@ -970,7 +1003,7 @@ class GraphExecutor:
self.logger.warning( self.logger.warning(
f" ↻ Branch {node_spec.name}: " f" ↻ Branch {node_spec.name}: "
f"retry {attempt + 1}/{node_spec.max_retries}" f"retry {attempt + 1}/{effective_max_retries}"
) )
# All retries exhausted # All retries exhausted
@@ -979,7 +1012,7 @@ class GraphExecutor:
branch.result = last_result branch.result = last_result
self.logger.error( self.logger.error(
f" ✗ Branch {node_spec.name}: " f" ✗ Branch {node_spec.name}: "
f"failed after {node_spec.max_retries} attempts" f"failed after {effective_max_retries} attempts"
) )
return branch, last_result return branch, last_result
+10 -1
View File
@@ -153,7 +153,10 @@ class NodeSpec(BaseModel):
# Node behavior type # Node behavior type
node_type: str = Field( node_type: str = Field(
default="llm_tool_use", default="llm_tool_use",
description="Type: 'llm_tool_use', 'llm_generate', 'function', 'router', 'human_input'", description=(
"Type: 'llm_tool_use', 'llm_generate', 'function', "
"'router', 'human_input', 'event_loop'"
),
) )
# Data flow # Data flow
@@ -218,6 +221,12 @@ class NodeSpec(BaseModel):
description="Maximum retries when Pydantic validation fails (with feedback to LLM)", description="Maximum retries when Pydantic validation fails (with feedback to LLM)",
) )
# Event loop behavior
client_facing: bool = Field(
default=False,
description="If True, this node streams output to the end user and can request input.",
)
model_config = {"extra": "allow", "arbitrary_types_allowed": True} model_config = {"extra": "allow", "arbitrary_types_allowed": True}
+265
View File
@@ -0,0 +1,265 @@
"""
Tests for event_loop node type wiring (Issue #2513).
Covers:
- NodeSpec.client_facing field
- event_loop in VALID_NODE_TYPES
- _get_node_implementation() event_loop branch
- no-retry enforcement in serial execution path
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from framework.graph.edge import GraphSpec
from framework.graph.executor import GraphExecutor
from framework.graph.goal import Goal
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
from framework.runtime.core import Runtime
class AlwaysFailsNode(NodeProtocol):
"""A test node that always fails."""
def __init__(self):
self.attempt_count = 0
async def execute(self, ctx: NodeContext) -> NodeResult:
self.attempt_count += 1
return NodeResult(success=False, error=f"Permanent error (attempt {self.attempt_count})")
class SucceedsOnceNode(NodeProtocol):
"""A test node that always succeeds."""
async def execute(self, ctx: NodeContext) -> NodeResult:
return NodeResult(success=True, output={"result": "ok"})
@pytest.fixture(autouse=True)
def fast_sleep(monkeypatch):
"""Mock asyncio.sleep to avoid real delays from exponential backoff."""
monkeypatch.setattr("asyncio.sleep", AsyncMock())
@pytest.fixture
def runtime():
"""Create a mock Runtime for testing."""
runtime = MagicMock(spec=Runtime)
runtime.start_run = MagicMock(return_value="test_run_id")
runtime.decide = MagicMock(return_value="test_decision_id")
runtime.record_outcome = MagicMock()
runtime.end_run = MagicMock()
runtime.report_problem = MagicMock()
runtime.set_node = MagicMock()
return runtime
# --- NodeSpec.client_facing tests ---
def test_client_facing_defaults_false():
"""NodeSpec without client_facing should default to False."""
spec = NodeSpec(
id="n1",
name="Node 1",
description="test",
node_type="llm_generate",
)
assert spec.client_facing is False
def test_client_facing_explicit_true():
"""NodeSpec with client_facing=True should retain the value."""
spec = NodeSpec(
id="n1",
name="Node 1",
description="test",
node_type="event_loop",
client_facing=True,
)
assert spec.client_facing is True
# --- VALID_NODE_TYPES tests ---
def test_event_loop_in_valid_node_types():
"""'event_loop' must be in GraphExecutor.VALID_NODE_TYPES."""
assert "event_loop" in GraphExecutor.VALID_NODE_TYPES
def test_event_loop_node_spec_accepted():
"""Creating a NodeSpec with node_type='event_loop' should not raise."""
spec = NodeSpec(
id="el1",
name="Event Loop",
description="test",
node_type="event_loop",
)
assert spec.node_type == "event_loop"
# --- _get_node_implementation() tests ---
def test_unregistered_event_loop_raises(runtime):
"""An event_loop node not in the registry should raise RuntimeError."""
spec = NodeSpec(
id="el1",
name="Event Loop",
description="test",
node_type="event_loop",
)
executor = GraphExecutor(runtime=runtime)
with pytest.raises(RuntimeError, match="not found in registry"):
executor._get_node_implementation(spec)
def test_registered_event_loop_returns_impl(runtime):
"""A registered event_loop node should be returned from the registry."""
spec = NodeSpec(
id="el1",
name="Event Loop",
description="test",
node_type="event_loop",
)
impl = SucceedsOnceNode()
executor = GraphExecutor(runtime=runtime)
executor.register_node("el1", impl)
result = executor._get_node_implementation(spec)
assert result is impl
# --- No-retry enforcement (serial path) ---
@pytest.mark.asyncio
async def test_event_loop_max_retries_forced_zero(runtime):
"""An event_loop node with max_retries=3 should only execute once (no retry)."""
node_spec = NodeSpec(
id="el_fail",
name="Failing Event Loop",
description="event loop that fails",
node_type="event_loop",
max_retries=3,
output_keys=["result"],
)
graph = GraphSpec(
id="test_graph",
goal_id="test_goal",
name="Test Graph",
entry_node="el_fail",
nodes=[node_spec],
edges=[],
terminal_nodes=["el_fail"],
)
goal = Goal(id="test_goal", name="Test", description="test")
executor = GraphExecutor(runtime=runtime)
failing_node = AlwaysFailsNode()
executor.register_node("el_fail", failing_node)
result = await executor.execute(graph, goal, {})
# Event loop nodes get max_retries overridden to 0, meaning execute once then fail
assert not result.success
assert failing_node.attempt_count == 1
@pytest.mark.asyncio
async def test_event_loop_max_retries_zero_no_warning(runtime, caplog):
"""An event_loop node with max_retries=0 should not log a warning."""
node_spec = NodeSpec(
id="el_zero",
name="Zero Retry Event Loop",
description="event loop with 0 retries",
node_type="event_loop",
max_retries=0,
output_keys=["result"],
)
graph = GraphSpec(
id="test_graph",
goal_id="test_goal",
name="Test Graph",
entry_node="el_zero",
nodes=[node_spec],
edges=[],
terminal_nodes=["el_zero"],
)
goal = Goal(id="test_goal", name="Test", description="test")
executor = GraphExecutor(runtime=runtime)
failing_node = AlwaysFailsNode()
executor.register_node("el_zero", failing_node)
import logging
with caplog.at_level(logging.WARNING):
await executor.execute(graph, goal, {})
# max_retries=0 should not trigger the override warning
assert "Overriding to 0" not in caplog.text
@pytest.mark.asyncio
async def test_event_loop_max_retries_positive_logs_warning(runtime, caplog):
"""An event_loop node with max_retries=3 should log a warning about override."""
node_spec = NodeSpec(
id="el_warn",
name="Warning Event Loop",
description="event loop with retries",
node_type="event_loop",
max_retries=3,
output_keys=["result"],
)
graph = GraphSpec(
id="test_graph",
goal_id="test_goal",
name="Test Graph",
entry_node="el_warn",
nodes=[node_spec],
edges=[],
terminal_nodes=["el_warn"],
)
goal = Goal(id="test_goal", name="Test", description="test")
executor = GraphExecutor(runtime=runtime)
failing_node = AlwaysFailsNode()
executor.register_node("el_warn", failing_node)
import logging
with caplog.at_level(logging.WARNING):
await executor.execute(graph, goal, {})
assert "Overriding to 0" in caplog.text
assert "el_warn" in caplog.text
# --- Existing node types unaffected ---
def test_existing_node_types_unchanged():
"""All pre-existing node types must still be in VALID_NODE_TYPES with defaults preserved."""
expected = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
assert expected.issubset(GraphExecutor.VALID_NODE_TYPES)
# Default node_type is still llm_tool_use
spec = NodeSpec(id="x", name="X", description="x")
assert spec.node_type == "llm_tool_use"
# Default max_retries is still 3
assert spec.max_retries == 3
# Default client_facing is False
assert spec.client_facing is False