Compare commits

...

1 Commits

Author SHA1 Message Date
Timothy 8db2e72f5e feat: selective back edges 2026-02-03 19:25:46 -08:00
19 changed files with 789 additions and 43 deletions
+13 -3
View File
@@ -519,7 +519,7 @@ NODE_SPECS = {
"load_data",
"save_data",
],
max_node_visits=3,
max_node_visits=5,
system_prompt=(
"You are a Contact Extractor agent. Your inputs 'user_profiles' and "
"'relevance_scores' are filenames pointing to JSON data files.\n\n"
@@ -555,8 +555,9 @@ NODE_SPECS = {
input_keys=["contact_list"],
output_keys=["approved_contacts", "redo_extraction"],
nullable_output_keys=["approved_contacts", "redo_extraction"],
max_node_visits=3,
max_node_visits=5,
tools=["load_data", "save_data"],
allowed_navigation_targets=["extractor"],
system_prompt=(
"You are the Review agent at a human checkpoint. Your input 'contact_list' "
"is a filename pointing to a JSON data file.\n\n"
@@ -570,6 +571,10 @@ NODE_SPECS = {
" set_output(key='approved_contacts', value='approved_contacts.json')\n\n"
" IF REDO REQUESTED: call:\n"
" set_output(key='redo_extraction', value='true')\n\n"
"NAVIGATION:\n"
"If the operator wants to go back and re-extract contacts with different "
"criteria, you can use navigate_to(target='extractor') instead of "
"set_output(key='redo_extraction').\n\n"
"CRITICAL RULE: Call set_output EXACTLY ONCE with EXACTLY ONE key.\n"
"NEVER call set_output twice. NEVER set both keys.\n"
"The two output keys are mutually exclusive — setting both will cause an error."
@@ -631,8 +636,9 @@ NODE_SPECS = {
input_keys=["draft_emails"],
output_keys=["approved_emails", "revise_campaigns"],
nullable_output_keys=["approved_emails", "revise_campaigns"],
max_node_visits=3,
max_node_visits=5,
tools=["load_data", "save_data"],
allowed_navigation_targets=["review", "campaign_builder"],
system_prompt=(
"You are the Approval agent at the final human checkpoint. Your input "
"'draft_emails' is a filename pointing to a JSON data file.\n\n"
@@ -645,6 +651,10 @@ NODE_SPECS = {
" set_output(key='approved_emails', value='approved_emails.json')\n\n"
" IF REVISION REQUESTED: call:\n"
" set_output(key='revise_campaigns', value='true')\n\n"
"NAVIGATION:\n"
"If the operator wants to go back to review the contact list, use "
"navigate_to(target='review'). If they want to revise the email drafts, "
"use navigate_to(target='campaign_builder').\n\n"
"CRITICAL RULE: Call set_output EXACTLY ONCE with EXACTLY ONE key.\n"
"NEVER call set_output twice. NEVER set both keys.\n"
"The two output keys are mutually exclusive — setting both will cause an error."
+2 -2
View File
@@ -15,7 +15,7 @@ You cannot skip steps or bypass validation.
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from enum import StrEnum
from pathlib import Path
from typing import Any
@@ -26,7 +26,7 @@ from framework.graph.goal import Goal
from framework.graph.node import NodeSpec
class BuildPhase(str, Enum):
class BuildPhase(StrEnum):
"""Current phase of the build process."""
INIT = "init" # Just started
+2 -2
View File
@@ -8,7 +8,7 @@ containing one or more keys (e.g., api_key, access_token, refresh_token).
from __future__ import annotations
from datetime import UTC, datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, SecretStr
@@ -19,7 +19,7 @@ def _utc_now() -> datetime:
return datetime.now(UTC)
class CredentialType(str, Enum):
class CredentialType(StrEnum):
"""Types of credentials the store can manage."""
API_KEY = "api_key"
@@ -11,11 +11,11 @@ from __future__ import annotations
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from enum import Enum
from enum import StrEnum
from typing import Any
class TokenPlacement(str, Enum):
class TokenPlacement(StrEnum):
"""Where to place the access token in HTTP requests."""
HEADER_BEARER = "header_bearer"
+13 -2
View File
@@ -21,7 +21,7 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
given the current goal, context, and execution state.
"""
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
@@ -29,7 +29,7 @@ from pydantic import BaseModel, Field
from framework.graph.safe_eval import safe_eval
class EdgeCondition(str, Enum):
class EdgeCondition(StrEnum):
"""When an edge should be traversed."""
ALWAYS = "always" # Always after source completes
@@ -644,4 +644,15 @@ class GraphSpec(BaseModel):
else:
seen_keys[key] = node_id
# Validate allowed_navigation_targets reference existing nodes
node_ids = {n.id for n in self.nodes}
for node in self.nodes:
targets = getattr(node, "allowed_navigation_targets", [])
for target_id in targets:
if target_id not in node_ids:
errors.append(
f"Node '{node.id}' has allowed_navigation_target "
f"'{target_id}' which doesn't exist in the graph"
)
return errors
+112 -1
View File
@@ -180,6 +180,9 @@ class EventLoopNode(NodeProtocol):
# Client-facing input blocking state
self._input_ready = asyncio.Event()
self._shutdown = False
# Dynamic navigation state (hybrid execution pattern)
self._navigate_target: str | None = None
self._navigate_reason: str | None = None
def validate_input(self, ctx: NodeContext) -> list[str]:
"""Validate hard requirements only.
@@ -209,6 +212,10 @@ class EventLoopNode(NodeProtocol):
if ctx.llm is None:
return NodeResult(success=False, error="LLM provider not available")
# 1b. Reset navigation state
self._navigate_target = None
self._navigate_reason = None
# 2. Restore or create new conversation + accumulator
conversation, accumulator, start_iteration = await self._restore(ctx)
if conversation is None:
@@ -226,11 +233,16 @@ class EventLoopNode(NodeProtocol):
if initial_message:
await conversation.add_user_message(initial_message)
# 3. Build tool list: node tools + synthetic set_output tool
# 3. Build tool list: node tools + synthetic tools
tools = list(ctx.available_tools)
set_output_tool = self._build_set_output_tool(ctx.node_spec.output_keys)
if set_output_tool:
tools.append(set_output_tool)
navigate_to_tool = self._build_navigate_to_tool(
getattr(ctx.node_spec, "allowed_navigation_targets", []),
)
if navigate_to_tool:
tools.append(navigate_to_tool)
logger.info(
"[%s] Tools available (%d): %s | client_facing=%s | judge=%s",
@@ -321,6 +333,32 @@ class EventLoopNode(NodeProtocol):
# 6g. Write cursor checkpoint
await self._write_cursor(ctx, conversation, accumulator, iteration)
# 6g'. Check if navigate_to was called
if self._navigate_target:
target = self._navigate_target
reason = self._navigate_reason or ""
self._navigate_target = None
self._navigate_reason = None
logger.info(
"[%s] iter=%d: navigate_to '%s' (reason: %s)",
node_id,
iteration,
target,
reason,
)
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
latency_ms = int((time.time() - start_time) * 1000)
return NodeResult(
success=True,
output=accumulator.to_dict(),
next_node=target,
route_reason=f"User navigation: {reason}",
tokens_used=total_input_tokens + total_output_tokens,
latency_ms=latency_ms,
)
# 6h. Client-facing input wait
logger.info(
"[%s] iter=%d: 6h check — client_facing=%s, tool_results=%d",
@@ -641,6 +679,19 @@ class EventLoopNode(NodeProtocol):
# Async write-through for set_output
if not result.is_error:
await accumulator.set(tc.tool_input["key"], tc.tool_input["value"])
elif tc.tool_name == "navigate_to":
result = self._handle_navigate_to(
tc.tool_input,
getattr(ctx.node_spec, "allowed_navigation_targets", []),
)
result = ToolResult(
tool_use_id=tc.tool_use_id,
content=result.content,
is_error=result.is_error,
)
if not result.is_error:
self._navigate_target = tc.tool_input["target"]
self._navigate_reason = tc.tool_input.get("reason", "")
else:
# Execute real tool
result = await self._execute_tool(tc)
@@ -806,6 +857,66 @@ class EventLoopNode(NodeProtocol):
is_error=False,
)
# -------------------------------------------------------------------
# Dynamic navigation (navigate_to synthetic tool)
# -------------------------------------------------------------------
def _build_navigate_to_tool(self, allowed_targets: list[str]) -> Tool | None:
"""Build the synthetic navigate_to tool for dynamic graph navigation.
Only created when allowed_targets is non-empty.
"""
if not allowed_targets:
return None
targets_str = ", ".join(f"'{t}'" for t in allowed_targets)
return Tool(
name="navigate_to",
description=(
"Navigate the pipeline to a different stage. Use this when the user "
"asks to go back to a previous step or move to a different stage. "
f"Allowed targets: {targets_str}. "
"IMPORTANT: Calling this exits the current stage immediately. "
"Any outputs you've set with set_output will be discarded."
),
parameters={
"type": "object",
"properties": {
"target": {
"type": "string",
"description": f"Node ID to navigate to. Must be one of: {allowed_targets}",
"enum": allowed_targets,
},
"reason": {
"type": "string",
"description": "Brief reason for navigation.",
},
},
"required": ["target", "reason"],
},
)
def _handle_navigate_to(
self,
tool_input: dict[str, Any],
allowed_targets: list[str],
) -> ToolResult:
"""Handle navigate_to tool call. Returns ToolResult (sync)."""
target = tool_input.get("target", "")
reason = tool_input.get("reason", "")
if target not in allowed_targets:
return ToolResult(
tool_use_id="",
content=f"Invalid navigation target '{target}'. Allowed: {allowed_targets}",
is_error=True,
)
return ToolResult(
tool_use_id="",
content=f"Navigating to '{target}'. Reason: {reason}",
is_error=False,
)
# -------------------------------------------------------------------
# Judge evaluation
# -------------------------------------------------------------------
+39 -5
View File
@@ -255,6 +255,7 @@ class GraphExecutor:
total_latency = 0
node_retry_counts: dict[str, int] = {} # Track retries per node
node_visit_counts: dict[str, int] = {} # Track visits for feedback loops
_is_retry = False # True when looping back for a retry (not a new visit)
# Determine entry point (may differ if resuming)
current_node_id = graph.get_entry_point(session_state)
@@ -284,7 +285,11 @@ class GraphExecutor:
raise RuntimeError(f"Node not found: {current_node_id}")
# Enforce max_node_visits (feedback/callback edge support)
node_visit_counts[current_node_id] = node_visit_counts.get(current_node_id, 0) + 1
# Don't increment visit count on retries — retries are not new visits
if not _is_retry:
cnt = node_visit_counts.get(current_node_id, 0) + 1
node_visit_counts[current_node_id] = cnt
_is_retry = False
max_visits = getattr(node_spec, "max_node_visits", 1)
if max_visits > 0 and node_visit_counts[current_node_id] > max_visits:
self.logger.warning(
@@ -433,6 +438,7 @@ class GraphExecutor:
self.logger.info(
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
)
_is_retry = True
continue
else:
# Max retries exceeded - fail the execution
@@ -524,11 +530,39 @@ class GraphExecutor:
break
# Determine next node
_navigated = False
if result.next_node:
# Router explicitly set next node
self.logger.info(f" → Router directing to: {result.next_node}")
current_node_id = result.next_node
else:
# Dynamic routing: router or navigate_to
_nav_valid = True
if node_spec.node_type != "router":
# Validate navigation target for non-router nodes
allowed = getattr(node_spec, "allowed_navigation_targets", [])
target_spec = graph.get_node(result.next_node)
if target_spec is None:
self.logger.warning(
f" ! Navigation target '{result.next_node}' "
f"not found in graph"
)
_nav_valid = False
elif allowed and result.next_node not in allowed:
self.logger.warning(
f" ! Navigation to '{result.next_node}' blocked: "
f"not in allowed_navigation_targets {allowed}"
)
_nav_valid = False
if _nav_valid:
reason = result.route_reason or ""
self.logger.info(
f" → Navigating to: {result.next_node}"
+ (f" ({reason})" if reason else "")
)
current_node_id = result.next_node
_navigated = True
if not _navigated:
# Get all traversable edges for fan-out detection
traversable_edges = self._get_all_traversable_edges(
graph=graph,
+2 -2
View File
@@ -12,13 +12,13 @@ Goals are:
"""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class GoalStatus(str, Enum):
class GoalStatus(StrEnum):
"""Lifecycle status of a goal."""
DRAFT = "draft" # Being defined
+2 -2
View File
@@ -6,11 +6,11 @@ where agents need to gather input from humans.
"""
from dataclasses import dataclass, field
from enum import Enum
from enum import StrEnum
from typing import Any
class HITLInputType(str, Enum):
class HITLInputType(StrEnum):
"""Type of input expected from human."""
FREE_TEXT = "free_text" # Open-ended text response
+6
View File
@@ -237,6 +237,12 @@ class NodeSpec(BaseModel):
description="If True, this node streams output to the end user and can request input.",
)
# Dynamic navigation (hybrid execution pattern)
allowed_navigation_targets: list[str] = Field(
default_factory=list,
description="Node IDs this node can navigate to via navigate_to tool. Empty = disabled.",
)
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
+6 -6
View File
@@ -11,13 +11,13 @@ The Plan is the contract between the external planner and the executor:
"""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class ActionType(str, Enum):
class ActionType(StrEnum):
"""Types of actions a PlanStep can perform."""
LLM_CALL = "llm_call" # Call LLM for generation
@@ -27,7 +27,7 @@ class ActionType(str, Enum):
CODE_EXECUTION = "code_execution" # Execute dynamic code (sandboxed)
class StepStatus(str, Enum):
class StepStatus(StrEnum):
"""Status of a plan step."""
PENDING = "pending"
@@ -56,7 +56,7 @@ class StepStatus(str, Enum):
return self == StepStatus.COMPLETED
class ApprovalDecision(str, Enum):
class ApprovalDecision(StrEnum):
"""Human decision on a step requiring approval."""
APPROVE = "approve" # Execute as planned
@@ -91,7 +91,7 @@ class ApprovalResult(BaseModel):
model_config = {"extra": "allow"}
class JudgmentAction(str, Enum):
class JudgmentAction(StrEnum):
"""Actions the judge can take after evaluating a step."""
ACCEPT = "accept" # Step completed successfully, continue
@@ -423,7 +423,7 @@ class Plan(BaseModel):
}
class ExecutionStatus(str, Enum):
class ExecutionStatus(StrEnum):
"""Status of plan execution."""
COMPLETED = "completed"
+2 -2
View File
@@ -12,13 +12,13 @@ import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__)
class EventType(str, Enum):
class EventType(StrEnum):
"""Types of events that can be published."""
# Execution lifecycle
+3 -3
View File
@@ -11,13 +11,13 @@ import asyncio
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__)
class IsolationLevel(str, Enum):
class IsolationLevel(StrEnum):
"""State isolation level for concurrent executions."""
ISOLATED = "isolated" # Private state per execution
@@ -25,7 +25,7 @@ class IsolationLevel(str, Enum):
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
class StateScope(str, Enum):
class StateScope(StrEnum):
"""Scope for state operations."""
EXECUTION = "execution" # Local to a single execution
+2 -2
View File
@@ -10,13 +10,13 @@ This is MORE important than actions because:
"""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, computed_field
class DecisionType(str, Enum):
class DecisionType(StrEnum):
"""Types of decisions an agent can make."""
TOOL_SELECTION = "tool_selection" # Which tool to use
+2 -2
View File
@@ -6,7 +6,7 @@ summaries and metrics that Builder needs to understand what happened.
"""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, computed_field
@@ -14,7 +14,7 @@ from pydantic import BaseModel, Field, computed_field
from framework.schemas.decision import Decision, Outcome
class RunStatus(str, Enum):
class RunStatus(StrEnum):
"""Status of a run."""
RUNNING = "running"
+2 -2
View File
@@ -6,13 +6,13 @@ programmatic/MCP-based approval.
"""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class ApprovalAction(str, Enum):
class ApprovalAction(StrEnum):
"""Actions a user can take on a generated test."""
APPROVE = "approve" # Accept as-is
+3 -3
View File
@@ -6,13 +6,13 @@ but require mandatory user approval before being stored.
"""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class ApprovalStatus(str, Enum):
class ApprovalStatus(StrEnum):
"""Status of user approval for a generated test."""
PENDING = "pending" # Awaiting user review
@@ -21,7 +21,7 @@ class ApprovalStatus(str, Enum):
REJECTED = "rejected" # User declined (with reason)
class TestType(str, Enum):
class TestType(StrEnum):
"""Type of test based on what it validates."""
__test__ = False # Not a pytest test class
+2 -2
View File
@@ -6,13 +6,13 @@ categorization for guiding iteration strategy.
"""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class ErrorCategory(str, Enum):
class ErrorCategory(StrEnum):
"""
Category of test failure for guiding iteration.
+574
View File
@@ -0,0 +1,574 @@
"""Tests for the navigate_to dynamic routing feature (hybrid execution pattern).
Tests the navigate_to synthetic tool on EventLoopNode and the executor's
validation of navigation targets.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import MagicMock
import pytest
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.event_loop_node import (
EventLoopNode,
LoopConfig,
)
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory
from framework.llm.provider import LLMProvider, LLMResponse, Tool
from framework.llm.stream_events import (
FinishEvent,
TextDeltaEvent,
ToolCallEvent,
)
from framework.runtime.core import Runtime
# ---------------------------------------------------------------------------
# Mock LLM (same pattern as test_event_loop_node.py)
# ---------------------------------------------------------------------------
class MockStreamingLLM(LLMProvider):
"""Mock LLM that yields pre-programmed StreamEvent sequences."""
def __init__(self, scenarios: list[list] | None = None):
self.scenarios = scenarios or []
self._call_index = 0
self.stream_calls: list[dict] = []
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator:
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
if not self.scenarios:
return
events = self.scenarios[self._call_index % len(self.scenarios)]
self._call_index += 1
for event in events:
yield event
def complete(self, messages, system="", **kwargs) -> LLMResponse:
return LLMResponse(content="Summary.", model="mock", stop_reason="stop")
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
return LLMResponse(content="", model="mock", stop_reason="stop")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def text_scenario(text: str) -> list:
return [
TextDeltaEvent(content=text, snapshot=text),
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
]
def tool_call_scenario(tool_name: str, tool_input: dict, tool_use_id: str = "call_1") -> list:
return [
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input),
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
]
@pytest.fixture
def runtime():
rt = MagicMock(spec=Runtime)
rt.start_run = MagicMock(return_value="run_1")
rt.decide = MagicMock(return_value="dec_1")
rt.record_outcome = MagicMock()
rt.end_run = MagicMock()
rt.report_problem = MagicMock()
rt.set_node = MagicMock()
return rt
@pytest.fixture
def memory():
return SharedMemory()
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None):
return NodeContext(
runtime=runtime,
node_id=node_spec.id,
node_spec=node_spec,
memory=memory,
input_data=input_data or {},
llm=llm,
available_tools=tools or [],
)
# ===========================================================================
# navigate_to tool availability
# ===========================================================================
class TestNavigateToToolAvailability:
@pytest.mark.asyncio
async def test_no_tool_without_targets(self, runtime, memory):
"""navigate_to tool should NOT appear when allowed_navigation_targets is empty."""
spec = NodeSpec(
id="node_a",
name="Node A",
description="test",
node_type="event_loop",
output_keys=["result"],
allowed_navigation_targets=[],
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
text_scenario("Done."),
]
)
ctx = build_ctx(runtime, spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
await node.execute(ctx)
assert llm.stream_calls, "LLM should have been called"
tools_sent = llm.stream_calls[0]["tools"] or []
tool_names = [t.name for t in tools_sent]
assert "navigate_to" not in tool_names
assert "set_output" in tool_names
@pytest.mark.asyncio
async def test_tool_present_with_targets(self, runtime, memory):
"""navigate_to tool should appear when allowed_navigation_targets is set."""
spec = NodeSpec(
id="node_a",
name="Node A",
description="test",
node_type="event_loop",
output_keys=["result"],
allowed_navigation_targets=["node_b", "node_c"],
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
text_scenario("Done."),
]
)
ctx = build_ctx(runtime, spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
await node.execute(ctx)
tools_sent = llm.stream_calls[0]["tools"] or []
tool_names = [t.name for t in tools_sent]
assert "navigate_to" in tool_names
assert "set_output" in tool_names
# ===========================================================================
# navigate_to tool execution
# ===========================================================================
class TestNavigateToExecution:
@pytest.mark.asyncio
async def test_valid_target_returns_next_node(self, runtime, memory):
"""Calling navigate_to with a valid target should set NodeResult.next_node."""
spec = NodeSpec(
id="approval",
name="Approval",
description="test",
node_type="event_loop",
output_keys=["approved", "revise"],
nullable_output_keys=["approved", "revise"],
allowed_navigation_targets=["review", "campaign"],
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"navigate_to",
{"target": "review", "reason": "User wants to go back"},
),
]
)
ctx = build_ctx(runtime, spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.next_node == "review"
assert "User wants to go back" in (result.route_reason or "")
@pytest.mark.asyncio
async def test_invalid_target_continues_loop(self, runtime, memory):
"""Calling navigate_to with an invalid target should error and continue."""
spec = NodeSpec(
id="approval",
name="Approval",
description="test",
node_type="event_loop",
output_keys=["result"],
allowed_navigation_targets=["review"],
)
llm = MockStreamingLLM(
scenarios=[
# Turn 1: try invalid target
tool_call_scenario(
"navigate_to",
{"target": "nonexistent", "reason": "test"},
),
# Turn 2: set output normally
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
# Turn 3: text -> implicit judge accepts
text_scenario("Done."),
]
)
ctx = build_ctx(runtime, spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.next_node is None
assert result.output.get("result") == "ok"
@pytest.mark.asyncio
async def test_navigate_to_with_partial_outputs(self, runtime, memory):
"""Outputs set before navigate_to should still be in the result dict
but the node exits via navigation, not via normal completion."""
spec = NodeSpec(
id="approval",
name="Approval",
description="test",
node_type="event_loop",
output_keys=["draft", "final"],
nullable_output_keys=["draft", "final"],
allowed_navigation_targets=["review"],
)
llm = MockStreamingLLM(
scenarios=[
# Turn 1: set one output
tool_call_scenario(
"set_output",
{"key": "draft", "value": "v1"},
tool_use_id="call_1",
),
# Turn 2: navigate away
tool_call_scenario(
"navigate_to",
{"target": "review", "reason": "go back"},
tool_use_id="call_2",
),
]
)
ctx = build_ctx(runtime, spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.next_node == "review"
assert result.output.get("draft") == "v1"
# ===========================================================================
# navigate_to handler unit tests
# ===========================================================================
class TestNavigateToHandler:
def test_valid_target(self):
"""_handle_navigate_to should succeed for valid targets."""
node = EventLoopNode()
result = node._handle_navigate_to(
{"target": "review", "reason": "go back"},
allowed_targets=["review", "campaign"],
)
assert not result.is_error
assert "review" in result.content
def test_invalid_target(self):
"""_handle_navigate_to should error for invalid targets."""
node = EventLoopNode()
result = node._handle_navigate_to(
{"target": "nonexistent", "reason": "test"},
allowed_targets=["review", "campaign"],
)
assert result.is_error
assert "Invalid" in result.content
def test_empty_target(self):
"""_handle_navigate_to should error when target is empty."""
node = EventLoopNode()
result = node._handle_navigate_to(
{"target": "", "reason": "test"},
allowed_targets=["review"],
)
assert result.is_error
# ===========================================================================
# navigate_to tool builder
# ===========================================================================
class TestNavigateToToolBuilder:
def test_no_tool_with_empty_targets(self):
"""Should return None when no targets allowed."""
node = EventLoopNode()
tool = node._build_navigate_to_tool([])
assert tool is None
def test_tool_with_targets(self):
"""Should return a Tool with enum of allowed targets."""
node = EventLoopNode()
tool = node._build_navigate_to_tool(["review", "campaign"])
assert tool is not None
assert tool.name == "navigate_to"
assert tool.parameters["properties"]["target"]["enum"] == ["review", "campaign"]
assert "required" in tool.parameters
assert "target" in tool.parameters["required"]
assert "reason" in tool.parameters["required"]
# ===========================================================================
# GraphSpec validation for navigation targets
# ===========================================================================
class TestGraphSpecNavValidation:
def test_valid_targets_pass(self):
"""Navigation targets that reference existing nodes should pass validation."""
graph = GraphSpec(
id="test",
goal_id="test_goal",
entry_node="a",
nodes=[
NodeSpec(
id="a",
name="A",
description="t",
node_type="event_loop",
allowed_navigation_targets=["b"],
),
NodeSpec(id="b", name="B", description="t", node_type="event_loop"),
],
edges=[
EdgeSpec(id="a_to_b", source="a", target="b"),
],
terminal_nodes=["b"],
)
errors = graph.validate()
nav_errors = [e for e in errors if "allowed_navigation_target" in e]
assert nav_errors == []
def test_invalid_targets_fail(self):
"""Navigation targets referencing non-existent nodes should fail."""
graph = GraphSpec(
id="test",
goal_id="test_goal",
entry_node="a",
nodes=[
NodeSpec(
id="a",
name="A",
description="t",
node_type="event_loop",
allowed_navigation_targets=["nonexistent"],
),
NodeSpec(id="b", name="B", description="t", node_type="event_loop"),
],
edges=[
EdgeSpec(id="a_to_b", source="a", target="b"),
],
terminal_nodes=["b"],
)
errors = graph.validate()
nav_errors = [e for e in errors if "allowed_navigation_target" in e]
assert len(nav_errors) == 1
assert "nonexistent" in nav_errors[0]
def test_empty_targets_pass(self):
"""Nodes with no navigation targets should pass (backward compatible)."""
graph = GraphSpec(
id="test",
goal_id="test_goal",
entry_node="a",
nodes=[
NodeSpec(id="a", name="A", description="t", node_type="event_loop"),
NodeSpec(id="b", name="B", description="t", node_type="event_loop"),
],
edges=[
EdgeSpec(id="a_to_b", source="a", target="b"),
],
terminal_nodes=["b"],
)
errors = graph.validate()
nav_errors = [e for e in errors if "allowed_navigation_target" in e]
assert nav_errors == []
# ===========================================================================
# Executor integration: navigate_to routing
# ===========================================================================
class TestExecutorNavigation:
"""Tests for executor handling of navigate_to results."""
def _make_graph(self, a_targets: list[str] | None = None) -> GraphSpec:
"""Build a simple A -> B -> C graph where A can navigate."""
return GraphSpec(
id="test",
goal_id="test_goal",
entry_node="a",
nodes=[
NodeSpec(
id="a",
name="A",
description="first",
node_type="event_loop",
output_keys=["out"],
allowed_navigation_targets=a_targets or [],
max_node_visits=3,
),
NodeSpec(
id="b",
name="B",
description="second",
node_type="event_loop",
output_keys=["out"],
max_node_visits=3,
),
NodeSpec(
id="c",
name="C",
description="third",
node_type="event_loop",
output_keys=["out"],
),
],
edges=[
EdgeSpec(
id="a_to_b",
source="a",
target="b",
condition=EdgeCondition.ON_SUCCESS,
),
EdgeSpec(
id="b_to_c",
source="b",
target="c",
condition=EdgeCondition.ON_SUCCESS,
),
],
terminal_nodes=["c"],
)
def _make_runtime(self):
rt = MagicMock(spec=Runtime)
rt.start_run = MagicMock(return_value="run_1")
rt.decide = MagicMock(return_value="dec_1")
rt.record_outcome = MagicMock()
rt.end_run = MagicMock()
rt.report_problem = MagicMock()
rt.set_node = MagicMock()
return rt
@pytest.mark.asyncio
async def test_navigation_changes_next_node(self):
"""When a node returns next_node in allowed_navigation_targets,
executor should route to it."""
from framework.graph.executor import GraphExecutor
graph = self._make_graph(a_targets=["b", "c"])
call_count = {"a": 0, "b": 0, "c": 0}
class MockNode:
def __init__(self, node_id, next_target=None):
self.node_id = node_id
self.next_target = next_target
def validate_input(self, ctx):
return []
async def execute(self, ctx):
call_count[self.node_id] += 1
if self.next_target and call_count[self.node_id] == 1:
return NodeResult(
success=True,
output={"out": "navigated"},
next_node=self.next_target,
route_reason="user requested",
)
return NodeResult(success=True, output={"out": f"done_{self.node_id}"})
rt = self._make_runtime()
executor = GraphExecutor(
runtime=rt,
node_registry={
"a": MockNode("a", next_target="c"), # Skip B, go to C
"b": MockNode("b"),
"c": MockNode("c"),
},
)
_result = await executor.execute(
graph=graph,
goal=MagicMock(name="test_goal", description="test", success_criteria="test"),
)
# A should navigate directly to C, skipping B
assert call_count["a"] == 1
assert call_count["b"] == 0
assert call_count["c"] == 1
@pytest.mark.asyncio
async def test_unauthorized_navigation_falls_through(self):
"""When next_node is not in allowed_navigation_targets,
executor should fall through to normal edge evaluation."""
from framework.graph.executor import GraphExecutor
# A can only navigate to B, not C
graph = self._make_graph(a_targets=["b"])
call_count = {"a": 0, "b": 0, "c": 0}
class MockNode:
def __init__(self, node_id):
self.node_id = node_id
def validate_input(self, ctx):
return []
async def execute(self, ctx):
call_count[self.node_id] += 1
if self.node_id == "a":
# Try to navigate to C (not allowed)
return NodeResult(
success=True,
output={"out": "try_c"},
next_node="c",
)
return NodeResult(success=True, output={"out": f"done_{self.node_id}"})
rt = self._make_runtime()
executor = GraphExecutor(
runtime=rt,
node_registry={
"a": MockNode("a"),
"b": MockNode("b"),
"c": MockNode("c"),
},
)
_result = await executor.execute(
graph=graph,
goal=MagicMock(name="test_goal", description="test", success_criteria="test"),
)
# A's navigation to C should be blocked; normal edge A->B fires
assert call_count["a"] == 1
assert call_count["b"] == 1
assert call_count["c"] == 1 # B->C edge fires normally