Files
hive/core/framework/graph/conversation.py
T
2026-04-02 12:35:16 -07:00

1228 lines
46 KiB
Python

"""NodeConversation: Message history management for graph nodes."""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Protocol, runtime_checkable
LEGACY_RUN_ID = "__legacy_run__"
def is_legacy_run_id(run_id: str | None) -> bool:
"""True when run_id represents pre-migration (no run boundary) data."""
return run_id is None or run_id == LEGACY_RUN_ID
@dataclass
class Message:
"""A single message in a conversation.
Attributes:
seq: Monotonic sequence number.
role: One of "user", "assistant", or "tool".
content: Message text.
tool_use_id: Internal tool-use identifier (output as ``tool_call_id`` in LLM dicts).
tool_calls: OpenAI-format tool call list for assistant messages.
is_error: When True and role is "tool", ``to_llm_dict`` prepends "ERROR: " to content.
"""
seq: int
role: Literal["user", "assistant", "tool"]
content: str
tool_use_id: str | None = None
tool_calls: list[dict[str, Any]] | None = None
is_error: bool = False
# Phase-aware compaction metadata (continuous mode)
phase_id: str | None = None
is_transition_marker: bool = False
# True when this message is real human input (from /chat), not a system prompt
is_client_input: bool = False
# Optional image content blocks (e.g. from browser_screenshot)
image_content: list[dict[str, Any]] | None = None
# True when message contains an activated skill body (AS-10: never prune)
is_skill_content: bool = False
# Logical worker run identifier for shared-session persistence
run_id: str | None = None
def to_llm_dict(self) -> dict[str, Any]:
"""Convert to OpenAI-format message dict."""
if self.role == "user":
if self.image_content:
blocks: list[dict[str, Any]] = []
if self.content:
blocks.append({"type": "text", "text": self.content})
blocks.extend(self.image_content)
return {"role": "user", "content": blocks}
return {"role": "user", "content": self.content}
if self.role == "assistant":
d: dict[str, Any] = {"role": "assistant", "content": self.content}
if self.tool_calls:
d["tool_calls"] = self.tool_calls
return d
# role == "tool"
content = f"ERROR: {self.content}" if self.is_error else self.content
if self.image_content:
# Multimodal tool result: text + image content blocks
blocks: list[dict[str, Any]] = [{"type": "text", "text": content}]
blocks.extend(self.image_content)
return {
"role": "tool",
"tool_call_id": self.tool_use_id,
"content": blocks,
}
return {
"role": "tool",
"tool_call_id": self.tool_use_id,
"content": content,
}
def to_storage_dict(self) -> dict[str, Any]:
"""Serialize all fields for persistence. Omits None/default-False fields."""
d: dict[str, Any] = {
"seq": self.seq,
"role": self.role,
"content": self.content,
}
if self.tool_use_id is not None:
d["tool_use_id"] = self.tool_use_id
if self.tool_calls is not None:
d["tool_calls"] = self.tool_calls
if self.is_error:
d["is_error"] = self.is_error
if self.phase_id is not None:
d["phase_id"] = self.phase_id
if self.is_transition_marker:
d["is_transition_marker"] = self.is_transition_marker
if self.is_client_input:
d["is_client_input"] = self.is_client_input
if self.image_content is not None:
d["image_content"] = self.image_content
if self.run_id is not None:
d["run_id"] = self.run_id
return d
@classmethod
def from_storage_dict(cls, data: dict[str, Any]) -> Message:
"""Deserialize from a storage dict."""
return cls(
seq=data["seq"],
role=data["role"],
content=data["content"],
tool_use_id=data.get("tool_use_id"),
tool_calls=data.get("tool_calls"),
is_error=data.get("is_error", False),
phase_id=data.get("phase_id"),
is_transition_marker=data.get("is_transition_marker", False),
is_client_input=data.get("is_client_input", False),
image_content=data.get("image_content"),
run_id=data.get("run_id"),
)
def _normalize_cursor(cursor: dict[str, Any] | None) -> dict[str, Any]:
"""Normalize legacy and run-scoped cursor formats into one shape."""
if not cursor:
return {}
if isinstance(cursor.get("runs"), dict):
normalized = dict(cursor)
normalized["runs"] = dict(cursor["runs"])
return normalized
normalized: dict[str, Any] = {}
if "next_seq" in cursor:
normalized["next_seq"] = cursor["next_seq"]
legacy_run = {k: v for k, v in cursor.items() if k != "next_seq"}
if legacy_run:
normalized["runs"] = {LEGACY_RUN_ID: legacy_run}
return normalized
def get_cursor_next_seq(cursor: dict[str, Any] | None) -> int | None:
normalized = _normalize_cursor(cursor)
next_seq = normalized.get("next_seq")
return next_seq if isinstance(next_seq, int) else None
def update_cursor_next_seq(cursor: dict[str, Any] | None, next_seq: int) -> dict[str, Any]:
updated = _normalize_cursor(cursor)
updated["next_seq"] = next_seq
return updated
def get_run_cursor(cursor: dict[str, Any] | None, run_id: str | None) -> dict[str, Any] | None:
if run_id is None:
return dict(cursor) if cursor else None
normalized = _normalize_cursor(cursor)
runs = normalized.get("runs", {})
value = runs.get(run_id)
return dict(value) if isinstance(value, dict) else None
def update_run_cursor(
cursor: dict[str, Any] | None,
run_id: str | None,
values: dict[str, Any],
) -> dict[str, Any]:
if run_id is None:
updated = dict(cursor or {})
updated.update(values)
return updated
normalized = _normalize_cursor(cursor)
runs = dict(normalized.get("runs", {}))
existing = dict(runs.get(run_id, {}))
existing.update(values)
runs[run_id] = existing
normalized["runs"] = runs
return normalized
def _extract_spillover_filename(content: str) -> str | None:
"""Extract spillover filename from a tool result annotation.
Matches patterns produced by EventLoopNode._truncate_tool_result():
- Large result: "saved to 'web_search_1.txt'"
- Small result: "[Saved to 'web_search_1.txt']"
"""
match = re.search(r"[Ss]aved to '([^']+)'", content)
return match.group(1) if match else None
_TC_ARG_LIMIT = 200 # max chars per tool_call argument after compaction
def _compact_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Truncate tool_call arguments to save context tokens during compaction.
Preserves ``id``, ``type``, and ``function.name`` exactly. When arguments
exceed ``_TC_ARG_LIMIT``, replaces the full JSON string with a compact
**valid** JSON summary. The Anthropic API parses tool_call arguments and
rejects requests with malformed JSON (e.g. unterminated strings), so we
must never produce broken JSON here.
"""
compact = []
for tc in tool_calls:
func = tc.get("function", {})
args = func.get("arguments", "")
if len(args) > _TC_ARG_LIMIT:
# Build a valid JSON summary instead of slicing mid-string.
# Try to extract top-level keys for a meaningful preview.
try:
parsed = json.loads(args)
if isinstance(parsed, dict):
# Preserve key names, truncate values
summary_parts = []
for k, v in parsed.items():
v_str = str(v)
if len(v_str) > 60:
v_str = v_str[:60] + "..."
summary_parts.append(f"{k}={v_str}")
summary = ", ".join(summary_parts)
if len(summary) > _TC_ARG_LIMIT:
summary = summary[:_TC_ARG_LIMIT] + "..."
args = json.dumps({"_compacted": summary})
else:
args = json.dumps({"_compacted": str(parsed)[:_TC_ARG_LIMIT]})
except (json.JSONDecodeError, TypeError):
# Args were already invalid JSON — wrap the preview safely
args = json.dumps({"_compacted": args[:_TC_ARG_LIMIT]})
compact.append(
{
"id": tc.get("id", ""),
"type": tc.get("type", "function"),
"function": {
"name": func.get("name", ""),
"arguments": args,
},
}
)
return compact
def extract_tool_call_history(messages: list[Message], max_entries: int = 30) -> str:
"""Build a compact tool call history from a list of messages.
Used in compaction summaries to prevent the LLM from re-calling
tools it already called. Extracts tool call details, files saved,
outputs set, and errors encountered.
"""
tool_calls_detail: dict[str, list[str]] = {}
files_saved: list[str] = []
outputs_set: list[str] = []
errors: list[str] = []
def _summarize_input(name: str, args: dict) -> str:
if name == "web_search":
return args.get("query", "")
if name == "web_scrape":
return args.get("url", "")
if name in ("load_data", "save_data"):
return args.get("filename", "")
return ""
for msg in messages:
if msg.role == "assistant" and msg.tool_calls:
for tc in msg.tool_calls:
func = tc.get("function", {})
name = func.get("name", "unknown")
try:
args = json.loads(func.get("arguments", "{}"))
except (json.JSONDecodeError, TypeError):
args = {}
summary = _summarize_input(name, args)
tool_calls_detail.setdefault(name, []).append(summary)
if name == "save_data" and args.get("filename"):
files_saved.append(args["filename"])
if name == "set_output" and args.get("key"):
outputs_set.append(args["key"])
if msg.role == "tool" and msg.is_error:
preview = msg.content[:120].replace("\n", " ")
errors.append(preview)
parts: list[str] = []
if tool_calls_detail:
lines: list[str] = []
for name, inputs in list(tool_calls_detail.items())[:max_entries]:
count = len(inputs)
non_empty = [s for s in inputs if s]
if non_empty:
detail_lines = [f" - {s[:120]}" for s in non_empty[:8]]
lines.append(f" {name} ({count}x):\n" + "\n".join(detail_lines))
else:
lines.append(f" {name} ({count}x)")
parts.append("TOOLS ALREADY CALLED:\n" + "\n".join(lines))
if files_saved:
unique = list(dict.fromkeys(files_saved))
parts.append("FILES SAVED: " + ", ".join(unique))
if outputs_set:
unique = list(dict.fromkeys(outputs_set))
parts.append("OUTPUTS SET: " + ", ".join(unique))
if errors:
parts.append("ERRORS (do NOT retry these):\n" + "\n".join(f" - {e}" for e in errors[:10]))
return "\n\n".join(parts)
# ---------------------------------------------------------------------------
# ConversationStore protocol (Phase 2)
# ---------------------------------------------------------------------------
@runtime_checkable
class ConversationStore(Protocol):
"""Protocol for conversation persistence backends."""
async def write_part(self, seq: int, data: dict[str, Any]) -> None: ...
async def read_parts(self) -> list[dict[str, Any]]: ...
async def write_meta(self, data: dict[str, Any]) -> None: ...
async def read_meta(self) -> dict[str, Any] | None: ...
async def write_cursor(self, data: dict[str, Any]) -> None: ...
async def read_cursor(self) -> dict[str, Any] | None: ...
async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None: ...
async def close(self) -> None: ...
async def destroy(self) -> None: ...
# ---------------------------------------------------------------------------
# NodeConversation
# ---------------------------------------------------------------------------
def _try_extract_key(content: str, key: str) -> str | None:
"""Try 4 strategies to extract a *key*'s value from message content.
Strategies (in order):
1. Whole message is JSON — ``json.loads``, check for key.
2. Embedded JSON via ``find_json_object`` helper.
3. Colon format: ``key: value``.
4. Equals format: ``key = value``.
"""
from framework.graph.node import find_json_object
# 1. Whole message is JSON
try:
parsed = json.loads(content)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 2. Embedded JSON via find_json_object
json_str = find_json_object(content)
if json_str:
try:
parsed = json.loads(json_str)
if isinstance(parsed, dict) and key in parsed:
val = parsed[key]
return json.dumps(val) if not isinstance(val, str) else val
except (json.JSONDecodeError, TypeError):
pass
# 3. Colon format: key: value
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
if match:
return match.group(1).strip()
# 4. Equals format: key = value
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
if match:
return match.group(1).strip()
return None
class NodeConversation:
"""Message history for a graph node with optional write-through persistence.
When *store* is ``None`` the conversation works purely in-memory.
When a :class:`ConversationStore` is supplied every mutation is
persisted via write-through (meta is lazily written on the first
``_persist`` call).
"""
def __init__(
self,
system_prompt: str = "",
max_context_tokens: int = 32000,
compaction_threshold: float = 0.8,
output_keys: list[str] | None = None,
store: ConversationStore | None = None,
run_id: str | None = None,
) -> None:
self._system_prompt = system_prompt
self._max_context_tokens = max_context_tokens
self._compaction_threshold = compaction_threshold
self._output_keys = output_keys
self._store = store
self._messages: list[Message] = []
self._next_seq: int = 0
self._meta_persisted: bool = False
self._last_api_input_tokens: int | None = None
self._current_phase: str | None = None
self._run_id: str | None = run_id
# --- Properties --------------------------------------------------------
@property
def system_prompt(self) -> str:
return self._system_prompt
def update_system_prompt(self, new_prompt: str) -> None:
"""Update the system prompt.
Used in continuous conversation mode at phase transitions to swap
Layer 3 (focus) while preserving the conversation history.
"""
self._system_prompt = new_prompt
self._meta_persisted = False # re-persist with new prompt
def set_current_phase(self, phase_id: str) -> None:
"""Set the current phase ID. Subsequent messages will be stamped with it."""
self._current_phase = phase_id
@property
def current_phase(self) -> str | None:
return self._current_phase
@property
def messages(self) -> list[Message]:
"""Return a defensive copy of the message list."""
return list(self._messages)
@property
def turn_count(self) -> int:
"""Number of conversational turns (one turn = one user message)."""
return sum(1 for m in self._messages if m.role == "user")
@property
def message_count(self) -> int:
"""Total number of messages (all roles)."""
return len(self._messages)
@property
def next_seq(self) -> int:
return self._next_seq
# --- Add messages ------------------------------------------------------
async def add_user_message(
self,
content: str,
*,
is_transition_marker: bool = False,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> Message:
msg = Message(
seq=self._next_seq,
role="user",
content=content,
phase_id=self._current_phase,
run_id=self._run_id,
is_transition_marker=is_transition_marker,
is_client_input=is_client_input,
image_content=image_content,
)
self._messages.append(msg)
self._next_seq += 1
await self._persist(msg)
return msg
async def add_assistant_message(
self,
content: str,
tool_calls: list[dict[str, Any]] | None = None,
) -> Message:
msg = Message(
seq=self._next_seq,
role="assistant",
content=content,
tool_calls=tool_calls,
phase_id=self._current_phase,
run_id=self._run_id,
)
self._messages.append(msg)
self._next_seq += 1
await self._persist(msg)
return msg
async def add_tool_result(
self,
tool_use_id: str,
content: str,
is_error: bool = False,
image_content: list[dict[str, Any]] | None = None,
is_skill_content: bool = False,
) -> Message:
msg = Message(
seq=self._next_seq,
role="tool",
content=content,
tool_use_id=tool_use_id,
is_error=is_error,
phase_id=self._current_phase,
image_content=image_content,
is_skill_content=is_skill_content,
run_id=self._run_id,
)
self._messages.append(msg)
self._next_seq += 1
await self._persist(msg)
return msg
# --- Query -------------------------------------------------------------
def to_llm_messages(self) -> list[dict[str, Any]]:
"""Return messages as OpenAI-format dicts (system prompt excluded).
Automatically repairs orphaned tool_use blocks (assistant messages
with tool_calls that lack corresponding tool-result messages). This
can happen when a loop is cancelled mid-tool-execution.
"""
msgs = [m.to_llm_dict() for m in self._messages]
return self._repair_orphaned_tool_calls(msgs)
@staticmethod
def _repair_orphaned_tool_calls(
msgs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Ensure tool_call / tool_result pairs are consistent.
1. **Orphaned tool results** (tool_result with no preceding tool_use)
are dropped. This happens when compaction removes an assistant
message but leaves its tool-result messages behind.
2. **Orphaned tool calls** (tool_use with no following tool_result)
get a synthetic error result appended. This happens when a loop
is cancelled mid-tool-execution.
"""
# Pass 1: collect all tool_call IDs from assistant messages so we
# can identify orphaned tool-result messages.
all_tool_call_ids: set[str] = set()
for m in msgs:
if m.get("role") == "assistant":
for tc in m.get("tool_calls") or []:
tc_id = tc.get("id")
if tc_id:
all_tool_call_ids.add(tc_id)
# Pass 2: build repaired list — drop orphaned tool results, patch
# missing tool results.
repaired: list[dict[str, Any]] = []
for i, m in enumerate(msgs):
# Drop tool-result messages whose tool_call_id has no matching
# tool_use in any assistant message (orphaned by compaction).
if m.get("role") == "tool":
tid = m.get("tool_call_id")
if tid and tid not in all_tool_call_ids:
continue # skip orphaned result
repaired.append(m)
tool_calls = m.get("tool_calls")
if m.get("role") != "assistant" or not tool_calls:
continue
# Collect IDs of tool results that follow this assistant message
answered: set[str] = set()
for j in range(i + 1, len(msgs)):
if msgs[j].get("role") == "tool":
tid = msgs[j].get("tool_call_id")
if tid:
answered.add(tid)
else:
break # stop at first non-tool message
# Patch any missing results
for tc in tool_calls:
tc_id = tc.get("id")
if tc_id and tc_id not in answered:
repaired.append(
{
"role": "tool",
"tool_call_id": tc_id,
"content": "ERROR: Tool execution was interrupted.",
}
)
return repaired
def estimate_tokens(self) -> int:
"""Best available token estimate.
Uses actual API input token count when available (set via
:meth:`update_token_count`), otherwise falls back to a
``total_chars / 4`` heuristic that includes both message content
AND tool_call argument sizes.
"""
if self._last_api_input_tokens is not None:
return self._last_api_input_tokens
total_chars = 0
for m in self._messages:
total_chars += len(m.content)
if m.tool_calls:
for tc in m.tool_calls:
func = tc.get("function", {})
total_chars += len(func.get("arguments", ""))
total_chars += len(func.get("name", ""))
return total_chars // 4
def update_token_count(self, actual_input_tokens: int) -> None:
"""Store actual API input token count for more accurate compaction.
Called by EventLoopNode after each LLM call with the ``input_tokens``
value from the API response. This value includes system prompt and
tool definitions, so it may be higher than a message-only estimate.
"""
self._last_api_input_tokens = actual_input_tokens
def usage_ratio(self) -> float:
"""Current token usage as a fraction of *max_context_tokens*.
Returns 0.0 when ``max_context_tokens`` is zero (unlimited).
"""
if self._max_context_tokens <= 0:
return 0.0
return self.estimate_tokens() / self._max_context_tokens
def needs_compaction(self) -> bool:
return self.estimate_tokens() >= self._max_context_tokens * self._compaction_threshold
# --- Output-key extraction ---------------------------------------------
def _extract_protected_values(self, messages: list[Message]) -> dict[str, str]:
"""Scan assistant messages for output_key values before compaction.
Iterates most-recent-first. Once a key is found, it's skipped for
older messages (latest value wins).
"""
if not self._output_keys:
return {}
found: dict[str, str] = {}
remaining_keys = set(self._output_keys)
for msg in reversed(messages):
if msg.role != "assistant" or not remaining_keys:
continue
for key in list(remaining_keys):
value = self._try_extract_key(msg.content, key)
if value is not None:
found[key] = value
remaining_keys.discard(key)
return found
def _try_extract_key(self, content: str, key: str) -> str | None:
"""Try 4 strategies to extract a key's value from message content."""
return _try_extract_key(content, key)
# --- Lifecycle ---------------------------------------------------------
async def prune_old_tool_results(
self,
protect_tokens: int = 5000,
min_prune_tokens: int = 2000,
) -> int:
"""Replace old tool result content with compact placeholders.
Walks backward through messages. Recent tool results (within
*protect_tokens*) are kept intact. Older tool results have their
content replaced with a ~100-char placeholder that preserves the
spillover filename reference (if any). Message structure (role,
seq, tool_use_id) stays valid for the LLM API.
Phase-aware behavior (continuous mode): when messages have ``phase_id``
metadata, all messages in the current phase are protected regardless of
token budget. Transition markers are never pruned. Older phases' tool
results are pruned more aggressively.
Error tool results are never pruned — they prevent re-calling
failing tools.
Returns the number of messages pruned (0 if nothing was pruned).
"""
if not self._messages:
return 0
# Walk backward, classify tool results as protected vs pruneable
protected_tokens = 0
pruneable: list[int] = [] # indices into self._messages
pruneable_tokens = 0
for i in range(len(self._messages) - 1, -1, -1):
msg = self._messages[i]
# Transition markers are never pruned (any role)
if msg.is_transition_marker:
continue
if msg.role != "tool":
continue
if msg.is_error:
continue # never prune errors
if msg.is_skill_content:
continue # never prune activated skill instructions (AS-10)
if msg.content.startswith("[Pruned tool result"):
continue # already pruned
# Tiny results (set_output acks, confirmations) — pruning
# saves negligible space but makes the LLM think the call
# failed, causing costly retries.
if len(msg.content) < 100:
continue
# Phase-aware: protect current phase messages
if self._current_phase and msg.phase_id == self._current_phase:
continue
est = len(msg.content) // 4
if protected_tokens < protect_tokens:
protected_tokens += est
else:
pruneable.append(i)
pruneable_tokens += est
# Only prune if enough to be worthwhile
if pruneable_tokens < min_prune_tokens:
return 0
# Replace content with compact placeholder
count = 0
for i in pruneable:
msg = self._messages[i]
orig_len = len(msg.content)
spillover = _extract_spillover_filename(msg.content)
if spillover:
placeholder = (
f"[Pruned tool result: {orig_len} chars. "
f"Full data in '{spillover}'. "
f"Use load_data('{spillover}') to retrieve.]"
)
else:
placeholder = f"[Pruned tool result: {orig_len} chars cleared from context.]"
self._messages[i] = Message(
seq=msg.seq,
role=msg.role,
content=placeholder,
tool_use_id=msg.tool_use_id,
tool_calls=msg.tool_calls,
is_error=msg.is_error,
phase_id=msg.phase_id,
is_transition_marker=msg.is_transition_marker,
run_id=msg.run_id,
)
count += 1
if self._store:
await self._store.write_part(msg.seq, self._messages[i].to_storage_dict())
# Reset token estimate — content lengths changed
self._last_api_input_tokens = None
return count
async def compact(
self,
summary: str,
keep_recent: int = 2,
phase_graduated: bool = False,
) -> None:
"""Replace old messages with a summary, optionally keeping recent ones.
Args:
summary: Caller-provided summary text.
keep_recent: Number of recent messages to preserve (default 2).
Clamped to [0, len(messages) - 1].
phase_graduated: When True and messages have phase_id metadata,
split at phase boundaries instead of using keep_recent.
Keeps current + previous phase intact; compacts older phases.
"""
if not self._messages:
return
total = len(self._messages)
# Phase-graduated: find the split point based on phase boundaries.
# Keeps current phase + previous phase intact, compacts older phases.
if phase_graduated and self._current_phase:
split = self._find_phase_graduated_split()
else:
split = None
if split is None:
# Fallback: use keep_recent (non-phase or single-phase conversation)
keep_recent = max(0, min(keep_recent, total - 1))
split = total - keep_recent if keep_recent > 0 else total
# Advance split past orphaned tool results at the boundary.
# Tool-role messages reference a tool_use from the preceding
# assistant message; if that assistant message falls into the
# compacted (old) portion the tool_result becomes invalid.
while split < total and self._messages[split].role == "tool":
split += 1
# Nothing to compact
if split == 0:
return
old_messages = list(self._messages[:split])
recent_messages = list(self._messages[split:])
# Extract protected values from messages being discarded
if self._output_keys:
protected = self._extract_protected_values(old_messages)
if protected:
lines = ["PRESERVED VALUES (do not lose these):"]
for k, v in protected.items():
lines.append(f"- {k}: {v}")
lines.append("")
lines.append("CONVERSATION SUMMARY:")
lines.append(summary)
summary = "\n".join(lines)
# Determine summary seq
if recent_messages:
summary_seq = recent_messages[0].seq - 1
else:
summary_seq = self._next_seq
self._next_seq += 1
summary_msg = Message(seq=summary_seq, role="user", content=summary, run_id=self._run_id)
# Persist
if self._store:
delete_before = recent_messages[0].seq if recent_messages else self._next_seq
await self._store.delete_parts_before(delete_before, run_id=self._run_id)
await self._store.write_part(summary_msg.seq, summary_msg.to_storage_dict())
await self._write_next_seq()
self._messages = [summary_msg] + recent_messages
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
async def compact_preserving_structure(
self,
spillover_dir: str,
keep_recent: int = 4,
phase_graduated: bool = False,
aggressive: bool = False,
) -> None:
"""Structure-preserving compaction: save freeform text to file, keep tool messages.
Unlike ``compact()`` which replaces ALL old messages with a single LLM
summary, this method preserves the tool call structure (assistant
messages with tool_calls + tool result messages) that are already tiny
after pruning. Only freeform text exchanges (user messages,
text-only assistant messages) are saved to a file and removed.
When *aggressive* is True, non-essential tool call pairs are also
collapsed into a compact summary instead of being kept individually.
Only ``set_output`` calls and error results are preserved; all other
old tool pairs are replaced by a tool-call history summary.
The result: the agent retains exact knowledge of what tools it called,
where each result is stored, and can load the conversation text if
needed. No LLM summary call. No heuristics. Nothing lost.
"""
if not self._messages:
return
total = len(self._messages)
# Determine split point (same logic as compact)
if phase_graduated and self._current_phase:
split = self._find_phase_graduated_split()
else:
split = None
if split is None:
keep_recent = max(0, min(keep_recent, total - 1))
split = total - keep_recent if keep_recent > 0 else total
# Advance split past orphaned tool results at the boundary
while split < total and self._messages[split].role == "tool":
split += 1
if split == 0:
return
old_messages = self._messages[:split]
# Classify old messages: structural (keep) vs freeform (save to file)
kept_structural: list[Message] = []
freeform_lines: list[str] = []
collapsed_msgs: list[Message] = []
if aggressive:
# Aggressive: only keep set_output tool pairs and error results.
# Everything else is collapsed into a tool-call history summary.
# We need to track tool_call IDs to pair assistant messages with
# their tool results.
protected_tc_ids: set[str] = set()
collapsible_tc_ids: set[str] = set()
# First pass: classify assistant messages
for msg in old_messages:
if msg.role != "assistant" or not msg.tool_calls:
continue
has_protected = any(
tc.get("function", {}).get("name") == "set_output" for tc in msg.tool_calls
)
tc_ids = {tc.get("id", "") for tc in msg.tool_calls}
if has_protected:
protected_tc_ids |= tc_ids
else:
collapsible_tc_ids |= tc_ids
# Second pass: classify all messages
for msg in old_messages:
if msg.role == "tool":
tc_id = msg.tool_use_id or ""
if tc_id in protected_tc_ids:
kept_structural.append(msg)
elif msg.is_error:
# Error results are always protected
kept_structural.append(msg)
# Protect the parent assistant message too
protected_tc_ids.add(tc_id)
else:
collapsed_msgs.append(msg)
elif msg.role == "assistant" and msg.tool_calls:
tc_ids = {tc.get("id", "") for tc in msg.tool_calls}
if tc_ids & protected_tc_ids:
# Has at least one protected tool call — keep entire msg
compact_tcs = _compact_tool_calls(msg.tool_calls)
kept_structural.append(
Message(
seq=msg.seq,
role=msg.role,
content="",
tool_calls=compact_tcs,
is_error=msg.is_error,
phase_id=msg.phase_id,
is_transition_marker=msg.is_transition_marker,
run_id=msg.run_id,
)
)
else:
collapsed_msgs.append(msg)
else:
# Freeform text — save to file
role_label = msg.role
text = msg.content
if len(text) > 2000:
text = text[:2000] + ""
freeform_lines.append(f"[{role_label}] (seq={msg.seq}): {text}")
else:
# Standard mode: keep all tool call pairs as structural
for msg in old_messages:
if msg.role == "tool":
kept_structural.append(msg)
elif msg.role == "assistant" and msg.tool_calls:
compact_tcs = _compact_tool_calls(msg.tool_calls)
kept_structural.append(
Message(
seq=msg.seq,
role=msg.role,
content="",
tool_calls=compact_tcs,
is_error=msg.is_error,
phase_id=msg.phase_id,
is_transition_marker=msg.is_transition_marker,
run_id=msg.run_id,
)
)
else:
role_label = msg.role
text = msg.content
if len(text) > 2000:
text = text[:2000] + ""
freeform_lines.append(f"[{role_label}] (seq={msg.seq}): {text}")
# Write freeform text to a numbered conversation file
spill_path = Path(spillover_dir)
spill_path.mkdir(parents=True, exist_ok=True)
# Find next conversation file number
existing = sorted(spill_path.glob("conversation_*.md"))
next_n = len(existing) + 1
conv_filename = f"conversation_{next_n}.md"
if freeform_lines:
header = f"## Compacted conversation (messages 1-{split})\n\n"
conv_text = header + "\n\n".join(freeform_lines)
(spill_path / conv_filename).write_text(conv_text, encoding="utf-8")
else:
# Nothing to save — skip file creation
conv_filename = ""
# Build reference message
ref_parts: list[str] = []
if conv_filename:
full_path = str((spill_path / conv_filename).resolve())
ref_parts.append(
f"[Previous conversation saved to '{full_path}'. "
f"Use load_data('{conv_filename}') to review if needed.]"
)
elif not collapsed_msgs:
ref_parts.append("[Previous freeform messages compacted.]")
# Aggressive: add collapsed tool-call history to the reference
if collapsed_msgs:
tool_history = extract_tool_call_history(collapsed_msgs)
if tool_history:
ref_parts.append(tool_history)
elif not ref_parts:
ref_parts.append("[Previous tool calls compacted.]")
ref_content = "\n\n".join(ref_parts)
# Use a seq just before the first kept message
recent_messages = list(self._messages[split:])
if kept_structural:
ref_seq = kept_structural[0].seq - 1
elif recent_messages:
ref_seq = recent_messages[0].seq - 1
else:
ref_seq = self._next_seq
self._next_seq += 1
ref_msg = Message(seq=ref_seq, role="user", content=ref_content, run_id=self._run_id)
# Persist: delete old messages from store, write reference + kept structural.
# In aggressive mode, collapsed messages may be interspersed with kept
# messages, so we delete everything before the recent boundary and
# rewrite only what we want to keep.
if self._store:
recent_boundary = recent_messages[0].seq if recent_messages else self._next_seq
await self._store.delete_parts_before(recent_boundary, run_id=self._run_id)
# Write the reference message
await self._store.write_part(ref_msg.seq, ref_msg.to_storage_dict())
# Write kept structural messages (they may have been modified)
for msg in kept_structural:
await self._store.write_part(msg.seq, msg.to_storage_dict())
await self._write_next_seq()
# Reassemble: reference + kept structural (in original order) + recent
self._messages = [ref_msg] + kept_structural + recent_messages
self._last_api_input_tokens = None
def _find_phase_graduated_split(self) -> int | None:
"""Find split point that preserves current + previous phase.
Returns the index of the first message in the protected set,
or None if phase graduation doesn't apply (< 3 phases).
"""
# Collect distinct phases in order of first appearance
phases_seen: list[str] = []
for msg in self._messages:
if msg.phase_id and msg.phase_id not in phases_seen:
phases_seen.append(msg.phase_id)
# Need at least 3 phases for graduation to be meaningful
# (current + previous are protected, older get compacted)
if len(phases_seen) < 3:
return None
# Protect: current phase + previous phase
protected_phases = {phases_seen[-1], phases_seen[-2]}
# Find split: first message belonging to a protected phase
for i, msg in enumerate(self._messages):
if msg.phase_id in protected_phases:
return i
return None
async def clear(self) -> None:
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
if self._store:
await self._store.delete_parts_before(self._next_seq, run_id=self._run_id)
await self._write_next_seq()
self._messages.clear()
self._last_api_input_tokens = None
def export_summary(self) -> str:
"""Structured summary with [STATS], [CONFIG], [RECENT_MESSAGES] sections."""
prompt_preview = (
self._system_prompt[:80] + "..."
if len(self._system_prompt) > 80
else self._system_prompt
)
lines = [
"[STATS]",
f"turns: {self.turn_count}",
f"messages: {self.message_count}",
f"estimated_tokens: {self.estimate_tokens()}",
"",
"[CONFIG]",
f"system_prompt: {prompt_preview!r}",
]
if self._output_keys:
lines.append(f"output_keys: {', '.join(self._output_keys)}")
lines.append("")
lines.append("[RECENT_MESSAGES]")
for m in self._messages[-5:]:
preview = m.content[:60] + "..." if len(m.content) > 60 else m.content
lines.append(f" [{m.role}] {preview}")
return "\n".join(lines)
# --- Persistence internals ---------------------------------------------
async def _persist(self, message: Message) -> None:
"""Write-through a single message. No-op when store is None."""
if self._store is None:
return
if not self._meta_persisted:
await self._persist_meta()
await self._store.write_part(message.seq, message.to_storage_dict())
await self._write_next_seq()
async def _persist_meta(self) -> None:
"""Lazily write conversation metadata to the store (called once).
When ``self._run_id`` is set, metadata is keyed under
``meta["runs"][run_id]`` so multiple runs can coexist in the same
session. Legacy (no run_id) sessions write flat for backward compat.
"""
if self._store is None:
return
run_meta = {
"system_prompt": self._system_prompt,
"max_context_tokens": self._max_context_tokens,
"compaction_threshold": self._compaction_threshold,
"output_keys": self._output_keys,
}
if self._run_id:
existing = await self._store.read_meta() or {}
runs = dict(existing.get("runs", {}))
runs[self._run_id] = run_meta
existing["runs"] = runs
await self._store.write_meta(existing)
else:
await self._store.write_meta(run_meta)
self._meta_persisted = True
async def _write_next_seq(self) -> None:
if self._store is None:
return
cursor = await self._store.read_cursor()
await self._store.write_cursor(update_cursor_next_seq(cursor, self._next_seq))
# --- Restore -----------------------------------------------------------
@classmethod
async def restore(
cls,
store: ConversationStore,
phase_id: str | None = None,
run_id: str | None = None,
) -> NodeConversation | None:
"""Reconstruct a NodeConversation from a store.
Args:
store: The conversation store to read from.
phase_id: If set, only load parts matching this phase_id.
Used in isolated mode so a node only sees its own
messages in the shared flat store. In continuous mode
pass ``None`` to load all parts.
Returns ``None`` if the store contains no metadata (i.e. the
conversation was never persisted).
"""
meta = await store.read_meta()
if meta is None:
return None
# Extract run-scoped metadata when available
if run_id and isinstance(meta.get("runs"), dict):
run_meta = meta["runs"].get(run_id)
if run_meta is not None:
meta = run_meta
conv = cls(
system_prompt=meta.get("system_prompt", ""),
max_context_tokens=meta.get("max_context_tokens", 32000),
compaction_threshold=meta.get("compaction_threshold", 0.8),
output_keys=meta.get("output_keys"),
store=store,
run_id=run_id,
)
conv._meta_persisted = True
parts = await store.read_parts()
if run_id is not None:
if is_legacy_run_id(run_id):
parts = [p for p in parts if is_legacy_run_id(p.get("run_id"))]
else:
parts = [p for p in parts if p.get("run_id") == run_id]
if phase_id:
parts = [p for p in parts if p.get("phase_id") == phase_id]
conv._messages = [Message.from_storage_dict(p) for p in parts]
cursor = await store.read_cursor()
next_seq = get_cursor_next_seq(cursor)
if next_seq is not None:
conv._next_seq = next_seq
elif conv._messages:
conv._next_seq = conv._messages[-1].seq + 1
return conv