feat: consolidate context building

This commit is contained in:
Richard Tang
2026-04-02 15:54:16 -07:00
parent e1911b3684
commit c5052ade34
6 changed files with 605 additions and 272 deletions
+1 -1
View File
@@ -1,5 +1,6 @@
"""Graph structures: Goals, Nodes, Edges, and Execution."""
from framework.graph.context import GraphContext
from framework.graph.context_handoff import ContextHandoff, HandoffContext
from framework.graph.conversation import ConversationStore, Message, NodeConversation
from framework.graph.edge import DEFAULT_MAX_TOKENS, EdgeCondition, EdgeSpec, GraphSpec
@@ -17,7 +18,6 @@ from framework.graph.worker_agent import (
Activation,
FanOutTag,
FanOutTracker,
GraphContext,
WorkerAgent,
WorkerCompletion,
WorkerLifecycle,
+309
View File
@@ -0,0 +1,309 @@
"""Shared graph execution context helpers.
This module centralizes:
- Graph-run shared state (`GraphContext`)
- Scoped buffer permission shaping for a node
- Per-node accounts prompt resolution
- Canonical `NodeContext` construction
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any
from framework.graph.edge import GraphSpec
from framework.graph.goal import Goal
from framework.graph.node import DataBuffer, NodeContext, NodeProtocol, NodeSpec
from framework.runtime.core import Runtime
@dataclass
class GraphContext:
"""Shared state for one graph execution run."""
graph: GraphSpec
goal: Goal
buffer: DataBuffer
runtime: Runtime
llm: Any # LLMProvider
tools: list[Any] # list[Tool]
tool_executor: Any # Callable
event_bus: Any # GraphScopedEventBus
execution_id: str
stream_id: str
run_id: str
storage_path: Any # Path | None
runtime_logger: Any = None
node_registry: dict[str, NodeProtocol] = field(default_factory=dict)
node_spec_registry: dict[str, NodeSpec] = field(default_factory=dict)
parallel_config: Any = None # ParallelExecutionConfig | None
is_continuous: bool = False
continuous_conversation: Any = None
cumulative_tools: list[Any] = field(default_factory=list)
cumulative_tool_names: set[str] = field(default_factory=set)
cumulative_output_keys: list[str] = field(default_factory=list)
accounts_prompt: str = ""
accounts_data: list[dict] | None = None
tool_provider_map: dict[str, str] | None = None
skills_catalog_prompt: str = ""
protocols_prompt: str = ""
skill_dirs: list[str] = field(default_factory=list)
context_warn_ratio: float | None = None
batch_init_nudge: str | None = None
dynamic_tools_provider: Any = None
dynamic_prompt_provider: Any = None
iteration_metadata_provider: Any = None
loop_config: dict[str, Any] = field(default_factory=dict)
path: list[str] = field(default_factory=list)
node_visit_counts: dict[str, int] = field(default_factory=dict)
_path_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_visits_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
def build_scoped_buffer(buffer: DataBuffer, node_spec: NodeSpec) -> DataBuffer:
"""Create a node-scoped buffer view.
When permissions are already restricted, auto-include framework-managed
`_`-prefixed keys used by the default skill protocols.
"""
read_keys = list(node_spec.input_keys)
write_keys = list(node_spec.output_keys)
if read_keys or write_keys:
from framework.skills.defaults import DATA_BUFFER_KEYS as _skill_keys
existing_underscore = [k for k in buffer._data if k.startswith("_")]
extra_keys = set(_skill_keys) | set(existing_underscore)
for key in extra_keys:
if read_keys and key not in read_keys:
read_keys.append(key)
if write_keys and key not in write_keys:
write_keys.append(key)
return buffer.with_permissions(read_keys=read_keys, write_keys=write_keys)
def build_node_accounts_prompt(
*,
accounts_prompt: str,
accounts_data: list[dict] | None,
tool_provider_map: dict[str, str] | None,
node_tool_names: list[str] | None,
fallback_to_default: bool = False,
) -> str:
"""Resolve the accounts prompt for one node."""
resolved = accounts_prompt
if accounts_data and tool_provider_map:
from framework.graph.prompt_composer import build_accounts_prompt
filtered = build_accounts_prompt(
accounts_data,
tool_provider_map,
node_tool_names=node_tool_names,
)
if filtered or not fallback_to_default:
resolved = filtered
return resolved
def _resolve_available_tools(
*,
node_spec: NodeSpec,
tools: list[Any],
override_tools: list[Any] | None,
) -> list[Any]:
"""Select tools available to the current node."""
if override_tools is not None:
return list(override_tools)
if not node_spec.tools:
return []
return [tool for tool in tools if tool.name in node_spec.tools]
def _derive_input_data(buffer: DataBuffer, input_keys: list[str]) -> dict[str, Any]:
"""Collect node inputs from the shared buffer."""
input_data: dict[str, Any] = {}
for key in input_keys:
value = buffer.read(key)
if value is not None:
input_data[key] = value
return input_data
def build_node_context(
*,
runtime: Runtime,
node_spec: NodeSpec,
buffer: DataBuffer,
goal: Goal,
llm: Any,
tools: list[Any],
max_tokens: int,
input_data: dict[str, Any] | None = None,
derive_input_data_from_buffer: bool = False,
runtime_logger: Any = None,
pause_event: Any = None,
continuous_mode: bool = False,
inherited_conversation: Any = None,
override_tools: list[Any] | None = None,
cumulative_output_keys: list[str] | None = None,
event_triggered: bool = False,
accounts_prompt: str = "",
accounts_data: list[dict] | None = None,
tool_provider_map: dict[str, str] | None = None,
fallback_to_default_accounts_prompt: bool = False,
identity_prompt: str = "",
narrative: str = "",
execution_id: str = "",
run_id: str = "",
stream_id: str = "",
node_registry: dict[str, NodeSpec] | None = None,
all_tools: list[Any] | None = None,
shared_node_registry: dict[str, NodeProtocol] | None = None,
dynamic_tools_provider: Any = None,
dynamic_prompt_provider: Any = None,
iteration_metadata_provider: Any = None,
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
default_skill_warn_ratio: float | None = None,
default_skill_batch_nudge: str | None = None,
) -> NodeContext:
"""Build a canonical `NodeContext` for graph execution."""
available_tools = _resolve_available_tools(
node_spec=node_spec,
tools=tools,
override_tools=override_tools,
)
scoped_buffer = build_scoped_buffer(buffer, node_spec)
node_accounts_prompt = build_node_accounts_prompt(
accounts_prompt=accounts_prompt,
accounts_data=accounts_data,
tool_provider_map=tool_provider_map,
node_tool_names=node_spec.tools,
fallback_to_default=fallback_to_default_accounts_prompt,
)
resolved_input_data = (
_derive_input_data(buffer, node_spec.input_keys)
if input_data is None and derive_input_data_from_buffer
else dict(input_data or {})
)
return NodeContext(
runtime=runtime,
node_id=node_spec.id,
node_spec=node_spec,
buffer=scoped_buffer,
input_data=resolved_input_data,
llm=llm,
available_tools=available_tools,
goal_context=goal.to_prompt_context(),
goal=goal,
max_tokens=max_tokens,
runtime_logger=runtime_logger,
pause_event=pause_event,
continuous_mode=continuous_mode,
inherited_conversation=inherited_conversation,
cumulative_output_keys=cumulative_output_keys or [],
event_triggered=event_triggered,
accounts_prompt=node_accounts_prompt,
identity_prompt=identity_prompt,
narrative=narrative,
execution_id=execution_id,
run_id=run_id,
stream_id=stream_id,
node_registry=node_registry or {},
all_tools=list(all_tools or tools),
shared_node_registry=shared_node_registry or {},
dynamic_tools_provider=dynamic_tools_provider,
dynamic_prompt_provider=dynamic_prompt_provider,
iteration_metadata_provider=iteration_metadata_provider,
skills_catalog_prompt=skills_catalog_prompt,
protocols_prompt=protocols_prompt,
skill_dirs=list(skill_dirs or []),
default_skill_warn_ratio=default_skill_warn_ratio,
default_skill_batch_nudge=default_skill_batch_nudge,
)
def build_node_context_from_graph_context(
graph_context: GraphContext,
*,
node_spec: NodeSpec,
pause_event: Any = None,
input_data: dict[str, Any] | None = None,
derive_input_data_from_buffer: bool = True,
override_tools: list[Any] | None = None,
inherited_conversation: Any = None,
cumulative_output_keys: list[str] | None = None,
event_triggered: bool = False,
identity_prompt: str | None = None,
narrative: str = "",
node_registry: dict[str, NodeSpec] | None = None,
fallback_to_default_accounts_prompt: bool = True,
) -> NodeContext:
"""Build `NodeContext` using shared graph-run state."""
gc = graph_context
resolved_override_tools = override_tools
if resolved_override_tools is None and gc.is_continuous and gc.cumulative_tools:
resolved_override_tools = list(gc.cumulative_tools)
resolved_inherited_conversation = inherited_conversation
if resolved_inherited_conversation is None and gc.is_continuous:
resolved_inherited_conversation = gc.continuous_conversation
resolved_output_keys = cumulative_output_keys
if resolved_output_keys is None and gc.is_continuous:
resolved_output_keys = list(gc.cumulative_output_keys)
return build_node_context(
runtime=gc.runtime,
node_spec=node_spec,
buffer=gc.buffer,
goal=gc.goal,
llm=gc.llm,
tools=gc.tools,
max_tokens=gc.graph.max_tokens,
input_data=input_data,
derive_input_data_from_buffer=derive_input_data_from_buffer,
runtime_logger=gc.runtime_logger,
pause_event=pause_event,
continuous_mode=gc.is_continuous,
inherited_conversation=resolved_inherited_conversation,
override_tools=resolved_override_tools,
cumulative_output_keys=resolved_output_keys,
event_triggered=event_triggered,
accounts_prompt=gc.accounts_prompt,
accounts_data=gc.accounts_data,
tool_provider_map=gc.tool_provider_map,
fallback_to_default_accounts_prompt=fallback_to_default_accounts_prompt,
identity_prompt=identity_prompt if identity_prompt is not None else getattr(gc.graph, "identity_prompt", "") or "",
narrative=narrative,
execution_id=gc.execution_id,
run_id=gc.run_id,
stream_id=gc.stream_id,
node_registry=node_registry or gc.node_spec_registry,
all_tools=gc.tools,
shared_node_registry=gc.node_registry,
dynamic_tools_provider=gc.dynamic_tools_provider,
dynamic_prompt_provider=gc.dynamic_prompt_provider,
iteration_metadata_provider=gc.iteration_metadata_provider,
skills_catalog_prompt=gc.skills_catalog_prompt,
protocols_prompt=gc.protocols_prompt,
skill_dirs=gc.skill_dirs,
default_skill_warn_ratio=gc.context_warn_ratio,
default_skill_batch_nudge=gc.batch_init_nudge,
)
+50 -117
View File
@@ -17,11 +17,11 @@ from pathlib import Path
from typing import Any
from framework.graph.checkpoint_config import CheckpointConfig
from framework.graph.context import GraphContext, build_node_context
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.goal import Goal
from framework.graph.conversation import LEGACY_RUN_ID, get_run_cursor
from framework.graph.node import (
NodeContext,
NodeProtocol,
NodeResult,
NodeSpec,
@@ -707,112 +707,6 @@ class GraphExecutor:
ToolRegistry.reset_execution_context(_ctx_token)
def _build_context(
self,
node_spec: NodeSpec,
buffer: DataBuffer,
goal: Goal,
input_data: dict[str, Any],
max_tokens: int = 4096,
continuous_mode: bool = False,
inherited_conversation: Any = None,
override_tools: list | None = None,
cumulative_output_keys: list[str] | None = None,
event_triggered: bool = False,
identity_prompt: str = "",
narrative: str = "",
node_registry: dict[str, NodeSpec] | None = None,
graph: "GraphSpec | None" = None,
) -> NodeContext:
"""Build execution context for a node."""
# Filter tools to those available to this node
if override_tools is not None:
# Continuous mode: use cumulative tool set
available_tools = list(override_tools)
else:
available_tools = []
if node_spec.tools:
available_tools = [t for t in self.tools if t.name in node_spec.tools]
# Create scoped buffer view.
# When permissions are restricted (non-empty key lists), auto-include
# _-prefixed keys used by default skill protocols so agents can read/write
# operational state (e.g. _working_notes, _batch_ledger) regardless of
# what the node declares. When key lists are empty (unrestricted), leave
# unchanged — empty means "allow all".
read_keys = list(node_spec.input_keys)
write_keys = list(node_spec.output_keys)
# Only extend lists that were already restricted (non-empty).
# Empty means "allow all" — adding keys would accidentally
# activate the permission check and block legitimate reads/writes.
if read_keys or write_keys:
from framework.skills.defaults import DATA_BUFFER_KEYS as _skill_keys
existing_underscore = [k for k in buffer._data if k.startswith("_")]
extra_keys = set(_skill_keys) | set(existing_underscore)
# Only inject into read_keys when it was already non-empty — an empty
# read_keys means "allow all reads" and injecting skill keys would
# inadvertently restrict reads to skill keys only.
for k in extra_keys:
if read_keys and k not in read_keys:
read_keys.append(k)
if write_keys and k not in write_keys:
write_keys.append(k)
scoped_buffer = buffer.with_permissions(
read_keys=read_keys,
write_keys=write_keys,
)
# Build per-node accounts prompt (filtered to this node's tools)
node_accounts_prompt = self.accounts_prompt
if self.accounts_data and self.tool_provider_map:
from framework.graph.prompt_composer import build_accounts_prompt
node_accounts_prompt = build_accounts_prompt(
self.accounts_data,
self.tool_provider_map,
node_tool_names=node_spec.tools,
)
goal_context = goal.to_prompt_context()
return NodeContext(
runtime=self.runtime,
node_id=node_spec.id,
node_spec=node_spec,
buffer=scoped_buffer,
input_data=input_data,
llm=self.llm,
available_tools=available_tools,
goal_context=goal_context,
goal=goal, # Pass Goal object for LLM-powered routers
max_tokens=max_tokens,
runtime_logger=self.runtime_logger,
pause_event=self._pause_requested, # Pass pause event for granular control
continuous_mode=continuous_mode,
inherited_conversation=inherited_conversation,
cumulative_output_keys=cumulative_output_keys or [],
event_triggered=event_triggered,
accounts_prompt=node_accounts_prompt,
identity_prompt=identity_prompt,
narrative=narrative,
execution_id=self._execution_id,
run_id=self._run_id,
stream_id=self._stream_id,
node_registry=node_registry or {},
all_tools=list(self.tools), # Full catalog for subagent tool resolution
shared_node_registry=self.node_registry, # For subagent escalation routing
dynamic_tools_provider=self.dynamic_tools_provider,
dynamic_prompt_provider=self.dynamic_prompt_provider,
iteration_metadata_provider=self.iteration_metadata_provider,
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
default_skill_warn_ratio=self.context_warn_ratio,
default_skill_batch_nudge=self.batch_init_nudge,
)
VALID_NODE_TYPES = {
"event_loop",
"gcu",
@@ -1103,14 +997,36 @@ class GraphExecutor:
branch.retry_count = attempt
# Build context for this branch
ctx = self._build_context(
node_spec,
buffer,
goal,
mapped,
graph.max_tokens,
ctx = build_node_context(
runtime=self.runtime,
node_spec=node_spec,
buffer=buffer,
goal=goal,
llm=self.llm,
tools=self.tools,
max_tokens=graph.max_tokens,
input_data=mapped,
runtime_logger=self.runtime_logger,
pause_event=self._pause_requested,
accounts_prompt=self.accounts_prompt,
accounts_data=self.accounts_data,
tool_provider_map=self.tool_provider_map,
identity_prompt="",
narrative="",
execution_id=self._execution_id,
run_id=self._run_id,
stream_id=self._stream_id,
node_registry=node_registry,
graph=graph,
all_tools=self.tools,
shared_node_registry=self.node_registry,
dynamic_tools_provider=self.dynamic_tools_provider,
dynamic_prompt_provider=self.dynamic_prompt_provider,
iteration_metadata_provider=self.iteration_metadata_provider,
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
default_skill_warn_ratio=self.context_warn_ratio,
default_skill_batch_nudge=self.batch_init_nudge,
)
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
@@ -1353,7 +1269,6 @@ class GraphExecutor:
from framework.graph.worker_agent import (
Activation,
FanOutTag,
GraphContext,
WorkerAgent,
WorkerCompletion,
WorkerLifecycle,
@@ -1399,8 +1314,9 @@ class GraphExecutor:
for node_spec in graph.nodes:
workers[node_spec.id] = WorkerAgent(node_spec=node_spec, graph_context=gc)
# Identify entry workers (zero incoming edges) and terminal workers
entry_worker_ids = [wid for wid, w in workers.items() if w.is_entry]
# Identify entry workers (graph entry node, not based on edge count)
# A node can be the entry point AND have incoming feedback edges.
entry_worker_ids = [graph.entry_node]
terminal_worker_ids = set(graph.terminal_nodes or [])
self.logger.info(
@@ -1457,6 +1373,9 @@ class GraphExecutor:
def _check_graph_done() -> bool:
"""Check whether active graph work has reached a terminal state."""
# Step-limit guard (equivalent to old while-loop's max_steps)
if len(gc.path) >= graph.max_steps:
return True
if not terminal_worker_ids:
# No terminals: check if all workers are done
return all(
@@ -1774,6 +1693,15 @@ class GraphExecutor:
error = last_result.error
elif task_error is not None:
error = str(task_error)
# Route ON_FAILURE activations
outgoing_activations = worker._last_activations
if outgoing_activations:
for activation in outgoing_activations:
_route_activation(
activation, workers, pending_tasks,
has_event_subscription=False,
)
elif task_error is not None:
error = str(task_error)
else:
@@ -1800,6 +1728,11 @@ class GraphExecutor:
# Quality assessment
has_failures = bool(failed_workers) or execution_error is not None
# If all terminal workers completed successfully, intermediate failures
# (handled by ON_FAILURE edges) don't count against overall success.
if terminal_worker_ids and completed_terminals >= terminal_worker_ids:
terminal_failures = terminal_worker_ids & set(failed_workers.keys())
has_failures = bool(terminal_failures) or execution_error is not None
exec_quality = "failed" if has_failures else "clean"
saved_buffer = buffer.read_all()
+34 -147
View File
@@ -19,17 +19,15 @@ from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.goal import Goal
from framework.graph.context import GraphContext, build_node_context_from_graph_context
from framework.graph.edge import EdgeCondition, EdgeSpec
from framework.graph.node import (
DataBuffer,
NodeContext,
NodeProtocol,
NodeResult,
NodeSpec,
)
from framework.graph.validator import OutputValidator
from framework.runtime.core import Runtime
logger = logging.getLogger(__name__)
@@ -106,58 +104,6 @@ class RetryState:
is_event_loop: bool = False
@dataclass
class GraphContext:
"""Shared state for one graph execution run.
Consolidates the 20+ constructor params on ``GraphExecutor.__init__``
into a single object shared by reference across all workers.
"""
graph: GraphSpec
goal: Goal
buffer: DataBuffer
runtime: Runtime
llm: Any # LLMProvider
tools: list[Any] # list[Tool]
tool_executor: Any # Callable
event_bus: Any # GraphScopedEventBus
execution_id: str
stream_id: str
run_id: str
storage_path: Any # Path | None
runtime_logger: Any = None
node_registry: dict[str, NodeProtocol] = field(default_factory=dict)
node_spec_registry: dict[str, NodeSpec] = field(default_factory=dict)
# Parallel execution config
parallel_config: Any = None # ParallelExecutionConfig | None
# Continuous mode
is_continuous: bool = False
continuous_conversation: Any = None
cumulative_tools: list[Any] = field(default_factory=list)
cumulative_tool_names: set[str] = field(default_factory=set)
cumulative_output_keys: list[str] = field(default_factory=list)
# Accounts / skills / dynamic providers
accounts_prompt: str = ""
accounts_data: list[dict] | None = None
tool_provider_map: dict[str, str] | None = None
skills_catalog_prompt: str = ""
protocols_prompt: str = ""
skill_dirs: list[str] = field(default_factory=list)
context_warn_ratio: float | None = None
batch_init_nudge: str | None = None
dynamic_tools_provider: Any = None
dynamic_prompt_provider: Any = None
iteration_metadata_provider: Any = None
# Loop config for EventLoopNode creation
loop_config: dict[str, Any] = field(default_factory=dict)
# Thread-safe execution state
path: list[str] = field(default_factory=list)
node_visit_counts: dict[str, int] = field(default_factory=dict)
_path_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_visits_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
# ---------------------------------------------------------------------------
# WorkerAgent
# ---------------------------------------------------------------------------
@@ -379,9 +325,12 @@ class WorkerAgent:
)
await self._publish_completion(completion)
else:
# Evaluate outgoing edges even on failure (ON_FAILURE edges)
activations = await self._evaluate_outgoing_edges(result)
self.lifecycle = WorkerLifecycle.FAILED
self._last_result = result
self._last_activations = []
self._last_activations = activations
await self._publish_failure(result.error or "Unknown error")
except Exception as exc:
error = str(exc) or type(exc).__name__
@@ -396,9 +345,16 @@ class WorkerAgent:
) -> NodeResult:
"""Execute node with exponential backoff retry."""
gc = self._gc
max_retries = 0 if self.retry_state.is_event_loop else self.retry_state.max_retries
# Only skip retries for actual EventLoopNode instances (they handle
# retries internally). Custom NodeProtocol impls registered via
# register_node should be retried by the executor.
from framework.graph.event_loop_node import EventLoopNode as _ELN
if isinstance(node_impl, _ELN):
max_retries = 0
else:
max_retries = self.retry_state.max_retries
for attempt in range(max_retries + 1):
for attempt in range(max(1, max_retries)):
# Check pause
await self._run_gate.wait()
@@ -413,13 +369,13 @@ class WorkerAgent:
return result
# Failure
if attempt < max_retries:
if attempt + 1 < max(1, max_retries):
delay = 1.0 * (2**attempt)
logger.warning(
"Worker %s failed (attempt %d/%d), retrying in %.1fs: %s",
self.node_spec.id,
attempt + 1,
max_retries + 1,
max_retries,
delay,
result.error,
)
@@ -435,24 +391,33 @@ class WorkerAgent:
await asyncio.sleep(delay)
continue
else:
return result
return NodeResult(
success=False,
error=f"failed after {attempt + 1} attempts: {result.error}",
)
except Exception as exc:
if attempt < max_retries:
if attempt < max(1, max_retries) - 1:
delay = 1.0 * (2**attempt)
logger.warning(
"Worker %s raised %s (attempt %d/%d), retrying in %.1fs",
self.node_spec.id,
type(exc).__name__,
attempt + 1,
max_retries + 1,
max(1, max_retries),
delay,
)
await asyncio.sleep(delay)
continue
return NodeResult(success=False, error=str(exc))
return NodeResult(
success=False,
error=f"failed after {attempt + 1} attempts: {exc}",
)
return NodeResult(success=False, error="Max retries exceeded")
return NodeResult(
success=False,
error=f"failed after {max(1, max_retries)} attempts",
)
# ------------------------------------------------------------------
# Edge evaluation (source-side)
@@ -624,88 +589,10 @@ class WorkerAgent:
def _build_node_context(self) -> NodeContext:
"""Build NodeContext for this worker's execution."""
gc = self._gc
node_spec = self.node_spec
# Filter tools
if gc.is_continuous and gc.cumulative_tools:
available_tools = list(gc.cumulative_tools)
else:
available_tools = []
if node_spec.tools:
available_tools = [t for t in gc.tools if t.name in node_spec.tools]
# Scoped buffer
read_keys = list(node_spec.input_keys)
write_keys = list(node_spec.output_keys)
if read_keys or write_keys:
from framework.skills.defaults import DATA_BUFFER_KEYS as _skill_keys
existing_underscore = [k for k in gc.buffer._data if k.startswith("_")]
extra_keys = set(_skill_keys) | set(existing_underscore)
for k in extra_keys:
if read_keys and k not in read_keys:
read_keys.append(k)
if write_keys and k not in write_keys:
write_keys.append(k)
scoped_buffer = gc.buffer.with_permissions(read_keys=read_keys, write_keys=write_keys)
# Per-node accounts prompt
node_accounts_prompt = gc.accounts_prompt
if gc.accounts_data and gc.tool_provider_map:
from framework.graph.prompt_composer import build_accounts_prompt
node_accounts_prompt = build_accounts_prompt(
gc.accounts_data,
gc.tool_provider_map,
node_tool_names=node_spec.tools,
) or gc.accounts_prompt
# Input data from buffer
input_data: dict[str, Any] = {}
for key in node_spec.input_keys:
val = gc.buffer.read(key)
if val is not None:
input_data[key] = val
# Continuous mode: thread conversation
inherited_conversation = None
if gc.is_continuous and gc.continuous_conversation:
inherited_conversation = gc.continuous_conversation
return NodeContext(
runtime=gc.runtime,
node_id=node_spec.id,
node_spec=node_spec,
buffer=scoped_buffer,
input_data=input_data,
llm=gc.llm,
available_tools=available_tools,
goal_context=gc.goal.to_prompt_context(),
goal=gc.goal,
max_tokens=gc.graph.max_tokens,
runtime_logger=gc.runtime_logger,
return build_node_context_from_graph_context(
self._gc,
node_spec=self.node_spec,
pause_event=self._pause_requested,
continuous_mode=gc.is_continuous,
inherited_conversation=inherited_conversation,
cumulative_output_keys=list(gc.cumulative_output_keys) if gc.is_continuous else [],
accounts_prompt=node_accounts_prompt,
identity_prompt=getattr(gc.graph, "identity_prompt", "") or "",
execution_id=gc.execution_id,
run_id=gc.run_id,
stream_id=gc.stream_id,
node_registry=gc.node_spec_registry,
all_tools=list(gc.tools),
shared_node_registry=gc.node_registry,
dynamic_tools_provider=gc.dynamic_tools_provider,
dynamic_prompt_provider=gc.dynamic_prompt_provider,
iteration_metadata_provider=gc.iteration_metadata_provider,
skills_catalog_prompt=gc.skills_catalog_prompt,
protocols_prompt=gc.protocols_prompt,
skill_dirs=list(gc.skill_dirs),
default_skill_warn_ratio=gc.context_warn_ratio,
default_skill_batch_nudge=gc.batch_init_nudge,
)
# ------------------------------------------------------------------
+21 -7
View File
@@ -398,25 +398,39 @@ async def _smoke_test_provider_async(provider: dict, timeout_seconds: float = 25
"branch execution did not reach the expected terminal path: "
f"{result.path}"
)
if result.output.get("result") != "BRANCH_OK":
if not result.output.get("result"):
raise RuntimeError(
"branch execution completed but did not produce result='BRANCH_OK'"
"branch execution reached the expected terminal path but did not "
f"produce a non-empty result output: path={result.path} "
f"output={result.output}"
)
current_step = "plain completion"
current_timeout = timeout_seconds
worker_timeout = max(
timeout_seconds,
float(os.environ.get("DUMMY_AGENT_SMOKE_WORKER_TIMEOUT_SECS", "30")),
)
branch_timeout = max(
timeout_seconds,
float(os.environ.get("DUMMY_AGENT_SMOKE_BRANCH_TIMEOUT_SECS", "60")),
)
try:
await asyncio.wait_for(_run_plain_completion(), timeout=timeout_seconds)
await asyncio.wait_for(_run_plain_completion(), timeout=current_timeout)
current_step = "tool calling"
await asyncio.wait_for(_run_tool_completion(), timeout=timeout_seconds)
current_timeout = timeout_seconds
await asyncio.wait_for(_run_tool_completion(), timeout=current_timeout)
current_step = "single-node worker execution"
await asyncio.wait_for(_run_worker_execution(), timeout=timeout_seconds)
current_timeout = worker_timeout
await asyncio.wait_for(_run_worker_execution(), timeout=current_timeout)
current_step = "branch worker execution"
await asyncio.wait_for(_run_branch_execution(), timeout=timeout_seconds)
current_timeout = branch_timeout
await asyncio.wait_for(_run_branch_execution(), timeout=current_timeout)
except TimeoutError as exc:
raise RuntimeError(
f"provider smoke test timed out during {current_step} "
f"after {timeout_seconds:.0f}s"
f"after {current_timeout:.0f}s"
) from exc
+190
View File
@@ -0,0 +1,190 @@
"""Unit tests for shared graph context helpers."""
from __future__ import annotations
from framework.graph.context import (
GraphContext,
build_node_accounts_prompt,
build_node_context,
build_node_context_from_graph_context,
build_scoped_buffer,
)
from framework.graph.edge import GraphSpec
from framework.graph.goal import Goal
from framework.graph.node import DataBuffer, NodeSpec
from framework.llm.provider import Tool
from framework.skills.defaults import DATA_BUFFER_KEYS
class DummyRuntime:
execution_id = ""
def _make_tool(name: str) -> Tool:
return Tool(
name=name,
description=f"Tool {name}",
parameters={"type": "object", "properties": {}},
)
def _make_goal() -> Goal:
return Goal(id="goal-1", name="Goal", description="Test goal")
def _make_graph(node_spec: NodeSpec) -> GraphSpec:
return GraphSpec(
id="graph-1",
goal_id="goal-1",
nodes=[node_spec],
edges=[],
entry_node=node_spec.id,
terminal_nodes=[node_spec.id],
max_tokens=2048,
)
def test_build_scoped_buffer_includes_skill_and_existing_internal_keys():
buffer = DataBuffer()
buffer.write("task", "draft")
buffer.write("_worker_state", "active")
node_spec = NodeSpec(
id="writer",
name="Writer",
description="Writes output",
node_type="event_loop",
input_keys=["task"],
output_keys=["result"],
)
scoped = build_scoped_buffer(buffer, node_spec)
assert "task" in scoped._allowed_read
assert "result" in scoped._allowed_write
assert "_worker_state" in scoped._allowed_read
assert "_worker_state" in scoped._allowed_write
for key in DATA_BUFFER_KEYS:
assert key in scoped._allowed_read
assert key in scoped._allowed_write
def test_build_scoped_buffer_keeps_empty_permissions_unrestricted():
buffer = DataBuffer()
buffer.write("task", "draft")
buffer.write("_worker_state", "active")
node_spec = NodeSpec(
id="reader",
name="Reader",
description="Reads everything",
node_type="event_loop",
input_keys=[],
output_keys=[],
)
scoped = build_scoped_buffer(buffer, node_spec)
assert scoped._allowed_read == set()
assert scoped._allowed_write == set()
assert scoped.read_all()["task"] == "draft"
assert scoped.read_all()["_worker_state"] == "active"
def test_accounts_prompt_falls_back_when_filtered_prompt_is_empty():
prompt = build_node_accounts_prompt(
accounts_prompt="DEFAULT_ACCOUNTS",
accounts_data=[{"provider": "google", "alias": "personal", "identity": {}}],
tool_provider_map={"gmail_list_messages": "google"},
node_tool_names=["slack_post_message"],
fallback_to_default=True,
)
assert prompt == "DEFAULT_ACCOUNTS"
def test_build_node_context_from_graph_context_preserves_continuous_state():
node_spec = NodeSpec(
id="writer",
name="Writer",
description="Writes output",
node_type="event_loop",
input_keys=["task"],
output_keys=["draft"],
tools=["save_data"],
)
buffer = DataBuffer()
buffer.write("task", "write the draft")
conversation = object()
save_data = _make_tool("save_data")
fallback_tool = _make_tool("web_search")
graph = _make_graph(node_spec)
graph_context = GraphContext(
graph=graph,
goal=_make_goal(),
buffer=buffer,
runtime=DummyRuntime(),
llm=None,
tools=[save_data, fallback_tool],
tool_executor=None,
event_bus=None,
execution_id="exec-1",
stream_id="stream-1",
run_id="run-1",
storage_path=None,
is_continuous=True,
continuous_conversation=conversation,
cumulative_tools=[fallback_tool],
cumulative_output_keys=["outline", "draft"],
accounts_prompt="ACCOUNTS",
skills_catalog_prompt="SKILLS",
protocols_prompt="PROTOCOLS",
)
ctx = build_node_context_from_graph_context(
graph_context,
node_spec=node_spec,
pause_event="pause-signal",
)
assert ctx.input_data == {"task": "write the draft"}
assert ctx.inherited_conversation is conversation
assert ctx.cumulative_output_keys == ["outline", "draft"]
assert [tool.name for tool in ctx.available_tools] == ["web_search"]
assert ctx.pause_event == "pause-signal"
assert ctx.accounts_prompt == "ACCOUNTS"
assert ctx.skills_catalog_prompt == "SKILLS"
assert ctx.protocols_prompt == "PROTOCOLS"
def test_build_node_context_uses_override_tools_for_legacy_executor_path():
node_spec = NodeSpec(
id="branch",
name="Branch",
description="Legacy branch execution",
node_type="event_loop",
input_keys=["task"],
output_keys=["result"],
tools=["save_data"],
)
buffer = DataBuffer()
save_data = _make_tool("save_data")
web_search = _make_tool("web_search")
ctx = build_node_context(
runtime=DummyRuntime(),
node_spec=node_spec,
buffer=buffer,
goal=_make_goal(),
llm=None,
tools=[save_data],
max_tokens=1024,
input_data={"task": "run branch"},
override_tools=[save_data, web_search],
node_registry={"branch": node_spec},
)
assert ctx.input_data == {"task": "run branch"}
assert [tool.name for tool in ctx.available_tools] == ["save_data", "web_search"]
assert ctx.node_registry == {"branch": node_spec}