feat: consolidate context building
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Graph structures: Goals, Nodes, Edges, and Execution."""
|
||||
|
||||
from framework.graph.context import GraphContext
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import ConversationStore, Message, NodeConversation
|
||||
from framework.graph.edge import DEFAULT_MAX_TOKENS, EdgeCondition, EdgeSpec, GraphSpec
|
||||
@@ -17,7 +18,6 @@ from framework.graph.worker_agent import (
|
||||
Activation,
|
||||
FanOutTag,
|
||||
FanOutTracker,
|
||||
GraphContext,
|
||||
WorkerAgent,
|
||||
WorkerCompletion,
|
||||
WorkerLifecycle,
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
"""Shared graph execution context helpers.
|
||||
|
||||
This module centralizes:
|
||||
- Graph-run shared state (`GraphContext`)
|
||||
- Scoped buffer permission shaping for a node
|
||||
- Per-node accounts prompt resolution
|
||||
- Canonical `NodeContext` construction
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import DataBuffer, NodeContext, NodeProtocol, NodeSpec
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphContext:
|
||||
"""Shared state for one graph execution run."""
|
||||
|
||||
graph: GraphSpec
|
||||
goal: Goal
|
||||
buffer: DataBuffer
|
||||
runtime: Runtime
|
||||
llm: Any # LLMProvider
|
||||
tools: list[Any] # list[Tool]
|
||||
tool_executor: Any # Callable
|
||||
event_bus: Any # GraphScopedEventBus
|
||||
execution_id: str
|
||||
stream_id: str
|
||||
run_id: str
|
||||
storage_path: Any # Path | None
|
||||
runtime_logger: Any = None
|
||||
node_registry: dict[str, NodeProtocol] = field(default_factory=dict)
|
||||
node_spec_registry: dict[str, NodeSpec] = field(default_factory=dict)
|
||||
parallel_config: Any = None # ParallelExecutionConfig | None
|
||||
is_continuous: bool = False
|
||||
continuous_conversation: Any = None
|
||||
cumulative_tools: list[Any] = field(default_factory=list)
|
||||
cumulative_tool_names: set[str] = field(default_factory=set)
|
||||
cumulative_output_keys: list[str] = field(default_factory=list)
|
||||
accounts_prompt: str = ""
|
||||
accounts_data: list[dict] | None = None
|
||||
tool_provider_map: dict[str, str] | None = None
|
||||
skills_catalog_prompt: str = ""
|
||||
protocols_prompt: str = ""
|
||||
skill_dirs: list[str] = field(default_factory=list)
|
||||
context_warn_ratio: float | None = None
|
||||
batch_init_nudge: str | None = None
|
||||
dynamic_tools_provider: Any = None
|
||||
dynamic_prompt_provider: Any = None
|
||||
iteration_metadata_provider: Any = None
|
||||
loop_config: dict[str, Any] = field(default_factory=dict)
|
||||
path: list[str] = field(default_factory=list)
|
||||
node_visit_counts: dict[str, int] = field(default_factory=dict)
|
||||
_path_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_visits_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
|
||||
|
||||
def build_scoped_buffer(buffer: DataBuffer, node_spec: NodeSpec) -> DataBuffer:
|
||||
"""Create a node-scoped buffer view.
|
||||
|
||||
When permissions are already restricted, auto-include framework-managed
|
||||
`_`-prefixed keys used by the default skill protocols.
|
||||
"""
|
||||
|
||||
read_keys = list(node_spec.input_keys)
|
||||
write_keys = list(node_spec.output_keys)
|
||||
|
||||
if read_keys or write_keys:
|
||||
from framework.skills.defaults import DATA_BUFFER_KEYS as _skill_keys
|
||||
|
||||
existing_underscore = [k for k in buffer._data if k.startswith("_")]
|
||||
extra_keys = set(_skill_keys) | set(existing_underscore)
|
||||
|
||||
for key in extra_keys:
|
||||
if read_keys and key not in read_keys:
|
||||
read_keys.append(key)
|
||||
if write_keys and key not in write_keys:
|
||||
write_keys.append(key)
|
||||
|
||||
return buffer.with_permissions(read_keys=read_keys, write_keys=write_keys)
|
||||
|
||||
|
||||
def build_node_accounts_prompt(
|
||||
*,
|
||||
accounts_prompt: str,
|
||||
accounts_data: list[dict] | None,
|
||||
tool_provider_map: dict[str, str] | None,
|
||||
node_tool_names: list[str] | None,
|
||||
fallback_to_default: bool = False,
|
||||
) -> str:
|
||||
"""Resolve the accounts prompt for one node."""
|
||||
|
||||
resolved = accounts_prompt
|
||||
if accounts_data and tool_provider_map:
|
||||
from framework.graph.prompt_composer import build_accounts_prompt
|
||||
|
||||
filtered = build_accounts_prompt(
|
||||
accounts_data,
|
||||
tool_provider_map,
|
||||
node_tool_names=node_tool_names,
|
||||
)
|
||||
if filtered or not fallback_to_default:
|
||||
resolved = filtered
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_available_tools(
|
||||
*,
|
||||
node_spec: NodeSpec,
|
||||
tools: list[Any],
|
||||
override_tools: list[Any] | None,
|
||||
) -> list[Any]:
|
||||
"""Select tools available to the current node."""
|
||||
|
||||
if override_tools is not None:
|
||||
return list(override_tools)
|
||||
|
||||
if not node_spec.tools:
|
||||
return []
|
||||
|
||||
return [tool for tool in tools if tool.name in node_spec.tools]
|
||||
|
||||
|
||||
def _derive_input_data(buffer: DataBuffer, input_keys: list[str]) -> dict[str, Any]:
|
||||
"""Collect node inputs from the shared buffer."""
|
||||
|
||||
input_data: dict[str, Any] = {}
|
||||
for key in input_keys:
|
||||
value = buffer.read(key)
|
||||
if value is not None:
|
||||
input_data[key] = value
|
||||
return input_data
|
||||
|
||||
|
||||
def build_node_context(
|
||||
*,
|
||||
runtime: Runtime,
|
||||
node_spec: NodeSpec,
|
||||
buffer: DataBuffer,
|
||||
goal: Goal,
|
||||
llm: Any,
|
||||
tools: list[Any],
|
||||
max_tokens: int,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
derive_input_data_from_buffer: bool = False,
|
||||
runtime_logger: Any = None,
|
||||
pause_event: Any = None,
|
||||
continuous_mode: bool = False,
|
||||
inherited_conversation: Any = None,
|
||||
override_tools: list[Any] | None = None,
|
||||
cumulative_output_keys: list[str] | None = None,
|
||||
event_triggered: bool = False,
|
||||
accounts_prompt: str = "",
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
fallback_to_default_accounts_prompt: bool = False,
|
||||
identity_prompt: str = "",
|
||||
narrative: str = "",
|
||||
execution_id: str = "",
|
||||
run_id: str = "",
|
||||
stream_id: str = "",
|
||||
node_registry: dict[str, NodeSpec] | None = None,
|
||||
all_tools: list[Any] | None = None,
|
||||
shared_node_registry: dict[str, NodeProtocol] | None = None,
|
||||
dynamic_tools_provider: Any = None,
|
||||
dynamic_prompt_provider: Any = None,
|
||||
iteration_metadata_provider: Any = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
default_skill_warn_ratio: float | None = None,
|
||||
default_skill_batch_nudge: str | None = None,
|
||||
) -> NodeContext:
|
||||
"""Build a canonical `NodeContext` for graph execution."""
|
||||
|
||||
available_tools = _resolve_available_tools(
|
||||
node_spec=node_spec,
|
||||
tools=tools,
|
||||
override_tools=override_tools,
|
||||
)
|
||||
scoped_buffer = build_scoped_buffer(buffer, node_spec)
|
||||
node_accounts_prompt = build_node_accounts_prompt(
|
||||
accounts_prompt=accounts_prompt,
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
node_tool_names=node_spec.tools,
|
||||
fallback_to_default=fallback_to_default_accounts_prompt,
|
||||
)
|
||||
|
||||
resolved_input_data = (
|
||||
_derive_input_data(buffer, node_spec.input_keys)
|
||||
if input_data is None and derive_input_data_from_buffer
|
||||
else dict(input_data or {})
|
||||
)
|
||||
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_spec.id,
|
||||
node_spec=node_spec,
|
||||
buffer=scoped_buffer,
|
||||
input_data=resolved_input_data,
|
||||
llm=llm,
|
||||
available_tools=available_tools,
|
||||
goal_context=goal.to_prompt_context(),
|
||||
goal=goal,
|
||||
max_tokens=max_tokens,
|
||||
runtime_logger=runtime_logger,
|
||||
pause_event=pause_event,
|
||||
continuous_mode=continuous_mode,
|
||||
inherited_conversation=inherited_conversation,
|
||||
cumulative_output_keys=cumulative_output_keys or [],
|
||||
event_triggered=event_triggered,
|
||||
accounts_prompt=node_accounts_prompt,
|
||||
identity_prompt=identity_prompt,
|
||||
narrative=narrative,
|
||||
execution_id=execution_id,
|
||||
run_id=run_id,
|
||||
stream_id=stream_id,
|
||||
node_registry=node_registry or {},
|
||||
all_tools=list(all_tools or tools),
|
||||
shared_node_registry=shared_node_registry or {},
|
||||
dynamic_tools_provider=dynamic_tools_provider,
|
||||
dynamic_prompt_provider=dynamic_prompt_provider,
|
||||
iteration_metadata_provider=iteration_metadata_provider,
|
||||
skills_catalog_prompt=skills_catalog_prompt,
|
||||
protocols_prompt=protocols_prompt,
|
||||
skill_dirs=list(skill_dirs or []),
|
||||
default_skill_warn_ratio=default_skill_warn_ratio,
|
||||
default_skill_batch_nudge=default_skill_batch_nudge,
|
||||
)
|
||||
|
||||
|
||||
def build_node_context_from_graph_context(
|
||||
graph_context: GraphContext,
|
||||
*,
|
||||
node_spec: NodeSpec,
|
||||
pause_event: Any = None,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
derive_input_data_from_buffer: bool = True,
|
||||
override_tools: list[Any] | None = None,
|
||||
inherited_conversation: Any = None,
|
||||
cumulative_output_keys: list[str] | None = None,
|
||||
event_triggered: bool = False,
|
||||
identity_prompt: str | None = None,
|
||||
narrative: str = "",
|
||||
node_registry: dict[str, NodeSpec] | None = None,
|
||||
fallback_to_default_accounts_prompt: bool = True,
|
||||
) -> NodeContext:
|
||||
"""Build `NodeContext` using shared graph-run state."""
|
||||
|
||||
gc = graph_context
|
||||
resolved_override_tools = override_tools
|
||||
if resolved_override_tools is None and gc.is_continuous and gc.cumulative_tools:
|
||||
resolved_override_tools = list(gc.cumulative_tools)
|
||||
|
||||
resolved_inherited_conversation = inherited_conversation
|
||||
if resolved_inherited_conversation is None and gc.is_continuous:
|
||||
resolved_inherited_conversation = gc.continuous_conversation
|
||||
|
||||
resolved_output_keys = cumulative_output_keys
|
||||
if resolved_output_keys is None and gc.is_continuous:
|
||||
resolved_output_keys = list(gc.cumulative_output_keys)
|
||||
|
||||
return build_node_context(
|
||||
runtime=gc.runtime,
|
||||
node_spec=node_spec,
|
||||
buffer=gc.buffer,
|
||||
goal=gc.goal,
|
||||
llm=gc.llm,
|
||||
tools=gc.tools,
|
||||
max_tokens=gc.graph.max_tokens,
|
||||
input_data=input_data,
|
||||
derive_input_data_from_buffer=derive_input_data_from_buffer,
|
||||
runtime_logger=gc.runtime_logger,
|
||||
pause_event=pause_event,
|
||||
continuous_mode=gc.is_continuous,
|
||||
inherited_conversation=resolved_inherited_conversation,
|
||||
override_tools=resolved_override_tools,
|
||||
cumulative_output_keys=resolved_output_keys,
|
||||
event_triggered=event_triggered,
|
||||
accounts_prompt=gc.accounts_prompt,
|
||||
accounts_data=gc.accounts_data,
|
||||
tool_provider_map=gc.tool_provider_map,
|
||||
fallback_to_default_accounts_prompt=fallback_to_default_accounts_prompt,
|
||||
identity_prompt=identity_prompt if identity_prompt is not None else getattr(gc.graph, "identity_prompt", "") or "",
|
||||
narrative=narrative,
|
||||
execution_id=gc.execution_id,
|
||||
run_id=gc.run_id,
|
||||
stream_id=gc.stream_id,
|
||||
node_registry=node_registry or gc.node_spec_registry,
|
||||
all_tools=gc.tools,
|
||||
shared_node_registry=gc.node_registry,
|
||||
dynamic_tools_provider=gc.dynamic_tools_provider,
|
||||
dynamic_prompt_provider=gc.dynamic_prompt_provider,
|
||||
iteration_metadata_provider=gc.iteration_metadata_provider,
|
||||
skills_catalog_prompt=gc.skills_catalog_prompt,
|
||||
protocols_prompt=gc.protocols_prompt,
|
||||
skill_dirs=gc.skill_dirs,
|
||||
default_skill_warn_ratio=gc.context_warn_ratio,
|
||||
default_skill_batch_nudge=gc.batch_init_nudge,
|
||||
)
|
||||
@@ -17,11 +17,11 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.context import GraphContext, build_node_context
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.conversation import LEGACY_RUN_ID, get_run_cursor
|
||||
from framework.graph.node import (
|
||||
NodeContext,
|
||||
NodeProtocol,
|
||||
NodeResult,
|
||||
NodeSpec,
|
||||
@@ -707,112 +707,6 @@ class GraphExecutor:
|
||||
ToolRegistry.reset_execution_context(_ctx_token)
|
||||
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
node_spec: NodeSpec,
|
||||
buffer: DataBuffer,
|
||||
goal: Goal,
|
||||
input_data: dict[str, Any],
|
||||
max_tokens: int = 4096,
|
||||
continuous_mode: bool = False,
|
||||
inherited_conversation: Any = None,
|
||||
override_tools: list | None = None,
|
||||
cumulative_output_keys: list[str] | None = None,
|
||||
event_triggered: bool = False,
|
||||
identity_prompt: str = "",
|
||||
narrative: str = "",
|
||||
node_registry: dict[str, NodeSpec] | None = None,
|
||||
graph: "GraphSpec | None" = None,
|
||||
) -> NodeContext:
|
||||
"""Build execution context for a node."""
|
||||
# Filter tools to those available to this node
|
||||
if override_tools is not None:
|
||||
# Continuous mode: use cumulative tool set
|
||||
available_tools = list(override_tools)
|
||||
else:
|
||||
available_tools = []
|
||||
if node_spec.tools:
|
||||
available_tools = [t for t in self.tools if t.name in node_spec.tools]
|
||||
|
||||
# Create scoped buffer view.
|
||||
# When permissions are restricted (non-empty key lists), auto-include
|
||||
# _-prefixed keys used by default skill protocols so agents can read/write
|
||||
# operational state (e.g. _working_notes, _batch_ledger) regardless of
|
||||
# what the node declares. When key lists are empty (unrestricted), leave
|
||||
# unchanged — empty means "allow all".
|
||||
read_keys = list(node_spec.input_keys)
|
||||
write_keys = list(node_spec.output_keys)
|
||||
# Only extend lists that were already restricted (non-empty).
|
||||
# Empty means "allow all" — adding keys would accidentally
|
||||
# activate the permission check and block legitimate reads/writes.
|
||||
if read_keys or write_keys:
|
||||
from framework.skills.defaults import DATA_BUFFER_KEYS as _skill_keys
|
||||
|
||||
existing_underscore = [k for k in buffer._data if k.startswith("_")]
|
||||
extra_keys = set(_skill_keys) | set(existing_underscore)
|
||||
# Only inject into read_keys when it was already non-empty — an empty
|
||||
# read_keys means "allow all reads" and injecting skill keys would
|
||||
# inadvertently restrict reads to skill keys only.
|
||||
for k in extra_keys:
|
||||
if read_keys and k not in read_keys:
|
||||
read_keys.append(k)
|
||||
if write_keys and k not in write_keys:
|
||||
write_keys.append(k)
|
||||
|
||||
scoped_buffer = buffer.with_permissions(
|
||||
read_keys=read_keys,
|
||||
write_keys=write_keys,
|
||||
)
|
||||
|
||||
# Build per-node accounts prompt (filtered to this node's tools)
|
||||
node_accounts_prompt = self.accounts_prompt
|
||||
if self.accounts_data and self.tool_provider_map:
|
||||
from framework.graph.prompt_composer import build_accounts_prompt
|
||||
|
||||
node_accounts_prompt = build_accounts_prompt(
|
||||
self.accounts_data,
|
||||
self.tool_provider_map,
|
||||
node_tool_names=node_spec.tools,
|
||||
)
|
||||
|
||||
goal_context = goal.to_prompt_context()
|
||||
|
||||
return NodeContext(
|
||||
runtime=self.runtime,
|
||||
node_id=node_spec.id,
|
||||
node_spec=node_spec,
|
||||
buffer=scoped_buffer,
|
||||
input_data=input_data,
|
||||
llm=self.llm,
|
||||
available_tools=available_tools,
|
||||
goal_context=goal_context,
|
||||
goal=goal, # Pass Goal object for LLM-powered routers
|
||||
max_tokens=max_tokens,
|
||||
runtime_logger=self.runtime_logger,
|
||||
pause_event=self._pause_requested, # Pass pause event for granular control
|
||||
continuous_mode=continuous_mode,
|
||||
inherited_conversation=inherited_conversation,
|
||||
cumulative_output_keys=cumulative_output_keys or [],
|
||||
event_triggered=event_triggered,
|
||||
accounts_prompt=node_accounts_prompt,
|
||||
identity_prompt=identity_prompt,
|
||||
narrative=narrative,
|
||||
execution_id=self._execution_id,
|
||||
run_id=self._run_id,
|
||||
stream_id=self._stream_id,
|
||||
node_registry=node_registry or {},
|
||||
all_tools=list(self.tools), # Full catalog for subagent tool resolution
|
||||
shared_node_registry=self.node_registry, # For subagent escalation routing
|
||||
dynamic_tools_provider=self.dynamic_tools_provider,
|
||||
dynamic_prompt_provider=self.dynamic_prompt_provider,
|
||||
iteration_metadata_provider=self.iteration_metadata_provider,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
default_skill_warn_ratio=self.context_warn_ratio,
|
||||
default_skill_batch_nudge=self.batch_init_nudge,
|
||||
)
|
||||
|
||||
VALID_NODE_TYPES = {
|
||||
"event_loop",
|
||||
"gcu",
|
||||
@@ -1103,14 +997,36 @@ class GraphExecutor:
|
||||
branch.retry_count = attempt
|
||||
|
||||
# Build context for this branch
|
||||
ctx = self._build_context(
|
||||
node_spec,
|
||||
buffer,
|
||||
goal,
|
||||
mapped,
|
||||
graph.max_tokens,
|
||||
ctx = build_node_context(
|
||||
runtime=self.runtime,
|
||||
node_spec=node_spec,
|
||||
buffer=buffer,
|
||||
goal=goal,
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
max_tokens=graph.max_tokens,
|
||||
input_data=mapped,
|
||||
runtime_logger=self.runtime_logger,
|
||||
pause_event=self._pause_requested,
|
||||
accounts_prompt=self.accounts_prompt,
|
||||
accounts_data=self.accounts_data,
|
||||
tool_provider_map=self.tool_provider_map,
|
||||
identity_prompt="",
|
||||
narrative="",
|
||||
execution_id=self._execution_id,
|
||||
run_id=self._run_id,
|
||||
stream_id=self._stream_id,
|
||||
node_registry=node_registry,
|
||||
graph=graph,
|
||||
all_tools=self.tools,
|
||||
shared_node_registry=self.node_registry,
|
||||
dynamic_tools_provider=self.dynamic_tools_provider,
|
||||
dynamic_prompt_provider=self.dynamic_prompt_provider,
|
||||
iteration_metadata_provider=self.iteration_metadata_provider,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
default_skill_warn_ratio=self.context_warn_ratio,
|
||||
default_skill_batch_nudge=self.batch_init_nudge,
|
||||
)
|
||||
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||
|
||||
@@ -1353,7 +1269,6 @@ class GraphExecutor:
|
||||
from framework.graph.worker_agent import (
|
||||
Activation,
|
||||
FanOutTag,
|
||||
GraphContext,
|
||||
WorkerAgent,
|
||||
WorkerCompletion,
|
||||
WorkerLifecycle,
|
||||
@@ -1399,8 +1314,9 @@ class GraphExecutor:
|
||||
for node_spec in graph.nodes:
|
||||
workers[node_spec.id] = WorkerAgent(node_spec=node_spec, graph_context=gc)
|
||||
|
||||
# Identify entry workers (zero incoming edges) and terminal workers
|
||||
entry_worker_ids = [wid for wid, w in workers.items() if w.is_entry]
|
||||
# Identify entry workers (graph entry node, not based on edge count)
|
||||
# A node can be the entry point AND have incoming feedback edges.
|
||||
entry_worker_ids = [graph.entry_node]
|
||||
terminal_worker_ids = set(graph.terminal_nodes or [])
|
||||
|
||||
self.logger.info(
|
||||
@@ -1457,6 +1373,9 @@ class GraphExecutor:
|
||||
|
||||
def _check_graph_done() -> bool:
|
||||
"""Check whether active graph work has reached a terminal state."""
|
||||
# Step-limit guard (equivalent to old while-loop's max_steps)
|
||||
if len(gc.path) >= graph.max_steps:
|
||||
return True
|
||||
if not terminal_worker_ids:
|
||||
# No terminals: check if all workers are done
|
||||
return all(
|
||||
@@ -1774,6 +1693,15 @@ class GraphExecutor:
|
||||
error = last_result.error
|
||||
elif task_error is not None:
|
||||
error = str(task_error)
|
||||
|
||||
# Route ON_FAILURE activations
|
||||
outgoing_activations = worker._last_activations
|
||||
if outgoing_activations:
|
||||
for activation in outgoing_activations:
|
||||
_route_activation(
|
||||
activation, workers, pending_tasks,
|
||||
has_event_subscription=False,
|
||||
)
|
||||
elif task_error is not None:
|
||||
error = str(task_error)
|
||||
else:
|
||||
@@ -1800,6 +1728,11 @@ class GraphExecutor:
|
||||
|
||||
# Quality assessment
|
||||
has_failures = bool(failed_workers) or execution_error is not None
|
||||
# If all terminal workers completed successfully, intermediate failures
|
||||
# (handled by ON_FAILURE edges) don't count against overall success.
|
||||
if terminal_worker_ids and completed_terminals >= terminal_worker_ids:
|
||||
terminal_failures = terminal_worker_ids & set(failed_workers.keys())
|
||||
has_failures = bool(terminal_failures) or execution_error is not None
|
||||
exec_quality = "failed" if has_failures else "clean"
|
||||
|
||||
saved_buffer = buffer.read_all()
|
||||
|
||||
@@ -19,17 +19,15 @@ from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.context import GraphContext, build_node_context_from_graph_context
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec
|
||||
from framework.graph.node import (
|
||||
DataBuffer,
|
||||
NodeContext,
|
||||
NodeProtocol,
|
||||
NodeResult,
|
||||
NodeSpec,
|
||||
)
|
||||
from framework.graph.validator import OutputValidator
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -106,58 +104,6 @@ class RetryState:
|
||||
is_event_loop: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphContext:
|
||||
"""Shared state for one graph execution run.
|
||||
|
||||
Consolidates the 20+ constructor params on ``GraphExecutor.__init__``
|
||||
into a single object shared by reference across all workers.
|
||||
"""
|
||||
|
||||
graph: GraphSpec
|
||||
goal: Goal
|
||||
buffer: DataBuffer
|
||||
runtime: Runtime
|
||||
llm: Any # LLMProvider
|
||||
tools: list[Any] # list[Tool]
|
||||
tool_executor: Any # Callable
|
||||
event_bus: Any # GraphScopedEventBus
|
||||
execution_id: str
|
||||
stream_id: str
|
||||
run_id: str
|
||||
storage_path: Any # Path | None
|
||||
runtime_logger: Any = None
|
||||
node_registry: dict[str, NodeProtocol] = field(default_factory=dict)
|
||||
node_spec_registry: dict[str, NodeSpec] = field(default_factory=dict)
|
||||
# Parallel execution config
|
||||
parallel_config: Any = None # ParallelExecutionConfig | None
|
||||
# Continuous mode
|
||||
is_continuous: bool = False
|
||||
continuous_conversation: Any = None
|
||||
cumulative_tools: list[Any] = field(default_factory=list)
|
||||
cumulative_tool_names: set[str] = field(default_factory=set)
|
||||
cumulative_output_keys: list[str] = field(default_factory=list)
|
||||
# Accounts / skills / dynamic providers
|
||||
accounts_prompt: str = ""
|
||||
accounts_data: list[dict] | None = None
|
||||
tool_provider_map: dict[str, str] | None = None
|
||||
skills_catalog_prompt: str = ""
|
||||
protocols_prompt: str = ""
|
||||
skill_dirs: list[str] = field(default_factory=list)
|
||||
context_warn_ratio: float | None = None
|
||||
batch_init_nudge: str | None = None
|
||||
dynamic_tools_provider: Any = None
|
||||
dynamic_prompt_provider: Any = None
|
||||
iteration_metadata_provider: Any = None
|
||||
# Loop config for EventLoopNode creation
|
||||
loop_config: dict[str, Any] = field(default_factory=dict)
|
||||
# Thread-safe execution state
|
||||
path: list[str] = field(default_factory=list)
|
||||
node_visit_counts: dict[str, int] = field(default_factory=dict)
|
||||
_path_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_visits_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkerAgent
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -379,9 +325,12 @@ class WorkerAgent:
|
||||
)
|
||||
await self._publish_completion(completion)
|
||||
else:
|
||||
# Evaluate outgoing edges even on failure (ON_FAILURE edges)
|
||||
activations = await self._evaluate_outgoing_edges(result)
|
||||
|
||||
self.lifecycle = WorkerLifecycle.FAILED
|
||||
self._last_result = result
|
||||
self._last_activations = []
|
||||
self._last_activations = activations
|
||||
await self._publish_failure(result.error or "Unknown error")
|
||||
except Exception as exc:
|
||||
error = str(exc) or type(exc).__name__
|
||||
@@ -396,9 +345,16 @@ class WorkerAgent:
|
||||
) -> NodeResult:
|
||||
"""Execute node with exponential backoff retry."""
|
||||
gc = self._gc
|
||||
max_retries = 0 if self.retry_state.is_event_loop else self.retry_state.max_retries
|
||||
# Only skip retries for actual EventLoopNode instances (they handle
|
||||
# retries internally). Custom NodeProtocol impls registered via
|
||||
# register_node should be retried by the executor.
|
||||
from framework.graph.event_loop_node import EventLoopNode as _ELN
|
||||
if isinstance(node_impl, _ELN):
|
||||
max_retries = 0
|
||||
else:
|
||||
max_retries = self.retry_state.max_retries
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
for attempt in range(max(1, max_retries)):
|
||||
# Check pause
|
||||
await self._run_gate.wait()
|
||||
|
||||
@@ -413,13 +369,13 @@ class WorkerAgent:
|
||||
return result
|
||||
|
||||
# Failure
|
||||
if attempt < max_retries:
|
||||
if attempt + 1 < max(1, max_retries):
|
||||
delay = 1.0 * (2**attempt)
|
||||
logger.warning(
|
||||
"Worker %s failed (attempt %d/%d), retrying in %.1fs: %s",
|
||||
self.node_spec.id,
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
max_retries,
|
||||
delay,
|
||||
result.error,
|
||||
)
|
||||
@@ -435,24 +391,33 @@ class WorkerAgent:
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
return result
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=f"failed after {attempt + 1} attempts: {result.error}",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
if attempt < max_retries:
|
||||
if attempt < max(1, max_retries) - 1:
|
||||
delay = 1.0 * (2**attempt)
|
||||
logger.warning(
|
||||
"Worker %s raised %s (attempt %d/%d), retrying in %.1fs",
|
||||
self.node_spec.id,
|
||||
type(exc).__name__,
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
max(1, max_retries),
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
return NodeResult(success=False, error=str(exc))
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=f"failed after {attempt + 1} attempts: {exc}",
|
||||
)
|
||||
|
||||
return NodeResult(success=False, error="Max retries exceeded")
|
||||
return NodeResult(
|
||||
success=False,
|
||||
error=f"failed after {max(1, max_retries)} attempts",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Edge evaluation (source-side)
|
||||
@@ -624,88 +589,10 @@ class WorkerAgent:
|
||||
|
||||
def _build_node_context(self) -> NodeContext:
|
||||
"""Build NodeContext for this worker's execution."""
|
||||
gc = self._gc
|
||||
node_spec = self.node_spec
|
||||
|
||||
# Filter tools
|
||||
if gc.is_continuous and gc.cumulative_tools:
|
||||
available_tools = list(gc.cumulative_tools)
|
||||
else:
|
||||
available_tools = []
|
||||
if node_spec.tools:
|
||||
available_tools = [t for t in gc.tools if t.name in node_spec.tools]
|
||||
|
||||
# Scoped buffer
|
||||
read_keys = list(node_spec.input_keys)
|
||||
write_keys = list(node_spec.output_keys)
|
||||
if read_keys or write_keys:
|
||||
from framework.skills.defaults import DATA_BUFFER_KEYS as _skill_keys
|
||||
|
||||
existing_underscore = [k for k in gc.buffer._data if k.startswith("_")]
|
||||
extra_keys = set(_skill_keys) | set(existing_underscore)
|
||||
for k in extra_keys:
|
||||
if read_keys and k not in read_keys:
|
||||
read_keys.append(k)
|
||||
if write_keys and k not in write_keys:
|
||||
write_keys.append(k)
|
||||
|
||||
scoped_buffer = gc.buffer.with_permissions(read_keys=read_keys, write_keys=write_keys)
|
||||
|
||||
# Per-node accounts prompt
|
||||
node_accounts_prompt = gc.accounts_prompt
|
||||
if gc.accounts_data and gc.tool_provider_map:
|
||||
from framework.graph.prompt_composer import build_accounts_prompt
|
||||
|
||||
node_accounts_prompt = build_accounts_prompt(
|
||||
gc.accounts_data,
|
||||
gc.tool_provider_map,
|
||||
node_tool_names=node_spec.tools,
|
||||
) or gc.accounts_prompt
|
||||
|
||||
# Input data from buffer
|
||||
input_data: dict[str, Any] = {}
|
||||
for key in node_spec.input_keys:
|
||||
val = gc.buffer.read(key)
|
||||
if val is not None:
|
||||
input_data[key] = val
|
||||
|
||||
# Continuous mode: thread conversation
|
||||
inherited_conversation = None
|
||||
if gc.is_continuous and gc.continuous_conversation:
|
||||
inherited_conversation = gc.continuous_conversation
|
||||
|
||||
return NodeContext(
|
||||
runtime=gc.runtime,
|
||||
node_id=node_spec.id,
|
||||
node_spec=node_spec,
|
||||
buffer=scoped_buffer,
|
||||
input_data=input_data,
|
||||
llm=gc.llm,
|
||||
available_tools=available_tools,
|
||||
goal_context=gc.goal.to_prompt_context(),
|
||||
goal=gc.goal,
|
||||
max_tokens=gc.graph.max_tokens,
|
||||
runtime_logger=gc.runtime_logger,
|
||||
return build_node_context_from_graph_context(
|
||||
self._gc,
|
||||
node_spec=self.node_spec,
|
||||
pause_event=self._pause_requested,
|
||||
continuous_mode=gc.is_continuous,
|
||||
inherited_conversation=inherited_conversation,
|
||||
cumulative_output_keys=list(gc.cumulative_output_keys) if gc.is_continuous else [],
|
||||
accounts_prompt=node_accounts_prompt,
|
||||
identity_prompt=getattr(gc.graph, "identity_prompt", "") or "",
|
||||
execution_id=gc.execution_id,
|
||||
run_id=gc.run_id,
|
||||
stream_id=gc.stream_id,
|
||||
node_registry=gc.node_spec_registry,
|
||||
all_tools=list(gc.tools),
|
||||
shared_node_registry=gc.node_registry,
|
||||
dynamic_tools_provider=gc.dynamic_tools_provider,
|
||||
dynamic_prompt_provider=gc.dynamic_prompt_provider,
|
||||
iteration_metadata_provider=gc.iteration_metadata_provider,
|
||||
skills_catalog_prompt=gc.skills_catalog_prompt,
|
||||
protocols_prompt=gc.protocols_prompt,
|
||||
skill_dirs=list(gc.skill_dirs),
|
||||
default_skill_warn_ratio=gc.context_warn_ratio,
|
||||
default_skill_batch_nudge=gc.batch_init_nudge,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -398,25 +398,39 @@ async def _smoke_test_provider_async(provider: dict, timeout_seconds: float = 25
|
||||
"branch execution did not reach the expected terminal path: "
|
||||
f"{result.path}"
|
||||
)
|
||||
if result.output.get("result") != "BRANCH_OK":
|
||||
if not result.output.get("result"):
|
||||
raise RuntimeError(
|
||||
"branch execution completed but did not produce result='BRANCH_OK'"
|
||||
"branch execution reached the expected terminal path but did not "
|
||||
f"produce a non-empty result output: path={result.path} "
|
||||
f"output={result.output}"
|
||||
)
|
||||
|
||||
current_step = "plain completion"
|
||||
current_timeout = timeout_seconds
|
||||
worker_timeout = max(
|
||||
timeout_seconds,
|
||||
float(os.environ.get("DUMMY_AGENT_SMOKE_WORKER_TIMEOUT_SECS", "30")),
|
||||
)
|
||||
branch_timeout = max(
|
||||
timeout_seconds,
|
||||
float(os.environ.get("DUMMY_AGENT_SMOKE_BRANCH_TIMEOUT_SECS", "60")),
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(_run_plain_completion(), timeout=timeout_seconds)
|
||||
await asyncio.wait_for(_run_plain_completion(), timeout=current_timeout)
|
||||
current_step = "tool calling"
|
||||
await asyncio.wait_for(_run_tool_completion(), timeout=timeout_seconds)
|
||||
current_timeout = timeout_seconds
|
||||
await asyncio.wait_for(_run_tool_completion(), timeout=current_timeout)
|
||||
current_step = "single-node worker execution"
|
||||
await asyncio.wait_for(_run_worker_execution(), timeout=timeout_seconds)
|
||||
current_timeout = worker_timeout
|
||||
await asyncio.wait_for(_run_worker_execution(), timeout=current_timeout)
|
||||
current_step = "branch worker execution"
|
||||
await asyncio.wait_for(_run_branch_execution(), timeout=timeout_seconds)
|
||||
current_timeout = branch_timeout
|
||||
await asyncio.wait_for(_run_branch_execution(), timeout=current_timeout)
|
||||
except TimeoutError as exc:
|
||||
raise RuntimeError(
|
||||
f"provider smoke test timed out during {current_step} "
|
||||
f"after {timeout_seconds:.0f}s"
|
||||
f"after {current_timeout:.0f}s"
|
||||
) from exc
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Unit tests for shared graph context helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from framework.graph.context import (
|
||||
GraphContext,
|
||||
build_node_accounts_prompt,
|
||||
build_node_context,
|
||||
build_node_context_from_graph_context,
|
||||
build_scoped_buffer,
|
||||
)
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import DataBuffer, NodeSpec
|
||||
from framework.llm.provider import Tool
|
||||
from framework.skills.defaults import DATA_BUFFER_KEYS
|
||||
|
||||
|
||||
class DummyRuntime:
|
||||
execution_id = ""
|
||||
|
||||
|
||||
def _make_tool(name: str) -> Tool:
|
||||
return Tool(
|
||||
name=name,
|
||||
description=f"Tool {name}",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
|
||||
|
||||
def _make_goal() -> Goal:
|
||||
return Goal(id="goal-1", name="Goal", description="Test goal")
|
||||
|
||||
|
||||
def _make_graph(node_spec: NodeSpec) -> GraphSpec:
|
||||
return GraphSpec(
|
||||
id="graph-1",
|
||||
goal_id="goal-1",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
entry_node=node_spec.id,
|
||||
terminal_nodes=[node_spec.id],
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
|
||||
def test_build_scoped_buffer_includes_skill_and_existing_internal_keys():
|
||||
buffer = DataBuffer()
|
||||
buffer.write("task", "draft")
|
||||
buffer.write("_worker_state", "active")
|
||||
|
||||
node_spec = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="Writes output",
|
||||
node_type="event_loop",
|
||||
input_keys=["task"],
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
scoped = build_scoped_buffer(buffer, node_spec)
|
||||
|
||||
assert "task" in scoped._allowed_read
|
||||
assert "result" in scoped._allowed_write
|
||||
assert "_worker_state" in scoped._allowed_read
|
||||
assert "_worker_state" in scoped._allowed_write
|
||||
for key in DATA_BUFFER_KEYS:
|
||||
assert key in scoped._allowed_read
|
||||
assert key in scoped._allowed_write
|
||||
|
||||
|
||||
def test_build_scoped_buffer_keeps_empty_permissions_unrestricted():
|
||||
buffer = DataBuffer()
|
||||
buffer.write("task", "draft")
|
||||
buffer.write("_worker_state", "active")
|
||||
|
||||
node_spec = NodeSpec(
|
||||
id="reader",
|
||||
name="Reader",
|
||||
description="Reads everything",
|
||||
node_type="event_loop",
|
||||
input_keys=[],
|
||||
output_keys=[],
|
||||
)
|
||||
|
||||
scoped = build_scoped_buffer(buffer, node_spec)
|
||||
|
||||
assert scoped._allowed_read == set()
|
||||
assert scoped._allowed_write == set()
|
||||
assert scoped.read_all()["task"] == "draft"
|
||||
assert scoped.read_all()["_worker_state"] == "active"
|
||||
|
||||
|
||||
def test_accounts_prompt_falls_back_when_filtered_prompt_is_empty():
|
||||
prompt = build_node_accounts_prompt(
|
||||
accounts_prompt="DEFAULT_ACCOUNTS",
|
||||
accounts_data=[{"provider": "google", "alias": "personal", "identity": {}}],
|
||||
tool_provider_map={"gmail_list_messages": "google"},
|
||||
node_tool_names=["slack_post_message"],
|
||||
fallback_to_default=True,
|
||||
)
|
||||
|
||||
assert prompt == "DEFAULT_ACCOUNTS"
|
||||
|
||||
|
||||
def test_build_node_context_from_graph_context_preserves_continuous_state():
|
||||
node_spec = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="Writes output",
|
||||
node_type="event_loop",
|
||||
input_keys=["task"],
|
||||
output_keys=["draft"],
|
||||
tools=["save_data"],
|
||||
)
|
||||
buffer = DataBuffer()
|
||||
buffer.write("task", "write the draft")
|
||||
conversation = object()
|
||||
save_data = _make_tool("save_data")
|
||||
fallback_tool = _make_tool("web_search")
|
||||
graph = _make_graph(node_spec)
|
||||
|
||||
graph_context = GraphContext(
|
||||
graph=graph,
|
||||
goal=_make_goal(),
|
||||
buffer=buffer,
|
||||
runtime=DummyRuntime(),
|
||||
llm=None,
|
||||
tools=[save_data, fallback_tool],
|
||||
tool_executor=None,
|
||||
event_bus=None,
|
||||
execution_id="exec-1",
|
||||
stream_id="stream-1",
|
||||
run_id="run-1",
|
||||
storage_path=None,
|
||||
is_continuous=True,
|
||||
continuous_conversation=conversation,
|
||||
cumulative_tools=[fallback_tool],
|
||||
cumulative_output_keys=["outline", "draft"],
|
||||
accounts_prompt="ACCOUNTS",
|
||||
skills_catalog_prompt="SKILLS",
|
||||
protocols_prompt="PROTOCOLS",
|
||||
)
|
||||
|
||||
ctx = build_node_context_from_graph_context(
|
||||
graph_context,
|
||||
node_spec=node_spec,
|
||||
pause_event="pause-signal",
|
||||
)
|
||||
|
||||
assert ctx.input_data == {"task": "write the draft"}
|
||||
assert ctx.inherited_conversation is conversation
|
||||
assert ctx.cumulative_output_keys == ["outline", "draft"]
|
||||
assert [tool.name for tool in ctx.available_tools] == ["web_search"]
|
||||
assert ctx.pause_event == "pause-signal"
|
||||
assert ctx.accounts_prompt == "ACCOUNTS"
|
||||
assert ctx.skills_catalog_prompt == "SKILLS"
|
||||
assert ctx.protocols_prompt == "PROTOCOLS"
|
||||
|
||||
|
||||
def test_build_node_context_uses_override_tools_for_legacy_executor_path():
|
||||
node_spec = NodeSpec(
|
||||
id="branch",
|
||||
name="Branch",
|
||||
description="Legacy branch execution",
|
||||
node_type="event_loop",
|
||||
input_keys=["task"],
|
||||
output_keys=["result"],
|
||||
tools=["save_data"],
|
||||
)
|
||||
buffer = DataBuffer()
|
||||
save_data = _make_tool("save_data")
|
||||
web_search = _make_tool("web_search")
|
||||
|
||||
ctx = build_node_context(
|
||||
runtime=DummyRuntime(),
|
||||
node_spec=node_spec,
|
||||
buffer=buffer,
|
||||
goal=_make_goal(),
|
||||
llm=None,
|
||||
tools=[save_data],
|
||||
max_tokens=1024,
|
||||
input_data={"task": "run branch"},
|
||||
override_tools=[save_data, web_search],
|
||||
node_registry={"branch": node_spec},
|
||||
)
|
||||
|
||||
assert ctx.input_data == {"task": "run branch"}
|
||||
assert [tool.name for tool in ctx.available_tools] == ["save_data", "web_search"]
|
||||
assert ctx.node_registry == {"branch": node_spec}
|
||||
Reference in New Issue
Block a user