feature: improve micro compaction

This commit is contained in:
Timothy
2026-04-01 16:06:35 -07:00
parent c8a25a0287
commit 137162eada
3 changed files with 244 additions and 35 deletions
+15 -3
View File
@@ -408,6 +408,9 @@ class NodeConversation:
)
self._messages.append(msg)
self._next_seq += 1
# Invalidate stale API token count so estimate_tokens() uses
# the char-based heuristic which reflects the new message.
self._last_api_input_tokens = None
await self._persist(msg)
return msg
@@ -425,6 +428,7 @@ class NodeConversation:
)
self._messages.append(msg)
self._next_seq += 1
self._last_api_input_tokens = None
await self._persist(msg)
return msg
@@ -448,6 +452,7 @@ class NodeConversation:
)
self._messages.append(msg)
self._next_seq += 1
self._last_api_input_tokens = None
await self._persist(msg)
return msg
@@ -528,12 +533,15 @@ class NodeConversation:
Uses actual API input token count when available (set via
:meth:`update_token_count`), otherwise falls back to a
``total_chars / 4`` heuristic that includes both message content
AND tool_call argument sizes.
character-based heuristic that includes message content, tool_call
arguments, and image blocks. The heuristic applies a 4/3 safety
margin to avoid under-counting (inspired by Claude Code's compact
service).
"""
if self._last_api_input_tokens is not None:
return self._last_api_input_tokens
total_chars = 0
image_tokens = 0
for m in self._messages:
total_chars += len(m.content)
if m.tool_calls:
@@ -541,7 +549,11 @@ class NodeConversation:
func = tc.get("function", {})
total_chars += len(func.get("arguments", ""))
total_chars += len(func.get("name", ""))
return total_chars // 4
if m.image_content:
# Images/documents have a fixed token cost per block
image_tokens += len(m.image_content) * 2000
# Apply 4/3 safety margin to character-based estimate
return (total_chars * 4) // (3 * 4) + image_tokens
def update_token_count(self, actual_input_tokens: int) -> None:
"""Store actual API input token count for more accurate compaction.
+222 -27
View File
@@ -1,7 +1,8 @@
"""Conversation compaction pipeline.
Implements the multi-level compaction strategy:
1. Prune old tool results
0. Microcompaction (count-based tool result clearing cheapest)
1. Prune old tool results (token-budget based)
2. Structure-preserving compaction (spillover)
3. LLM summary compaction (with recursive splitting)
4. Emergency deterministic summary (no LLM)
@@ -13,11 +14,12 @@ import json
import logging
import os
import re
import time
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from framework.graph.conversation import NodeConversation
from framework.graph.conversation import Message, NodeConversation
from framework.graph.event_loop.event_publishing import publish_context_usage
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator
from framework.graph.node import NodeContext
@@ -29,6 +31,108 @@ logger = logging.getLogger(__name__)
LLM_COMPACT_CHAR_LIMIT: int = 240_000
LLM_COMPACT_MAX_DEPTH: int = 10
# Microcompaction: tools whose results can be safely cleared
COMPACTABLE_TOOLS: frozenset[str] = frozenset({
"read_file", "run_command", "web_search", "web_fetch",
"grep_search", "glob_search", "write_file", "edit_file",
"browser_screenshot", "list_directory",
})
# Keep at most this many compactable tool results; clear older ones
MICROCOMPACT_KEEP_RECENT: int = 8
# Circuit-breaker: stop auto-compacting after this many consecutive failures
MAX_CONSECUTIVE_FAILURES: int = 3
# Track consecutive compaction failures per conversation (module-level)
_failure_counts: dict[int, int] = {}
# Track last compaction time per conversation for recompaction detection
_last_compact_times: dict[int, float] = {}
def microcompact(conversation: NodeConversation, *, keep_recent: int = MICROCOMPACT_KEEP_RECENT) -> int:
"""Clear old compactable tool results by count, keeping only the most recent.
This is the cheapest possible compaction no LLM call, no structural
changes, just replaces old tool result content with a short placeholder.
Inspired by Claude Code's cached-microcompact strategy.
Returns the number of tool results cleared.
"""
# Collect indices of compactable tool results (newest first)
compactable_indices: list[int] = []
messages = conversation.messages
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
if msg.role != "tool" or msg.is_error or msg.is_skill_content:
continue
if msg.content.startswith("[Pruned tool result") or msg.content.startswith("[Old tool result"):
continue
if len(msg.content) < 100:
continue
# Check if the tool that produced this result is compactable
tool_name = _find_tool_name_for_result(messages, msg)
if tool_name and tool_name in COMPACTABLE_TOOLS:
compactable_indices.append(i)
# Keep the most recent N, clear the rest
to_clear = compactable_indices[keep_recent:]
if not to_clear:
return 0
cleared = 0
for i in to_clear:
msg = messages[i]
spillover = _extract_spillover_filename_inline(msg.content)
orig_len = len(msg.content)
if spillover:
placeholder = (
f"[Old tool result cleared: {orig_len} chars. "
f"Full data in '{spillover}'. "
f"Use load_data('{spillover}') to retrieve.]"
)
else:
placeholder = f"[Old tool result cleared: {orig_len} chars.]"
# Mutate in-place (microcompact is synchronous, no store writes)
conversation._messages[i] = Message(
seq=msg.seq,
role=msg.role,
content=placeholder,
tool_use_id=msg.tool_use_id,
tool_calls=msg.tool_calls,
is_error=msg.is_error,
phase_id=msg.phase_id,
is_transition_marker=msg.is_transition_marker,
)
cleared += 1
if cleared > 0:
# Invalidate cached token count
conversation._last_api_input_tokens = None
return cleared
def _find_tool_name_for_result(messages: list[Message], tool_msg: Message) -> str | None:
"""Find the tool name from the assistant message that triggered this tool result."""
if not tool_msg.tool_use_id:
return None
for msg in messages:
if msg.tool_calls:
for tc in msg.tool_calls:
if tc.get("id") == tool_msg.tool_use_id:
return tc.get("function", {}).get("name")
return None
def _extract_spillover_filename_inline(content: str) -> str | None:
"""Quick inline check for spillover filename in tool result content."""
match = re.search(r"saved to '([^']+)'", content, re.IGNORECASE)
return match.group(1) if match else None
async def compact(
ctx: NodeContext,
@@ -43,11 +147,31 @@ async def compact(
"""Run the full compaction pipeline if conversation needs compaction.
Pipeline stages (in order, short-circuits when budget is restored):
1. Prune old tool results
0. Microcompaction (count-based tool result clearing cheapest)
1. Prune old tool results (token-budget based)
2. Structure-preserving compaction (free, no LLM)
3. LLM summary compaction (recursive split if too large)
4. Emergency deterministic summary (fallback)
"""
conv_id = id(conversation)
# Circuit breaker: stop auto-compacting after repeated failures
if _failure_counts.get(conv_id, 0) >= MAX_CONSECUTIVE_FAILURES:
logger.warning(
"Circuit breaker: skipping compaction after %d consecutive failures",
_failure_counts[conv_id],
)
return
# Recompaction detection
now = time.monotonic()
last_time = _last_compact_times.get(conv_id)
if last_time is not None and (now - last_time) < 30:
logger.warning(
"Recompaction chain detected: only %.1fs since last compaction",
now - last_time,
)
ratio_before = conversation.usage_ratio()
phase_grad = getattr(ctx, "continuous_mode", False)
pre_inventory: list[dict[str, Any]] | None = None
@@ -55,6 +179,23 @@ async def compact(
if ratio_before >= 1.0:
pre_inventory = build_message_inventory(conversation)
# --- Step 0: Microcompaction (count-based, cheapest) ---
mc_cleared = microcompact(conversation)
if mc_cleared > 0:
logger.info(
"Microcompact cleared %d old tool results: %.0f%% -> %.0f%%",
mc_cleared,
ratio_before * 100,
conversation.usage_ratio() * 100,
)
if not conversation.needs_compaction():
_record_success(conv_id, now)
await log_compaction(
ctx, conversation, ratio_before, event_bus,
pre_inventory=pre_inventory,
)
return
# --- Step 1: Prune old tool results (free, fast) ---
protect = max(2000, config.max_context_tokens // 12)
pruned = await conversation.prune_old_tool_results(
@@ -69,11 +210,9 @@ async def compact(
conversation.usage_ratio() * 100,
)
if not conversation.needs_compaction():
_record_success(conv_id, now)
await log_compaction(
ctx,
conversation,
ratio_before,
event_bus,
ctx, conversation, ratio_before, event_bus,
pre_inventory=pre_inventory,
)
return
@@ -87,11 +226,9 @@ async def compact(
phase_graduated=phase_grad,
)
if not conversation.needs_compaction():
_record_success(conv_id, now)
await log_compaction(
ctx,
conversation,
ratio_before,
event_bus,
ctx, conversation, ratio_before, event_bus,
pre_inventory=pre_inventory,
)
return
@@ -118,13 +255,12 @@ async def compact(
)
except Exception as e:
logger.warning("LLM compaction failed: %s", e)
_failure_counts[conv_id] = _failure_counts.get(conv_id, 0) + 1
if not conversation.needs_compaction():
_record_success(conv_id, now)
await log_compaction(
ctx,
conversation,
ratio_before,
event_bus,
ctx, conversation, ratio_before, event_bus,
pre_inventory=pre_inventory,
)
return
@@ -140,18 +276,51 @@ async def compact(
keep_recent=1,
phase_graduated=phase_grad,
)
_record_success(conv_id, now)
await log_compaction(
ctx,
conversation,
ratio_before,
event_bus,
ctx, conversation, ratio_before, event_bus,
pre_inventory=pre_inventory,
)
def _record_success(conv_id: int, timestamp: float) -> None:
"""Reset failure counter and record compaction time on success."""
_failure_counts.pop(conv_id, None)
_last_compact_times[conv_id] = timestamp
# --- LLM compaction with binary-search splitting ----------------------
def strip_images_from_messages(messages: list[Message]) -> list[Message]:
"""Strip image_content from messages before LLM summarisation.
Images/documents are replaced with ``[image]`` markers so the summary
notes they existed without wasting tokens sending binary data to the
compaction LLM. Returns a new list (original messages are not mutated).
"""
stripped: list[Message] = []
for msg in messages:
if msg.image_content:
n_images = len(msg.image_content)
marker = " ".join("[image]" for _ in range(n_images))
content = f"{msg.content}\n{marker}" if msg.content else marker
stripped.append(Message(
seq=msg.seq,
role=msg.role,
content=content,
tool_use_id=msg.tool_use_id,
tool_calls=msg.tool_calls,
is_error=msg.is_error,
phase_id=msg.phase_id,
is_transition_marker=msg.is_transition_marker,
image_content=None, # stripped
))
else:
stripped.append(msg)
return stripped
async def llm_compact(
ctx: NodeContext,
messages: list,
@@ -175,6 +344,10 @@ async def llm_compact(
if _depth > max_depth:
raise RuntimeError(f"LLM compaction recursion limit ({max_depth})")
# Strip images before summarisation to avoid wasting tokens
if _depth == 0:
messages = strip_images_from_messages(messages)
formatted = format_messages_for_summary(messages)
# Proactive split: avoid wasting an API call on oversized input
@@ -297,7 +470,12 @@ def build_llm_compaction_prompt(
*,
max_context_tokens: int = 128_000,
) -> str:
"""Build prompt for LLM compaction targeting 50% of token budget."""
"""Build prompt for LLM compaction targeting 50% of token budget.
Uses a structured section format inspired by Claude Code's compact
service. Each section focuses on a different aspect of the conversation
so the summariser produces consistently useful, well-organised output.
"""
spec = ctx.node_spec
ctx_lines = [f"NODE: {spec.name} (id={spec.id})"]
if spec.description:
@@ -330,13 +508,30 @@ def build_llm_compaction_prompt(
f"CONVERSATION MESSAGES:\n{formatted_messages}\n\n"
"INSTRUCTIONS:\n"
f"Write a summary of approximately {target_chars} characters "
f"(~{target_tokens} tokens).\n"
"1. Preserve ALL user-stated rules, constraints, and preferences "
"verbatim.\n"
"2. Preserve key decisions made and results obtained.\n"
"3. Preserve in-progress work state so the agent can continue.\n"
"4. Be detailed enough that the agent can resume without "
"re-doing work.\n"
f"(~{target_tokens} tokens).\n\n"
"Organise the summary into these sections (omit empty ones):\n\n"
"1. **Primary Request and Intent** — What the user originally asked "
"for and the high-level goal the agent is working toward.\n"
"2. **Key Technical Concepts** — Important domain-specific terms, "
"patterns, or architectural decisions established in the conversation.\n"
"3. **Files and Code Sections** — Specific files read/written/edited "
"with brief descriptions of changes. Include short code snippets only "
"when they capture critical logic.\n"
"4. **Errors and Fixes** — Problems encountered and how they were "
"resolved. Include root causes so the agent doesn't repeat them.\n"
"5. **Problem Solving Efforts** — Approaches tried, dead ends hit, "
"and reasoning behind the current strategy.\n"
"6. **User Messages** — Preserve ALL user-stated rules, constraints, "
"identity preferences, and account details verbatim.\n"
"7. **Pending Tasks** — Work remaining, outputs still needed, and "
"any blockers.\n"
"8. **Current Work** — The most recent action taken and the immediate "
"next step the agent should perform. This section is the most important "
"for seamless resumption.\n\n"
"Additional rules:\n"
"- Be detailed enough that the agent can resume without re-doing work.\n"
"- Preserve key decisions made and results obtained.\n"
"- When in doubt, keep information rather than discard it.\n"
)
+7 -5
View File
@@ -167,14 +167,15 @@ class TestNodeConversation:
async def test_token_estimation(self):
conv = NodeConversation()
await conv.add_user_message("a" * 400)
assert conv.estimate_tokens() == 100
# chars // 3 (4/3 safety margin over chars/4 base)
assert conv.estimate_tokens() == 400 // 3
@pytest.mark.asyncio
async def test_update_token_count_overrides_estimate(self):
"""When actual API token count is provided, estimate_tokens uses it."""
conv = NodeConversation()
await conv.add_user_message("a" * 400)
assert conv.estimate_tokens() == 100 # chars/4 fallback
assert conv.estimate_tokens() == 400 // 3 # char-based fallback with safety margin
conv.update_token_count(500)
assert conv.estimate_tokens() == 500 # actual API value
@@ -188,8 +189,8 @@ class TestNodeConversation:
assert conv.estimate_tokens() == 500
await conv.compact("summary", keep_recent=0)
# Falls back to chars/4 for the summary message
assert conv.estimate_tokens() == len("summary") // 4
# Falls back to char-based heuristic with 4/3 safety margin (chars // 3)
assert conv.estimate_tokens() == len("summary") // 3
@pytest.mark.asyncio
async def test_clear_resets_token_count(self):
@@ -207,7 +208,8 @@ class TestNodeConversation:
"""usage_ratio returns estimate / max_context_tokens."""
conv = NodeConversation(max_context_tokens=1000)
await conv.add_user_message("a" * 400)
assert conv.usage_ratio() == pytest.approx(0.1) # 100/1000
# 400 // 3 = 133 tokens (with safety margin), so 133/1000
assert conv.usage_ratio() == pytest.approx(400 // 3 / 1000)
conv.update_token_count(800)
assert conv.usage_ratio() == pytest.approx(0.8) # 800/1000