Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8db2e72f5e |
@@ -519,7 +519,7 @@ NODE_SPECS = {
|
||||
"load_data",
|
||||
"save_data",
|
||||
],
|
||||
max_node_visits=3,
|
||||
max_node_visits=5,
|
||||
system_prompt=(
|
||||
"You are a Contact Extractor agent. Your inputs 'user_profiles' and "
|
||||
"'relevance_scores' are filenames pointing to JSON data files.\n\n"
|
||||
@@ -555,8 +555,9 @@ NODE_SPECS = {
|
||||
input_keys=["contact_list"],
|
||||
output_keys=["approved_contacts", "redo_extraction"],
|
||||
nullable_output_keys=["approved_contacts", "redo_extraction"],
|
||||
max_node_visits=3,
|
||||
max_node_visits=5,
|
||||
tools=["load_data", "save_data"],
|
||||
allowed_navigation_targets=["extractor"],
|
||||
system_prompt=(
|
||||
"You are the Review agent at a human checkpoint. Your input 'contact_list' "
|
||||
"is a filename pointing to a JSON data file.\n\n"
|
||||
@@ -570,6 +571,10 @@ NODE_SPECS = {
|
||||
" set_output(key='approved_contacts', value='approved_contacts.json')\n\n"
|
||||
" IF REDO REQUESTED: call:\n"
|
||||
" set_output(key='redo_extraction', value='true')\n\n"
|
||||
"NAVIGATION:\n"
|
||||
"If the operator wants to go back and re-extract contacts with different "
|
||||
"criteria, you can use navigate_to(target='extractor') instead of "
|
||||
"set_output(key='redo_extraction').\n\n"
|
||||
"CRITICAL RULE: Call set_output EXACTLY ONCE with EXACTLY ONE key.\n"
|
||||
"NEVER call set_output twice. NEVER set both keys.\n"
|
||||
"The two output keys are mutually exclusive — setting both will cause an error."
|
||||
@@ -631,8 +636,9 @@ NODE_SPECS = {
|
||||
input_keys=["draft_emails"],
|
||||
output_keys=["approved_emails", "revise_campaigns"],
|
||||
nullable_output_keys=["approved_emails", "revise_campaigns"],
|
||||
max_node_visits=3,
|
||||
max_node_visits=5,
|
||||
tools=["load_data", "save_data"],
|
||||
allowed_navigation_targets=["review", "campaign_builder"],
|
||||
system_prompt=(
|
||||
"You are the Approval agent at the final human checkpoint. Your input "
|
||||
"'draft_emails' is a filename pointing to a JSON data file.\n\n"
|
||||
@@ -645,6 +651,10 @@ NODE_SPECS = {
|
||||
" set_output(key='approved_emails', value='approved_emails.json')\n\n"
|
||||
" IF REVISION REQUESTED: call:\n"
|
||||
" set_output(key='revise_campaigns', value='true')\n\n"
|
||||
"NAVIGATION:\n"
|
||||
"If the operator wants to go back to review the contact list, use "
|
||||
"navigate_to(target='review'). If they want to revise the email drafts, "
|
||||
"use navigate_to(target='campaign_builder').\n\n"
|
||||
"CRITICAL RULE: Call set_output EXACTLY ONCE with EXACTLY ONE key.\n"
|
||||
"NEVER call set_output twice. NEVER set both keys.\n"
|
||||
"The two output keys are mutually exclusive — setting both will cause an error."
|
||||
|
||||
@@ -15,7 +15,7 @@ You cannot skip steps or bypass validation.
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -26,7 +26,7 @@ from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
|
||||
class BuildPhase(str, Enum):
|
||||
class BuildPhase(StrEnum):
|
||||
"""Current phase of the build process."""
|
||||
|
||||
INIT = "init" # Just started
|
||||
|
||||
@@ -8,7 +8,7 @@ containing one or more keys (e.g., api_key, access_token, refresh_token).
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@@ -19,7 +19,7 @@ def _utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class CredentialType(str, Enum):
|
||||
class CredentialType(StrEnum):
|
||||
"""Types of credentials the store can manage."""
|
||||
|
||||
API_KEY = "api_key"
|
||||
|
||||
@@ -11,11 +11,11 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TokenPlacement(str, Enum):
|
||||
class TokenPlacement(StrEnum):
|
||||
"""Where to place the access token in HTTP requests."""
|
||||
|
||||
HEADER_BEARER = "header_bearer"
|
||||
|
||||
@@ -21,7 +21,7 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
|
||||
given the current goal, context, and execution state.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -29,7 +29,7 @@ from pydantic import BaseModel, Field
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
|
||||
class EdgeCondition(str, Enum):
|
||||
class EdgeCondition(StrEnum):
|
||||
"""When an edge should be traversed."""
|
||||
|
||||
ALWAYS = "always" # Always after source completes
|
||||
@@ -644,4 +644,15 @@ class GraphSpec(BaseModel):
|
||||
else:
|
||||
seen_keys[key] = node_id
|
||||
|
||||
# Validate allowed_navigation_targets reference existing nodes
|
||||
node_ids = {n.id for n in self.nodes}
|
||||
for node in self.nodes:
|
||||
targets = getattr(node, "allowed_navigation_targets", [])
|
||||
for target_id in targets:
|
||||
if target_id not in node_ids:
|
||||
errors.append(
|
||||
f"Node '{node.id}' has allowed_navigation_target "
|
||||
f"'{target_id}' which doesn't exist in the graph"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
@@ -180,6 +180,9 @@ class EventLoopNode(NodeProtocol):
|
||||
# Client-facing input blocking state
|
||||
self._input_ready = asyncio.Event()
|
||||
self._shutdown = False
|
||||
# Dynamic navigation state (hybrid execution pattern)
|
||||
self._navigate_target: str | None = None
|
||||
self._navigate_reason: str | None = None
|
||||
|
||||
def validate_input(self, ctx: NodeContext) -> list[str]:
|
||||
"""Validate hard requirements only.
|
||||
@@ -209,6 +212,10 @@ class EventLoopNode(NodeProtocol):
|
||||
if ctx.llm is None:
|
||||
return NodeResult(success=False, error="LLM provider not available")
|
||||
|
||||
# 1b. Reset navigation state
|
||||
self._navigate_target = None
|
||||
self._navigate_reason = None
|
||||
|
||||
# 2. Restore or create new conversation + accumulator
|
||||
conversation, accumulator, start_iteration = await self._restore(ctx)
|
||||
if conversation is None:
|
||||
@@ -226,11 +233,16 @@ class EventLoopNode(NodeProtocol):
|
||||
if initial_message:
|
||||
await conversation.add_user_message(initial_message)
|
||||
|
||||
# 3. Build tool list: node tools + synthetic set_output tool
|
||||
# 3. Build tool list: node tools + synthetic tools
|
||||
tools = list(ctx.available_tools)
|
||||
set_output_tool = self._build_set_output_tool(ctx.node_spec.output_keys)
|
||||
if set_output_tool:
|
||||
tools.append(set_output_tool)
|
||||
navigate_to_tool = self._build_navigate_to_tool(
|
||||
getattr(ctx.node_spec, "allowed_navigation_targets", []),
|
||||
)
|
||||
if navigate_to_tool:
|
||||
tools.append(navigate_to_tool)
|
||||
|
||||
logger.info(
|
||||
"[%s] Tools available (%d): %s | client_facing=%s | judge=%s",
|
||||
@@ -321,6 +333,32 @@ class EventLoopNode(NodeProtocol):
|
||||
# 6g. Write cursor checkpoint
|
||||
await self._write_cursor(ctx, conversation, accumulator, iteration)
|
||||
|
||||
# 6g'. Check if navigate_to was called
|
||||
if self._navigate_target:
|
||||
target = self._navigate_target
|
||||
reason = self._navigate_reason or ""
|
||||
self._navigate_target = None
|
||||
self._navigate_reason = None
|
||||
|
||||
logger.info(
|
||||
"[%s] iter=%d: navigate_to '%s' (reason: %s)",
|
||||
node_id,
|
||||
iteration,
|
||||
target,
|
||||
reason,
|
||||
)
|
||||
|
||||
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output=accumulator.to_dict(),
|
||||
next_node=target,
|
||||
route_reason=f"User navigation: {reason}",
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
# 6h. Client-facing input wait
|
||||
logger.info(
|
||||
"[%s] iter=%d: 6h check — client_facing=%s, tool_results=%d",
|
||||
@@ -641,6 +679,19 @@ class EventLoopNode(NodeProtocol):
|
||||
# Async write-through for set_output
|
||||
if not result.is_error:
|
||||
await accumulator.set(tc.tool_input["key"], tc.tool_input["value"])
|
||||
elif tc.tool_name == "navigate_to":
|
||||
result = self._handle_navigate_to(
|
||||
tc.tool_input,
|
||||
getattr(ctx.node_spec, "allowed_navigation_targets", []),
|
||||
)
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
)
|
||||
if not result.is_error:
|
||||
self._navigate_target = tc.tool_input["target"]
|
||||
self._navigate_reason = tc.tool_input.get("reason", "")
|
||||
else:
|
||||
# Execute real tool
|
||||
result = await self._execute_tool(tc)
|
||||
@@ -806,6 +857,66 @@ class EventLoopNode(NodeProtocol):
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Dynamic navigation (navigate_to synthetic tool)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def _build_navigate_to_tool(self, allowed_targets: list[str]) -> Tool | None:
|
||||
"""Build the synthetic navigate_to tool for dynamic graph navigation.
|
||||
|
||||
Only created when allowed_targets is non-empty.
|
||||
"""
|
||||
if not allowed_targets:
|
||||
return None
|
||||
targets_str = ", ".join(f"'{t}'" for t in allowed_targets)
|
||||
return Tool(
|
||||
name="navigate_to",
|
||||
description=(
|
||||
"Navigate the pipeline to a different stage. Use this when the user "
|
||||
"asks to go back to a previous step or move to a different stage. "
|
||||
f"Allowed targets: {targets_str}. "
|
||||
"IMPORTANT: Calling this exits the current stage immediately. "
|
||||
"Any outputs you've set with set_output will be discarded."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": f"Node ID to navigate to. Must be one of: {allowed_targets}",
|
||||
"enum": allowed_targets,
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Brief reason for navigation.",
|
||||
},
|
||||
},
|
||||
"required": ["target", "reason"],
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_navigate_to(
|
||||
self,
|
||||
tool_input: dict[str, Any],
|
||||
allowed_targets: list[str],
|
||||
) -> ToolResult:
|
||||
"""Handle navigate_to tool call. Returns ToolResult (sync)."""
|
||||
target = tool_input.get("target", "")
|
||||
reason = tool_input.get("reason", "")
|
||||
|
||||
if target not in allowed_targets:
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Invalid navigation target '{target}'. Allowed: {allowed_targets}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Navigating to '{target}'. Reason: {reason}",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Judge evaluation
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
@@ -255,6 +255,7 @@ class GraphExecutor:
|
||||
total_latency = 0
|
||||
node_retry_counts: dict[str, int] = {} # Track retries per node
|
||||
node_visit_counts: dict[str, int] = {} # Track visits for feedback loops
|
||||
_is_retry = False # True when looping back for a retry (not a new visit)
|
||||
|
||||
# Determine entry point (may differ if resuming)
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
@@ -284,7 +285,11 @@ class GraphExecutor:
|
||||
raise RuntimeError(f"Node not found: {current_node_id}")
|
||||
|
||||
# Enforce max_node_visits (feedback/callback edge support)
|
||||
node_visit_counts[current_node_id] = node_visit_counts.get(current_node_id, 0) + 1
|
||||
# Don't increment visit count on retries — retries are not new visits
|
||||
if not _is_retry:
|
||||
cnt = node_visit_counts.get(current_node_id, 0) + 1
|
||||
node_visit_counts[current_node_id] = cnt
|
||||
_is_retry = False
|
||||
max_visits = getattr(node_spec, "max_node_visits", 1)
|
||||
if max_visits > 0 and node_visit_counts[current_node_id] > max_visits:
|
||||
self.logger.warning(
|
||||
@@ -433,6 +438,7 @@ class GraphExecutor:
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
_is_retry = True
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - fail the execution
|
||||
@@ -524,11 +530,39 @@ class GraphExecutor:
|
||||
break
|
||||
|
||||
# Determine next node
|
||||
_navigated = False
|
||||
if result.next_node:
|
||||
# Router explicitly set next node
|
||||
self.logger.info(f" → Router directing to: {result.next_node}")
|
||||
current_node_id = result.next_node
|
||||
else:
|
||||
# Dynamic routing: router or navigate_to
|
||||
_nav_valid = True
|
||||
|
||||
if node_spec.node_type != "router":
|
||||
# Validate navigation target for non-router nodes
|
||||
allowed = getattr(node_spec, "allowed_navigation_targets", [])
|
||||
target_spec = graph.get_node(result.next_node)
|
||||
|
||||
if target_spec is None:
|
||||
self.logger.warning(
|
||||
f" ! Navigation target '{result.next_node}' "
|
||||
f"not found in graph"
|
||||
)
|
||||
_nav_valid = False
|
||||
elif allowed and result.next_node not in allowed:
|
||||
self.logger.warning(
|
||||
f" ! Navigation to '{result.next_node}' blocked: "
|
||||
f"not in allowed_navigation_targets {allowed}"
|
||||
)
|
||||
_nav_valid = False
|
||||
|
||||
if _nav_valid:
|
||||
reason = result.route_reason or ""
|
||||
self.logger.info(
|
||||
f" → Navigating to: {result.next_node}"
|
||||
+ (f" ({reason})" if reason else "")
|
||||
)
|
||||
current_node_id = result.next_node
|
||||
_navigated = True
|
||||
|
||||
if not _navigated:
|
||||
# Get all traversable edges for fan-out detection
|
||||
traversable_edges = self._get_all_traversable_edges(
|
||||
graph=graph,
|
||||
|
||||
@@ -12,13 +12,13 @@ Goals are:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GoalStatus(str, Enum):
|
||||
class GoalStatus(StrEnum):
|
||||
"""Lifecycle status of a goal."""
|
||||
|
||||
DRAFT = "draft" # Being defined
|
||||
|
||||
@@ -6,11 +6,11 @@ where agents need to gather input from humans.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class HITLInputType(str, Enum):
|
||||
class HITLInputType(StrEnum):
|
||||
"""Type of input expected from human."""
|
||||
|
||||
FREE_TEXT = "free_text" # Open-ended text response
|
||||
|
||||
@@ -237,6 +237,12 @@ class NodeSpec(BaseModel):
|
||||
description="If True, this node streams output to the end user and can request input.",
|
||||
)
|
||||
|
||||
# Dynamic navigation (hybrid execution pattern)
|
||||
allowed_navigation_targets: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Node IDs this node can navigate to via navigate_to tool. Empty = disabled.",
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ The Plan is the contract between the external planner and the executor:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
class ActionType(StrEnum):
|
||||
"""Types of actions a PlanStep can perform."""
|
||||
|
||||
LLM_CALL = "llm_call" # Call LLM for generation
|
||||
@@ -27,7 +27,7 @@ class ActionType(str, Enum):
|
||||
CODE_EXECUTION = "code_execution" # Execute dynamic code (sandboxed)
|
||||
|
||||
|
||||
class StepStatus(str, Enum):
|
||||
class StepStatus(StrEnum):
|
||||
"""Status of a plan step."""
|
||||
|
||||
PENDING = "pending"
|
||||
@@ -56,7 +56,7 @@ class StepStatus(str, Enum):
|
||||
return self == StepStatus.COMPLETED
|
||||
|
||||
|
||||
class ApprovalDecision(str, Enum):
|
||||
class ApprovalDecision(StrEnum):
|
||||
"""Human decision on a step requiring approval."""
|
||||
|
||||
APPROVE = "approve" # Execute as planned
|
||||
@@ -91,7 +91,7 @@ class ApprovalResult(BaseModel):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class JudgmentAction(str, Enum):
|
||||
class JudgmentAction(StrEnum):
|
||||
"""Actions the judge can take after evaluating a step."""
|
||||
|
||||
ACCEPT = "accept" # Step completed successfully, continue
|
||||
@@ -423,7 +423,7 @@ class Plan(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
class ExecutionStatus(StrEnum):
|
||||
"""Status of plan execution."""
|
||||
|
||||
COMPLETED = "completed"
|
||||
|
||||
@@ -12,13 +12,13 @@ import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
class EventType(StrEnum):
|
||||
"""Types of events that can be published."""
|
||||
|
||||
# Execution lifecycle
|
||||
|
||||
@@ -11,13 +11,13 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IsolationLevel(str, Enum):
|
||||
class IsolationLevel(StrEnum):
|
||||
"""State isolation level for concurrent executions."""
|
||||
|
||||
ISOLATED = "isolated" # Private state per execution
|
||||
@@ -25,7 +25,7 @@ class IsolationLevel(str, Enum):
|
||||
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
|
||||
|
||||
|
||||
class StateScope(str, Enum):
|
||||
class StateScope(StrEnum):
|
||||
"""Scope for state operations."""
|
||||
|
||||
EXECUTION = "execution" # Local to a single execution
|
||||
|
||||
@@ -10,13 +10,13 @@ This is MORE important than actions because:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
class DecisionType(StrEnum):
|
||||
"""Types of decisions an agent can make."""
|
||||
|
||||
TOOL_SELECTION = "tool_selection" # Which tool to use
|
||||
|
||||
@@ -6,7 +6,7 @@ summaries and metrics that Builder needs to understand what happened.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
@@ -14,7 +14,7 @@ from pydantic import BaseModel, Field, computed_field
|
||||
from framework.schemas.decision import Decision, Outcome
|
||||
|
||||
|
||||
class RunStatus(str, Enum):
|
||||
class RunStatus(StrEnum):
|
||||
"""Status of a run."""
|
||||
|
||||
RUNNING = "running"
|
||||
|
||||
@@ -6,13 +6,13 @@ programmatic/MCP-based approval.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalAction(str, Enum):
|
||||
class ApprovalAction(StrEnum):
|
||||
"""Actions a user can take on a generated test."""
|
||||
|
||||
APPROVE = "approve" # Accept as-is
|
||||
|
||||
@@ -6,13 +6,13 @@ but require mandatory user approval before being stored.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
class ApprovalStatus(StrEnum):
|
||||
"""Status of user approval for a generated test."""
|
||||
|
||||
PENDING = "pending" # Awaiting user review
|
||||
@@ -21,7 +21,7 @@ class ApprovalStatus(str, Enum):
|
||||
REJECTED = "rejected" # User declined (with reason)
|
||||
|
||||
|
||||
class TestType(str, Enum):
|
||||
class TestType(StrEnum):
|
||||
"""Type of test based on what it validates."""
|
||||
|
||||
__test__ = False # Not a pytest test class
|
||||
|
||||
@@ -6,13 +6,13 @@ categorization for guiding iteration strategy.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorCategory(str, Enum):
|
||||
class ErrorCategory(StrEnum):
|
||||
"""
|
||||
Category of test failure for guiding iteration.
|
||||
|
||||
|
||||
@@ -0,0 +1,574 @@
|
||||
"""Tests for the navigate_to dynamic routing feature (hybrid execution pattern).
|
||||
|
||||
Tests the navigate_to synthetic tool on EventLoopNode and the executor's
|
||||
validation of navigation targets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
LoopConfig,
|
||||
)
|
||||
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM (same pattern as test_event_loop_node.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences."""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def text_scenario(text: str) -> list:
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def tool_call_scenario(tool_name: str, tool_input: dict, tool_use_id: str = "call_1") -> list:
|
||||
return [
|
||||
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory():
|
||||
return SharedMemory()
|
||||
|
||||
|
||||
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None):
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_spec.id,
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data=input_data or {},
|
||||
llm=llm,
|
||||
available_tools=tools or [],
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# navigate_to tool availability
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNavigateToToolAvailability:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_without_targets(self, runtime, memory):
|
||||
"""navigate_to tool should NOT appear when allowed_navigation_targets is empty."""
|
||||
spec = NodeSpec(
|
||||
id="node_a",
|
||||
name="Node A",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
allowed_navigation_targets=[],
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
text_scenario("Done."),
|
||||
]
|
||||
)
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
await node.execute(ctx)
|
||||
|
||||
assert llm.stream_calls, "LLM should have been called"
|
||||
tools_sent = llm.stream_calls[0]["tools"] or []
|
||||
tool_names = [t.name for t in tools_sent]
|
||||
assert "navigate_to" not in tool_names
|
||||
assert "set_output" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_present_with_targets(self, runtime, memory):
|
||||
"""navigate_to tool should appear when allowed_navigation_targets is set."""
|
||||
spec = NodeSpec(
|
||||
id="node_a",
|
||||
name="Node A",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
allowed_navigation_targets=["node_b", "node_c"],
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
text_scenario("Done."),
|
||||
]
|
||||
)
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
await node.execute(ctx)
|
||||
|
||||
tools_sent = llm.stream_calls[0]["tools"] or []
|
||||
tool_names = [t.name for t in tools_sent]
|
||||
assert "navigate_to" in tool_names
|
||||
assert "set_output" in tool_names
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# navigate_to tool execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNavigateToExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_target_returns_next_node(self, runtime, memory):
|
||||
"""Calling navigate_to with a valid target should set NodeResult.next_node."""
|
||||
spec = NodeSpec(
|
||||
id="approval",
|
||||
name="Approval",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["approved", "revise"],
|
||||
nullable_output_keys=["approved", "revise"],
|
||||
allowed_navigation_targets=["review", "campaign"],
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario(
|
||||
"navigate_to",
|
||||
{"target": "review", "reason": "User wants to go back"},
|
||||
),
|
||||
]
|
||||
)
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.next_node == "review"
|
||||
assert "User wants to go back" in (result.route_reason or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_target_continues_loop(self, runtime, memory):
|
||||
"""Calling navigate_to with an invalid target should error and continue."""
|
||||
spec = NodeSpec(
|
||||
id="approval",
|
||||
name="Approval",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
allowed_navigation_targets=["review"],
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: try invalid target
|
||||
tool_call_scenario(
|
||||
"navigate_to",
|
||||
{"target": "nonexistent", "reason": "test"},
|
||||
),
|
||||
# Turn 2: set output normally
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
|
||||
# Turn 3: text -> implicit judge accepts
|
||||
text_scenario("Done."),
|
||||
]
|
||||
)
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.next_node is None
|
||||
assert result.output.get("result") == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_navigate_to_with_partial_outputs(self, runtime, memory):
|
||||
"""Outputs set before navigate_to should still be in the result dict
|
||||
but the node exits via navigation, not via normal completion."""
|
||||
spec = NodeSpec(
|
||||
id="approval",
|
||||
name="Approval",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
output_keys=["draft", "final"],
|
||||
nullable_output_keys=["draft", "final"],
|
||||
allowed_navigation_targets=["review"],
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: set one output
|
||||
tool_call_scenario(
|
||||
"set_output",
|
||||
{"key": "draft", "value": "v1"},
|
||||
tool_use_id="call_1",
|
||||
),
|
||||
# Turn 2: navigate away
|
||||
tool_call_scenario(
|
||||
"navigate_to",
|
||||
{"target": "review", "reason": "go back"},
|
||||
tool_use_id="call_2",
|
||||
),
|
||||
]
|
||||
)
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.next_node == "review"
|
||||
assert result.output.get("draft") == "v1"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# navigate_to handler unit tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNavigateToHandler:
|
||||
def test_valid_target(self):
|
||||
"""_handle_navigate_to should succeed for valid targets."""
|
||||
node = EventLoopNode()
|
||||
result = node._handle_navigate_to(
|
||||
{"target": "review", "reason": "go back"},
|
||||
allowed_targets=["review", "campaign"],
|
||||
)
|
||||
assert not result.is_error
|
||||
assert "review" in result.content
|
||||
|
||||
def test_invalid_target(self):
|
||||
"""_handle_navigate_to should error for invalid targets."""
|
||||
node = EventLoopNode()
|
||||
result = node._handle_navigate_to(
|
||||
{"target": "nonexistent", "reason": "test"},
|
||||
allowed_targets=["review", "campaign"],
|
||||
)
|
||||
assert result.is_error
|
||||
assert "Invalid" in result.content
|
||||
|
||||
def test_empty_target(self):
|
||||
"""_handle_navigate_to should error when target is empty."""
|
||||
node = EventLoopNode()
|
||||
result = node._handle_navigate_to(
|
||||
{"target": "", "reason": "test"},
|
||||
allowed_targets=["review"],
|
||||
)
|
||||
assert result.is_error
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# navigate_to tool builder
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNavigateToToolBuilder:
|
||||
def test_no_tool_with_empty_targets(self):
|
||||
"""Should return None when no targets allowed."""
|
||||
node = EventLoopNode()
|
||||
tool = node._build_navigate_to_tool([])
|
||||
assert tool is None
|
||||
|
||||
def test_tool_with_targets(self):
|
||||
"""Should return a Tool with enum of allowed targets."""
|
||||
node = EventLoopNode()
|
||||
tool = node._build_navigate_to_tool(["review", "campaign"])
|
||||
assert tool is not None
|
||||
assert tool.name == "navigate_to"
|
||||
assert tool.parameters["properties"]["target"]["enum"] == ["review", "campaign"]
|
||||
assert "required" in tool.parameters
|
||||
assert "target" in tool.parameters["required"]
|
||||
assert "reason" in tool.parameters["required"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# GraphSpec validation for navigation targets
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGraphSpecNavValidation:
|
||||
def test_valid_targets_pass(self):
|
||||
"""Navigation targets that reference existing nodes should pass validation."""
|
||||
graph = GraphSpec(
|
||||
id="test",
|
||||
goal_id="test_goal",
|
||||
entry_node="a",
|
||||
nodes=[
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="t",
|
||||
node_type="event_loop",
|
||||
allowed_navigation_targets=["b"],
|
||||
),
|
||||
NodeSpec(id="b", name="B", description="t", node_type="event_loop"),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b"),
|
||||
],
|
||||
terminal_nodes=["b"],
|
||||
)
|
||||
errors = graph.validate()
|
||||
nav_errors = [e for e in errors if "allowed_navigation_target" in e]
|
||||
assert nav_errors == []
|
||||
|
||||
def test_invalid_targets_fail(self):
|
||||
"""Navigation targets referencing non-existent nodes should fail."""
|
||||
graph = GraphSpec(
|
||||
id="test",
|
||||
goal_id="test_goal",
|
||||
entry_node="a",
|
||||
nodes=[
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="t",
|
||||
node_type="event_loop",
|
||||
allowed_navigation_targets=["nonexistent"],
|
||||
),
|
||||
NodeSpec(id="b", name="B", description="t", node_type="event_loop"),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b"),
|
||||
],
|
||||
terminal_nodes=["b"],
|
||||
)
|
||||
errors = graph.validate()
|
||||
nav_errors = [e for e in errors if "allowed_navigation_target" in e]
|
||||
assert len(nav_errors) == 1
|
||||
assert "nonexistent" in nav_errors[0]
|
||||
|
||||
def test_empty_targets_pass(self):
|
||||
"""Nodes with no navigation targets should pass (backward compatible)."""
|
||||
graph = GraphSpec(
|
||||
id="test",
|
||||
goal_id="test_goal",
|
||||
entry_node="a",
|
||||
nodes=[
|
||||
NodeSpec(id="a", name="A", description="t", node_type="event_loop"),
|
||||
NodeSpec(id="b", name="B", description="t", node_type="event_loop"),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b"),
|
||||
],
|
||||
terminal_nodes=["b"],
|
||||
)
|
||||
errors = graph.validate()
|
||||
nav_errors = [e for e in errors if "allowed_navigation_target" in e]
|
||||
assert nav_errors == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Executor integration: navigate_to routing
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExecutorNavigation:
|
||||
"""Tests for executor handling of navigate_to results."""
|
||||
|
||||
def _make_graph(self, a_targets: list[str] | None = None) -> GraphSpec:
|
||||
"""Build a simple A -> B -> C graph where A can navigate."""
|
||||
return GraphSpec(
|
||||
id="test",
|
||||
goal_id="test_goal",
|
||||
entry_node="a",
|
||||
nodes=[
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="first",
|
||||
node_type="event_loop",
|
||||
output_keys=["out"],
|
||||
allowed_navigation_targets=a_targets or [],
|
||||
max_node_visits=3,
|
||||
),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="second",
|
||||
node_type="event_loop",
|
||||
output_keys=["out"],
|
||||
max_node_visits=3,
|
||||
),
|
||||
NodeSpec(
|
||||
id="c",
|
||||
name="C",
|
||||
description="third",
|
||||
node_type="event_loop",
|
||||
output_keys=["out"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(
|
||||
id="a_to_b",
|
||||
source="a",
|
||||
target="b",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="b_to_c",
|
||||
source="b",
|
||||
target="c",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["c"],
|
||||
)
|
||||
|
||||
def _make_runtime(self):
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_navigation_changes_next_node(self):
|
||||
"""When a node returns next_node in allowed_navigation_targets,
|
||||
executor should route to it."""
|
||||
from framework.graph.executor import GraphExecutor
|
||||
|
||||
graph = self._make_graph(a_targets=["b", "c"])
|
||||
call_count = {"a": 0, "b": 0, "c": 0}
|
||||
|
||||
class MockNode:
|
||||
def __init__(self, node_id, next_target=None):
|
||||
self.node_id = node_id
|
||||
self.next_target = next_target
|
||||
|
||||
def validate_input(self, ctx):
|
||||
return []
|
||||
|
||||
async def execute(self, ctx):
|
||||
call_count[self.node_id] += 1
|
||||
if self.next_target and call_count[self.node_id] == 1:
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output={"out": "navigated"},
|
||||
next_node=self.next_target,
|
||||
route_reason="user requested",
|
||||
)
|
||||
return NodeResult(success=True, output={"out": f"done_{self.node_id}"})
|
||||
|
||||
rt = self._make_runtime()
|
||||
executor = GraphExecutor(
|
||||
runtime=rt,
|
||||
node_registry={
|
||||
"a": MockNode("a", next_target="c"), # Skip B, go to C
|
||||
"b": MockNode("b"),
|
||||
"c": MockNode("c"),
|
||||
},
|
||||
)
|
||||
|
||||
_result = await executor.execute(
|
||||
graph=graph,
|
||||
goal=MagicMock(name="test_goal", description="test", success_criteria="test"),
|
||||
)
|
||||
|
||||
# A should navigate directly to C, skipping B
|
||||
assert call_count["a"] == 1
|
||||
assert call_count["b"] == 0
|
||||
assert call_count["c"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_navigation_falls_through(self):
|
||||
"""When next_node is not in allowed_navigation_targets,
|
||||
executor should fall through to normal edge evaluation."""
|
||||
from framework.graph.executor import GraphExecutor
|
||||
|
||||
# A can only navigate to B, not C
|
||||
graph = self._make_graph(a_targets=["b"])
|
||||
call_count = {"a": 0, "b": 0, "c": 0}
|
||||
|
||||
class MockNode:
|
||||
def __init__(self, node_id):
|
||||
self.node_id = node_id
|
||||
|
||||
def validate_input(self, ctx):
|
||||
return []
|
||||
|
||||
async def execute(self, ctx):
|
||||
call_count[self.node_id] += 1
|
||||
if self.node_id == "a":
|
||||
# Try to navigate to C (not allowed)
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output={"out": "try_c"},
|
||||
next_node="c",
|
||||
)
|
||||
return NodeResult(success=True, output={"out": f"done_{self.node_id}"})
|
||||
|
||||
rt = self._make_runtime()
|
||||
executor = GraphExecutor(
|
||||
runtime=rt,
|
||||
node_registry={
|
||||
"a": MockNode("a"),
|
||||
"b": MockNode("b"),
|
||||
"c": MockNode("c"),
|
||||
},
|
||||
)
|
||||
|
||||
_result = await executor.execute(
|
||||
graph=graph,
|
||||
goal=MagicMock(name="test_goal", description="test", success_criteria="test"),
|
||||
)
|
||||
|
||||
# A's navigation to C should be blocked; normal edge A->B fires
|
||||
assert call_count["a"] == 1
|
||||
assert call_count["b"] == 1
|
||||
assert call_count["c"] == 1 # B->C edge fires normally
|
||||
Reference in New Issue
Block a user