Compare commits

...

1 Commits

Author SHA1 Message Date
bryan 36eaf56786 WP-11 2026-01-30 20:42:46 -08:00
5 changed files with 567 additions and 34 deletions
+4
View File
@@ -1,6 +1,7 @@
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
from framework.graph.context_handoff import ContextHandoff, HandoffContext
from framework.graph.conversation import ConversationStore, Message, NodeConversation
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.executor import GraphExecutor
@@ -77,4 +78,7 @@ __all__ = [
"NodeConversation",
"ConversationStore",
"Message",
# Context Handoff
"ContextHandoff",
"HandoffContext",
]
+191
View File
@@ -0,0 +1,191 @@
"""Context handoff: summarize a completed NodeConversation for the next graph node."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from framework.graph.conversation import _try_extract_key
if TYPE_CHECKING:
from framework.graph.conversation import NodeConversation
from framework.llm.provider import LLMProvider
logger = logging.getLogger(__name__)
_TRUNCATE_CHARS = 500
# ---------------------------------------------------------------------------
# Data
# ---------------------------------------------------------------------------
@dataclass
class HandoffContext:
"""Structured summary of a completed node conversation."""
source_node_id: str
summary: str
key_outputs: dict[str, Any]
turn_count: int
total_tokens_used: int
# ---------------------------------------------------------------------------
# ContextHandoff
# ---------------------------------------------------------------------------
class ContextHandoff:
"""Summarize a completed NodeConversation into a HandoffContext.
Parameters
----------
llm : LLMProvider | None
Optional LLM provider for abstractive summarization.
When *None*, all summarization uses the extractive fallback.
"""
def __init__(self, llm: LLMProvider | None = None) -> None:
self.llm = llm
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def summarize_conversation(
self,
conversation: NodeConversation,
node_id: str,
output_keys: list[str] | None = None,
) -> HandoffContext:
"""Produce a HandoffContext from *conversation*.
1. Extracts turn_count & total_tokens_used (sync properties).
2. Extracts key_outputs by scanning assistant messages most-recent-first.
3. Builds a summary via the LLM (if available) or extractive fallback.
"""
turn_count = conversation.turn_count
total_tokens_used = conversation.estimate_tokens()
messages = conversation.messages # defensive copy
# --- key outputs ---------------------------------------------------
key_outputs: dict[str, Any] = {}
if output_keys:
remaining = set(output_keys)
for msg in reversed(messages):
if msg.role != "assistant" or not remaining:
continue
for key in list(remaining):
value = _try_extract_key(msg.content, key)
if value is not None:
key_outputs[key] = value
remaining.discard(key)
# --- summary -------------------------------------------------------
if self.llm is not None:
try:
summary = self._llm_summary(messages, output_keys or [])
except Exception:
logger.warning(
"LLM summarization failed; falling back to extractive.",
exc_info=True,
)
summary = self._extractive_summary(messages)
else:
summary = self._extractive_summary(messages)
return HandoffContext(
source_node_id=node_id,
summary=summary,
key_outputs=key_outputs,
turn_count=turn_count,
total_tokens_used=total_tokens_used,
)
@staticmethod
def format_as_input(handoff: HandoffContext) -> str:
"""Render *handoff* as structured plain text for the next node's input."""
header = (
f"--- CONTEXT FROM: {handoff.source_node_id} "
f"({handoff.turn_count} turns, ~{handoff.total_tokens_used} tokens) ---"
)
sections: list[str] = [header, ""]
if handoff.key_outputs:
sections.append("KEY OUTPUTS:")
for k, v in handoff.key_outputs.items():
sections.append(f"- {k}: {v}")
sections.append("")
summary_text = handoff.summary or "No summary available."
sections.append("SUMMARY:")
sections.append(summary_text)
sections.append("")
sections.append("--- END CONTEXT ---")
return "\n".join(sections)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
@staticmethod
def _extractive_summary(messages: list) -> str:
"""Build a summary from key assistant messages without an LLM.
Strategy:
- Include the first assistant message (initial assessment).
- Include the last assistant message (final conclusion).
- Truncate each to ~500 chars.
"""
if not messages:
return "Empty conversation."
assistant_msgs = [m for m in messages if m.role == "assistant"]
if not assistant_msgs:
return "No assistant responses."
parts: list[str] = []
first = assistant_msgs[0].content
parts.append(first[:_TRUNCATE_CHARS])
if len(assistant_msgs) > 1:
last = assistant_msgs[-1].content
parts.append(last[:_TRUNCATE_CHARS])
return "\n\n".join(parts)
def _llm_summary(self, messages: list, output_keys: list[str]) -> str:
"""Produce a summary by calling the LLM provider."""
if self.llm is None:
raise ValueError("_llm_summary called without an LLM provider")
conversation_text = "\n".join(f"[{m.role}]: {m.content}" for m in messages)
key_hint = ""
if output_keys:
key_hint = (
"\nThe following output keys are especially important: "
+ ", ".join(output_keys)
+ ".\n"
)
system_prompt = (
"You are a concise summarizer. Given the conversation below, "
"produce a brief summary (at most ~500 tokens) that captures the "
"key decisions, findings, and outcomes. Focus on what was concluded "
"rather than the back-and-forth process." + key_hint
)
response = self.llm.complete(
messages=[{"role": "user", "content": conversation_text}],
system=system_prompt,
max_tokens=500,
)
return response.content.strip()
+45 -33
View File
@@ -108,6 +108,50 @@ class ConversationStore(Protocol):
# ---------------------------------------------------------------------------
def _try_extract_key(content: str, key: str) -> str | None:
"""Try 4 strategies to extract a *key*'s value from message content.
Strategies (in order):
1. Whole message is JSON ``json.loads``, check for key.
2. Embedded JSON via ``find_json_object`` helper.
3. Colon format: ``key: value``.
4. Equals format: ``key = value``.
"""
from framework.graph.node import find_json_object
# 1. Whole message is JSON
try:
parsed = json.loads(content)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 2. Embedded JSON via find_json_object
json_str = find_json_object(content)
if json_str:
try:
parsed = json.loads(json_str)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 3. Colon format: key: value
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
if match:
return match.group(1).strip()
# 4. Equals format: key = value
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
if match:
return match.group(1).strip()
return None
class NodeConversation:
"""Message history for a graph node with optional write-through persistence.
@@ -244,39 +288,7 @@ class NodeConversation:
def _try_extract_key(self, content: str, key: str) -> str | None:
"""Try 4 strategies to extract a key's value from message content."""
from framework.graph.node import find_json_object
# 1. Whole message is JSON
try:
parsed = json.loads(content)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 2. Embedded JSON via find_json_object
json_str = find_json_object(content)
if json_str:
try:
parsed = json.loads(json_str)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 3. Colon format: key: value
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
if match:
return match.group(1).strip()
# 4. Equals format: key = value
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
if match:
return match.group(1).strip()
return None
return _try_extract_key(content, key)
# --- Lifecycle ---------------------------------------------------------
+326
View File
@@ -0,0 +1,326 @@
"""Tests for ContextHandoff and HandoffContext."""
from __future__ import annotations
from typing import Any
import pytest
from framework.graph.context_handoff import ContextHandoff, HandoffContext
from framework.graph.conversation import NodeConversation
from framework.llm.mock import MockLLMProvider
from framework.llm.provider import LLMProvider, LLMResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class SpyLLMProvider(MockLLMProvider):
"""MockLLMProvider that records whether complete() was called."""
def __init__(self) -> None:
super().__init__()
self.complete_called = False
self.complete_call_args: dict[str, Any] | None = None
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
self.complete_called = True
self.complete_call_args = {"messages": messages, **kwargs}
return super().complete(messages, **kwargs)
class FailingLLMProvider(LLMProvider):
"""LLM provider that always raises."""
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
raise RuntimeError("LLM unavailable")
def complete_with_tools(
self,
messages: list[dict[str, Any]],
system: str,
tools: list,
tool_executor: Any,
max_iterations: int = 10,
) -> LLMResponse:
raise RuntimeError("LLM unavailable")
async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation:
"""Build a NodeConversation from (user, assistant) message pairs."""
conv = NodeConversation()
for user_msg, assistant_msg in pairs:
await conv.add_user_message(user_msg)
await conv.add_assistant_message(assistant_msg)
return conv
# ---------------------------------------------------------------------------
# TestHandoffContext
# ---------------------------------------------------------------------------
class TestHandoffContext:
def test_instantiation(self) -> None:
hc = HandoffContext(
source_node_id="node_A",
summary="Summary text",
key_outputs={"result": "42"},
turn_count=3,
total_tokens_used=1200,
)
assert hc.source_node_id == "node_A"
assert hc.summary == "Summary text"
assert hc.key_outputs == {"result": "42"}
assert hc.turn_count == 3
assert hc.total_tokens_used == 1200
def test_field_access(self) -> None:
hc = HandoffContext(
source_node_id="n1",
summary="s",
key_outputs={},
turn_count=0,
total_tokens_used=0,
)
assert hc.key_outputs == {}
# ---------------------------------------------------------------------------
# TestExtractiveSummary
# ---------------------------------------------------------------------------
class TestExtractiveSummary:
@pytest.mark.asyncio
async def test_extractive_summary_includes_first_last(self) -> None:
conv = await _build_conversation(
("hello", "First response here."),
("continue", "Middle response."),
("finish", "Final conclusion."),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="test_node")
assert "First response here." in hc.summary
assert "Final conclusion." in hc.summary
@pytest.mark.asyncio
async def test_extractive_summary_metadata(self) -> None:
conv = await _build_conversation(
("hi", "hello"),
("bye", "goodbye"),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="node_42")
assert hc.source_node_id == "node_42"
assert hc.turn_count == 2
assert hc.total_tokens_used > 0
@pytest.mark.asyncio
async def test_extractive_with_output_keys_colon(self) -> None:
conv = await _build_conversation(
("what is the answer?", "answer: 42"),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["answer"])
assert hc.key_outputs["answer"] == "42"
@pytest.mark.asyncio
async def test_extractive_with_output_keys_equals(self) -> None:
conv = await _build_conversation(
("compute", "result = success"),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["result"])
assert hc.key_outputs["result"] == "success"
@pytest.mark.asyncio
async def test_extractive_json_output_keys(self) -> None:
conv = await _build_conversation(
("give me json", '{"score": 95, "grade": "A"}'),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
assert hc.key_outputs["score"] == "95"
assert hc.key_outputs["grade"] == "A"
@pytest.mark.asyncio
async def test_extractive_empty_conversation(self) -> None:
conv = NodeConversation()
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="empty")
assert hc.summary == "Empty conversation."
assert hc.turn_count == 0
assert hc.key_outputs == {}
@pytest.mark.asyncio
async def test_extractive_no_assistant_messages(self) -> None:
conv = NodeConversation()
await conv.add_user_message("hello?")
await conv.add_user_message("anyone there?")
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="silent")
assert hc.summary == "No assistant responses."
@pytest.mark.asyncio
async def test_extractive_most_recent_wins(self) -> None:
conv = await _build_conversation(
("first", "status: old_value"),
("second", "status: new_value"),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["status"])
assert hc.key_outputs["status"] == "new_value"
@pytest.mark.asyncio
async def test_extractive_truncation(self) -> None:
long_text = "x" * 1000
conv = await _build_conversation(
("go", long_text),
)
ch = ContextHandoff()
hc = ch.summarize_conversation(conv, node_id="n")
# Summary should be truncated to ~500 chars
assert len(hc.summary) <= 500
# ---------------------------------------------------------------------------
# TestLLMSummary
# ---------------------------------------------------------------------------
class TestLLMSummary:
@pytest.mark.asyncio
async def test_llm_summary_calls_provider(self) -> None:
llm = SpyLLMProvider()
conv = await _build_conversation(
("hi", "hello back"),
("what now?", "we are done"),
)
ch = ContextHandoff(llm=llm)
hc = ch.summarize_conversation(conv, node_id="llm_node")
assert llm.complete_called, "LLM complete() was never invoked"
assert hc.summary == "This is a mock response for testing purposes."
@pytest.mark.asyncio
async def test_llm_summary_includes_output_key_hint(self) -> None:
llm = SpyLLMProvider()
conv = await _build_conversation(
("compute", '{"score": 95}'),
)
ch = ContextHandoff(llm=llm)
ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
assert llm.complete_call_args is not None
system = llm.complete_call_args.get("system", "")
assert "score" in system
assert "grade" in system
@pytest.mark.asyncio
async def test_llm_fallback_on_error(self) -> None:
llm = FailingLLMProvider()
conv = await _build_conversation(
("start", "First assistant message."),
("end", "Last assistant message."),
)
ch = ContextHandoff(llm=llm)
hc = ch.summarize_conversation(conv, node_id="fallback_node")
# Should fall back to extractive (first + last assistant messages)
assert "First assistant message." in hc.summary
assert "Last assistant message." in hc.summary
# ---------------------------------------------------------------------------
# TestFormatAsInput
# ---------------------------------------------------------------------------
class TestFormatAsInput:
def test_format_structure(self) -> None:
hc = HandoffContext(
source_node_id="analyzer",
summary="Analysis complete.",
key_outputs={"score": "95"},
turn_count=5,
total_tokens_used=2000,
)
output = ContextHandoff.format_as_input(hc)
assert "--- CONTEXT FROM: analyzer" in output
assert "KEY OUTPUTS:" in output
assert "SUMMARY:" in output
assert "--- END CONTEXT ---" in output
def test_format_no_key_outputs(self) -> None:
hc = HandoffContext(
source_node_id="simple",
summary="Done.",
key_outputs={},
turn_count=1,
total_tokens_used=100,
)
output = ContextHandoff.format_as_input(hc)
assert "KEY OUTPUTS:" not in output
assert "SUMMARY:" in output
def test_format_content_values(self) -> None:
hc = HandoffContext(
source_node_id="node_X",
summary="Found 3 bugs.",
key_outputs={"bugs": "3", "severity": "high"},
turn_count=7,
total_tokens_used=5000,
)
output = ContextHandoff.format_as_input(hc)
assert "node_X" in output
assert "7 turns" in output
assert "~5000 tokens" in output
assert "- bugs: 3" in output
assert "- severity: high" in output
assert "Found 3 bugs." in output
def test_format_empty_summary(self) -> None:
hc = HandoffContext(
source_node_id="n",
summary="",
key_outputs={},
turn_count=0,
total_tokens_used=0,
)
output = ContextHandoff.format_as_input(hc)
assert "No summary available." in output
@pytest.mark.asyncio
async def test_format_as_input_usable_as_message(self) -> None:
"""Formatted output can be fed into a NodeConversation as a user message."""
hc = HandoffContext(
source_node_id="prev_node",
summary="Completed analysis.",
key_outputs={"result": "42"},
turn_count=3,
total_tokens_used=900,
)
text = ContextHandoff.format_as_input(hc)
conv = NodeConversation()
msg = await conv.add_user_message(text)
assert msg.role == "user"
assert "CONTEXT FROM: prev_node" in msg.content
assert conv.turn_count == 1
@@ -34,7 +34,7 @@ def register_tools(mcp: FastMCP) -> None:
Returns:
dict with success status, data, and metadata
"""
if (offset < 0 or (limit is not None and limit < 0)):
if offset < 0 or (limit is not None and limit < 0):
return {"error": "offset and limit must be non-negative"}
try:
secure_path = get_secure_path(path, workspace_id, agent_id, session_id)