From 36eaf5678662d3d1b9e9613758599c7bbf7df05b Mon Sep 17 00:00:00 2001 From: bryan Date: Fri, 30 Jan 2026 20:42:46 -0800 Subject: [PATCH] WP-11 --- core/framework/graph/__init__.py | 4 + core/framework/graph/context_handoff.py | 191 ++++++++++ core/framework/graph/conversation.py | 78 +++-- core/tests/test_context_handoff.py | 326 ++++++++++++++++++ .../src/aden_tools/tools/csv_tool/csv_tool.py | 2 +- 5 files changed, 567 insertions(+), 34 deletions(-) create mode 100644 core/framework/graph/context_handoff.py create mode 100644 core/tests/test_context_handoff.py diff --git a/core/framework/graph/__init__.py b/core/framework/graph/__init__.py index fcad344c..d404c41f 100644 --- a/core/framework/graph/__init__.py +++ b/core/framework/graph/__init__.py @@ -1,6 +1,7 @@ """Graph structures: Goals, Nodes, Edges, and Flexible Execution.""" from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec +from framework.graph.context_handoff import ContextHandoff, HandoffContext from framework.graph.conversation import ConversationStore, Message, NodeConversation from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec from framework.graph.executor import GraphExecutor @@ -77,4 +78,7 @@ __all__ = [ "NodeConversation", "ConversationStore", "Message", + # Context Handoff + "ContextHandoff", + "HandoffContext", ] diff --git a/core/framework/graph/context_handoff.py b/core/framework/graph/context_handoff.py new file mode 100644 index 00000000..69831506 --- /dev/null +++ b/core/framework/graph/context_handoff.py @@ -0,0 +1,191 @@ +"""Context handoff: summarize a completed NodeConversation for the next graph node.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from framework.graph.conversation import _try_extract_key + +if TYPE_CHECKING: + from framework.graph.conversation import NodeConversation + from framework.llm.provider import LLMProvider + +logger = logging.getLogger(__name__) + +_TRUNCATE_CHARS = 500 + + +# --------------------------------------------------------------------------- +# Data +# --------------------------------------------------------------------------- + + +@dataclass +class HandoffContext: + """Structured summary of a completed node conversation.""" + + source_node_id: str + summary: str + key_outputs: dict[str, Any] + turn_count: int + total_tokens_used: int + + +# --------------------------------------------------------------------------- +# ContextHandoff +# --------------------------------------------------------------------------- + + +class ContextHandoff: + """Summarize a completed NodeConversation into a HandoffContext. + + Parameters + ---------- + llm : LLMProvider | None + Optional LLM provider for abstractive summarization. + When *None*, all summarization uses the extractive fallback. + """ + + def __init__(self, llm: LLMProvider | None = None) -> None: + self.llm = llm + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def summarize_conversation( + self, + conversation: NodeConversation, + node_id: str, + output_keys: list[str] | None = None, + ) -> HandoffContext: + """Produce a HandoffContext from *conversation*. + + 1. Extracts turn_count & total_tokens_used (sync properties). + 2. Extracts key_outputs by scanning assistant messages most-recent-first. + 3. Builds a summary via the LLM (if available) or extractive fallback. + """ + turn_count = conversation.turn_count + total_tokens_used = conversation.estimate_tokens() + messages = conversation.messages # defensive copy + + # --- key outputs --------------------------------------------------- + key_outputs: dict[str, Any] = {} + if output_keys: + remaining = set(output_keys) + for msg in reversed(messages): + if msg.role != "assistant" or not remaining: + continue + for key in list(remaining): + value = _try_extract_key(msg.content, key) + if value is not None: + key_outputs[key] = value + remaining.discard(key) + + # --- summary ------------------------------------------------------- + if self.llm is not None: + try: + summary = self._llm_summary(messages, output_keys or []) + except Exception: + logger.warning( + "LLM summarization failed; falling back to extractive.", + exc_info=True, + ) + summary = self._extractive_summary(messages) + else: + summary = self._extractive_summary(messages) + + return HandoffContext( + source_node_id=node_id, + summary=summary, + key_outputs=key_outputs, + turn_count=turn_count, + total_tokens_used=total_tokens_used, + ) + + @staticmethod + def format_as_input(handoff: HandoffContext) -> str: + """Render *handoff* as structured plain text for the next node's input.""" + header = ( + f"--- CONTEXT FROM: {handoff.source_node_id} " + f"({handoff.turn_count} turns, ~{handoff.total_tokens_used} tokens) ---" + ) + + sections: list[str] = [header, ""] + + if handoff.key_outputs: + sections.append("KEY OUTPUTS:") + for k, v in handoff.key_outputs.items(): + sections.append(f"- {k}: {v}") + sections.append("") + + summary_text = handoff.summary or "No summary available." + sections.append("SUMMARY:") + sections.append(summary_text) + sections.append("") + sections.append("--- END CONTEXT ---") + + return "\n".join(sections) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extractive_summary(messages: list) -> str: + """Build a summary from key assistant messages without an LLM. + + Strategy: + - Include the first assistant message (initial assessment). + - Include the last assistant message (final conclusion). + - Truncate each to ~500 chars. + """ + if not messages: + return "Empty conversation." + + assistant_msgs = [m for m in messages if m.role == "assistant"] + if not assistant_msgs: + return "No assistant responses." + + parts: list[str] = [] + + first = assistant_msgs[0].content + parts.append(first[:_TRUNCATE_CHARS]) + + if len(assistant_msgs) > 1: + last = assistant_msgs[-1].content + parts.append(last[:_TRUNCATE_CHARS]) + + return "\n\n".join(parts) + + def _llm_summary(self, messages: list, output_keys: list[str]) -> str: + """Produce a summary by calling the LLM provider.""" + if self.llm is None: + raise ValueError("_llm_summary called without an LLM provider") + + conversation_text = "\n".join(f"[{m.role}]: {m.content}" for m in messages) + + key_hint = "" + if output_keys: + key_hint = ( + "\nThe following output keys are especially important: " + + ", ".join(output_keys) + + ".\n" + ) + + system_prompt = ( + "You are a concise summarizer. Given the conversation below, " + "produce a brief summary (at most ~500 tokens) that captures the " + "key decisions, findings, and outcomes. Focus on what was concluded " + "rather than the back-and-forth process." + key_hint + ) + + response = self.llm.complete( + messages=[{"role": "user", "content": conversation_text}], + system=system_prompt, + max_tokens=500, + ) + + return response.content.strip() diff --git a/core/framework/graph/conversation.py b/core/framework/graph/conversation.py index 9239da2b..590203a3 100644 --- a/core/framework/graph/conversation.py +++ b/core/framework/graph/conversation.py @@ -108,6 +108,50 @@ class ConversationStore(Protocol): # --------------------------------------------------------------------------- +def _try_extract_key(content: str, key: str) -> str | None: + """Try 4 strategies to extract a *key*'s value from message content. + + Strategies (in order): + 1. Whole message is JSON — ``json.loads``, check for key. + 2. Embedded JSON via ``find_json_object`` helper. + 3. Colon format: ``key: value``. + 4. Equals format: ``key = value``. + """ + from framework.graph.node import find_json_object + + # 1. Whole message is JSON + try: + parsed = json.loads(content) + if isinstance(parsed, dict) and key in parsed: + val = parsed[key] + return json.dumps(val) if not isinstance(val, str) else val + except (json.JSONDecodeError, TypeError): + pass + + # 2. Embedded JSON via find_json_object + json_str = find_json_object(content) + if json_str: + try: + parsed = json.loads(json_str) + if isinstance(parsed, dict) and key in parsed: + val = parsed[key] + return json.dumps(val) if not isinstance(val, str) else val + except (json.JSONDecodeError, TypeError): + pass + + # 3. Colon format: key: value + match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content) + if match: + return match.group(1).strip() + + # 4. Equals format: key = value + match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content) + if match: + return match.group(1).strip() + + return None + + class NodeConversation: """Message history for a graph node with optional write-through persistence. @@ -244,39 +288,7 @@ class NodeConversation: def _try_extract_key(self, content: str, key: str) -> str | None: """Try 4 strategies to extract a key's value from message content.""" - from framework.graph.node import find_json_object - - # 1. Whole message is JSON - try: - parsed = json.loads(content) - if isinstance(parsed, dict) and key in parsed: - val = parsed[key] - return json.dumps(val) if not isinstance(val, str) else val - except (json.JSONDecodeError, TypeError): - pass - - # 2. Embedded JSON via find_json_object - json_str = find_json_object(content) - if json_str: - try: - parsed = json.loads(json_str) - if isinstance(parsed, dict) and key in parsed: - val = parsed[key] - return json.dumps(val) if not isinstance(val, str) else val - except (json.JSONDecodeError, TypeError): - pass - - # 3. Colon format: key: value - match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content) - if match: - return match.group(1).strip() - - # 4. Equals format: key = value - match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content) - if match: - return match.group(1).strip() - - return None + return _try_extract_key(content, key) # --- Lifecycle --------------------------------------------------------- diff --git a/core/tests/test_context_handoff.py b/core/tests/test_context_handoff.py new file mode 100644 index 00000000..4e2df2b1 --- /dev/null +++ b/core/tests/test_context_handoff.py @@ -0,0 +1,326 @@ +"""Tests for ContextHandoff and HandoffContext.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from framework.graph.context_handoff import ContextHandoff, HandoffContext +from framework.graph.conversation import NodeConversation +from framework.llm.mock import MockLLMProvider +from framework.llm.provider import LLMProvider, LLMResponse + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class SpyLLMProvider(MockLLMProvider): + """MockLLMProvider that records whether complete() was called.""" + + def __init__(self) -> None: + super().__init__() + self.complete_called = False + self.complete_call_args: dict[str, Any] | None = None + + def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse: + self.complete_called = True + self.complete_call_args = {"messages": messages, **kwargs} + return super().complete(messages, **kwargs) + + +class FailingLLMProvider(LLMProvider): + """LLM provider that always raises.""" + + def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse: + raise RuntimeError("LLM unavailable") + + def complete_with_tools( + self, + messages: list[dict[str, Any]], + system: str, + tools: list, + tool_executor: Any, + max_iterations: int = 10, + ) -> LLMResponse: + raise RuntimeError("LLM unavailable") + + +async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation: + """Build a NodeConversation from (user, assistant) message pairs.""" + conv = NodeConversation() + for user_msg, assistant_msg in pairs: + await conv.add_user_message(user_msg) + await conv.add_assistant_message(assistant_msg) + return conv + + +# --------------------------------------------------------------------------- +# TestHandoffContext +# --------------------------------------------------------------------------- + + +class TestHandoffContext: + def test_instantiation(self) -> None: + hc = HandoffContext( + source_node_id="node_A", + summary="Summary text", + key_outputs={"result": "42"}, + turn_count=3, + total_tokens_used=1200, + ) + assert hc.source_node_id == "node_A" + assert hc.summary == "Summary text" + assert hc.key_outputs == {"result": "42"} + assert hc.turn_count == 3 + assert hc.total_tokens_used == 1200 + + def test_field_access(self) -> None: + hc = HandoffContext( + source_node_id="n1", + summary="s", + key_outputs={}, + turn_count=0, + total_tokens_used=0, + ) + assert hc.key_outputs == {} + + +# --------------------------------------------------------------------------- +# TestExtractiveSummary +# --------------------------------------------------------------------------- + + +class TestExtractiveSummary: + @pytest.mark.asyncio + async def test_extractive_summary_includes_first_last(self) -> None: + conv = await _build_conversation( + ("hello", "First response here."), + ("continue", "Middle response."), + ("finish", "Final conclusion."), + ) + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="test_node") + + assert "First response here." in hc.summary + assert "Final conclusion." in hc.summary + + @pytest.mark.asyncio + async def test_extractive_summary_metadata(self) -> None: + conv = await _build_conversation( + ("hi", "hello"), + ("bye", "goodbye"), + ) + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="node_42") + + assert hc.source_node_id == "node_42" + assert hc.turn_count == 2 + assert hc.total_tokens_used > 0 + + @pytest.mark.asyncio + async def test_extractive_with_output_keys_colon(self) -> None: + conv = await _build_conversation( + ("what is the answer?", "answer: 42"), + ) + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="n", output_keys=["answer"]) + + assert hc.key_outputs["answer"] == "42" + + @pytest.mark.asyncio + async def test_extractive_with_output_keys_equals(self) -> None: + conv = await _build_conversation( + ("compute", "result = success"), + ) + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="n", output_keys=["result"]) + + assert hc.key_outputs["result"] == "success" + + @pytest.mark.asyncio + async def test_extractive_json_output_keys(self) -> None: + conv = await _build_conversation( + ("give me json", '{"score": 95, "grade": "A"}'), + ) + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"]) + + assert hc.key_outputs["score"] == "95" + assert hc.key_outputs["grade"] == "A" + + @pytest.mark.asyncio + async def test_extractive_empty_conversation(self) -> None: + conv = NodeConversation() + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="empty") + + assert hc.summary == "Empty conversation." + assert hc.turn_count == 0 + assert hc.key_outputs == {} + + @pytest.mark.asyncio + async def test_extractive_no_assistant_messages(self) -> None: + conv = NodeConversation() + await conv.add_user_message("hello?") + await conv.add_user_message("anyone there?") + + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="silent") + + assert hc.summary == "No assistant responses." + + @pytest.mark.asyncio + async def test_extractive_most_recent_wins(self) -> None: + conv = await _build_conversation( + ("first", "status: old_value"), + ("second", "status: new_value"), + ) + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="n", output_keys=["status"]) + + assert hc.key_outputs["status"] == "new_value" + + @pytest.mark.asyncio + async def test_extractive_truncation(self) -> None: + long_text = "x" * 1000 + conv = await _build_conversation( + ("go", long_text), + ) + ch = ContextHandoff() + hc = ch.summarize_conversation(conv, node_id="n") + + # Summary should be truncated to ~500 chars + assert len(hc.summary) <= 500 + + +# --------------------------------------------------------------------------- +# TestLLMSummary +# --------------------------------------------------------------------------- + + +class TestLLMSummary: + @pytest.mark.asyncio + async def test_llm_summary_calls_provider(self) -> None: + llm = SpyLLMProvider() + conv = await _build_conversation( + ("hi", "hello back"), + ("what now?", "we are done"), + ) + ch = ContextHandoff(llm=llm) + hc = ch.summarize_conversation(conv, node_id="llm_node") + + assert llm.complete_called, "LLM complete() was never invoked" + assert hc.summary == "This is a mock response for testing purposes." + + @pytest.mark.asyncio + async def test_llm_summary_includes_output_key_hint(self) -> None: + llm = SpyLLMProvider() + conv = await _build_conversation( + ("compute", '{"score": 95}'), + ) + ch = ContextHandoff(llm=llm) + ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"]) + + assert llm.complete_call_args is not None + system = llm.complete_call_args.get("system", "") + assert "score" in system + assert "grade" in system + + @pytest.mark.asyncio + async def test_llm_fallback_on_error(self) -> None: + llm = FailingLLMProvider() + conv = await _build_conversation( + ("start", "First assistant message."), + ("end", "Last assistant message."), + ) + ch = ContextHandoff(llm=llm) + hc = ch.summarize_conversation(conv, node_id="fallback_node") + + # Should fall back to extractive (first + last assistant messages) + assert "First assistant message." in hc.summary + assert "Last assistant message." in hc.summary + + +# --------------------------------------------------------------------------- +# TestFormatAsInput +# --------------------------------------------------------------------------- + + +class TestFormatAsInput: + def test_format_structure(self) -> None: + hc = HandoffContext( + source_node_id="analyzer", + summary="Analysis complete.", + key_outputs={"score": "95"}, + turn_count=5, + total_tokens_used=2000, + ) + output = ContextHandoff.format_as_input(hc) + + assert "--- CONTEXT FROM: analyzer" in output + assert "KEY OUTPUTS:" in output + assert "SUMMARY:" in output + assert "--- END CONTEXT ---" in output + + def test_format_no_key_outputs(self) -> None: + hc = HandoffContext( + source_node_id="simple", + summary="Done.", + key_outputs={}, + turn_count=1, + total_tokens_used=100, + ) + output = ContextHandoff.format_as_input(hc) + + assert "KEY OUTPUTS:" not in output + assert "SUMMARY:" in output + + def test_format_content_values(self) -> None: + hc = HandoffContext( + source_node_id="node_X", + summary="Found 3 bugs.", + key_outputs={"bugs": "3", "severity": "high"}, + turn_count=7, + total_tokens_used=5000, + ) + output = ContextHandoff.format_as_input(hc) + + assert "node_X" in output + assert "7 turns" in output + assert "~5000 tokens" in output + assert "- bugs: 3" in output + assert "- severity: high" in output + assert "Found 3 bugs." in output + + def test_format_empty_summary(self) -> None: + hc = HandoffContext( + source_node_id="n", + summary="", + key_outputs={}, + turn_count=0, + total_tokens_used=0, + ) + output = ContextHandoff.format_as_input(hc) + + assert "No summary available." in output + + @pytest.mark.asyncio + async def test_format_as_input_usable_as_message(self) -> None: + """Formatted output can be fed into a NodeConversation as a user message.""" + hc = HandoffContext( + source_node_id="prev_node", + summary="Completed analysis.", + key_outputs={"result": "42"}, + turn_count=3, + total_tokens_used=900, + ) + text = ContextHandoff.format_as_input(hc) + + conv = NodeConversation() + msg = await conv.add_user_message(text) + + assert msg.role == "user" + assert "CONTEXT FROM: prev_node" in msg.content + assert conv.turn_count == 1 diff --git a/tools/src/aden_tools/tools/csv_tool/csv_tool.py b/tools/src/aden_tools/tools/csv_tool/csv_tool.py index 0a188b7e..ed88c187 100644 --- a/tools/src/aden_tools/tools/csv_tool/csv_tool.py +++ b/tools/src/aden_tools/tools/csv_tool/csv_tool.py @@ -34,7 +34,7 @@ def register_tools(mcp: FastMCP) -> None: Returns: dict with success status, data, and metadata """ - if (offset < 0 or (limit is not None and limit < 0)): + if offset < 0 or (limit is not None and limit < 0): return {"error": "offset and limit must be non-negative"} try: secure_path = get_secure_path(path, workspace_id, agent_id, session_id)