db82b59254
* fix: inject longTermBackground into memory prompt
The format_memory_for_injection function only processed recentMonths and
earlierContext from the history section, silently dropping longTermBackground.
The LLM writes longTermBackground correctly and it persists to memory.json,
but it was never injected into the system prompt — making the user's
long-term background invisible to the AI.
Add the missing field handling and a regression test.
* fix(middleware): handle list-type AIMessage.content in LoopDetectionMiddleware
LangChain AIMessage.content can be str | list. When using providers that
return structured content blocks (e.g. Anthropic thinking mode, certain
OpenAI-compatible gateways), content is a list of dicts like
[{"type": "text", "text": "..."}].
The hard_limit branch in _apply() concatenated content with a string via
(last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}", which raises
TypeError when content is a non-empty list (list + str is invalid).
Add _append_text() static method that:
- Returns the text directly when content is None
- Appends a {"type": "text"} block when content is a list
- Falls back to string concatenation when content is a str
This is consistent with how other modules in the project already handle
list content (client.py._extract_text, memory_middleware, executor.py).
* test(middleware): add unit tests for _append_text and list content hard stop
Add regression tests to verify LoopDetectionMiddleware handles list-type
AIMessage.content correctly during hard stop:
- TestAppendText: unit tests for the new _append_text() static method
covering None, str, list (including empty list) content types
- TestHardStopWithListContent: integration tests verifying hard stop
works correctly with list content (Anthropic thinking mode), None
content, and str content
Requested by reviewer in PR #1823.
* fix(middleware): improve _append_text robustness and test isolation
- Add explicit isinstance(content, str) check with fallback for
unexpected types (coerce to str) to prevent TypeError on edge cases
- Deep-copy list content in _make_state() test helper to prevent
shared mutable references across test iterations
- Add test_unexpected_type_coerced_to_str: verify fallback for
non-str/list/None content types
- Add test_list_content_not_mutated_in_place: verify _append_text
does not modify the original list
* style: fix ruff format whitespace in test file
---------
Co-authored-by: ppyt <14163465+ppyt@users.noreply.github.com>
245 lines
9.3 KiB
Python
245 lines
9.3 KiB
Python
"""Middleware to detect and break repetitive tool call loops.
|
|
|
|
P0 safety: prevents the agent from calling the same tool with the same
|
|
arguments indefinitely until the recursion limit kills the run.
|
|
|
|
Detection strategy:
|
|
1. After each model response, hash the tool calls (name + args).
|
|
2. Track recent hashes in a sliding window.
|
|
3. If the same hash appears >= warn_threshold times, inject a
|
|
"you are repeating yourself — wrap up" system message (once per hash).
|
|
4. If it appears >= hard_limit times, strip all tool_calls from the
|
|
response so the agent is forced to produce a final text answer.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import threading
|
|
from collections import OrderedDict, defaultdict
|
|
from typing import override
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain_core.messages import HumanMessage
|
|
from langgraph.runtime import Runtime
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Defaults — can be overridden via constructor
|
|
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
|
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
|
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
|
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
|
|
|
|
|
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
|
"""Deterministic hash of a set of tool calls (name + args).
|
|
|
|
This is intended to be order-independent: the same multiset of tool calls
|
|
should always produce the same hash, regardless of their input order.
|
|
"""
|
|
# First normalize each tool call to a minimal (name, args) structure.
|
|
normalized: list[dict] = []
|
|
for tc in tool_calls:
|
|
normalized.append(
|
|
{
|
|
"name": tc.get("name", ""),
|
|
"args": tc.get("args", {}),
|
|
}
|
|
)
|
|
|
|
# Sort by both name and a deterministic serialization of args so that
|
|
# permutations of the same multiset of calls yield the same ordering.
|
|
normalized.sort(
|
|
key=lambda tc: (
|
|
tc["name"],
|
|
json.dumps(tc["args"], sort_keys=True, default=str),
|
|
)
|
|
)
|
|
blob = json.dumps(normalized, sort_keys=True, default=str)
|
|
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
|
|
|
|
|
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
|
|
|
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
|
|
|
|
|
|
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|
"""Detects and breaks repetitive tool call loops.
|
|
|
|
Args:
|
|
warn_threshold: Number of identical tool call sets before injecting
|
|
a warning message. Default: 3.
|
|
hard_limit: Number of identical tool call sets before stripping
|
|
tool_calls entirely. Default: 5.
|
|
window_size: Size of the sliding window for tracking calls.
|
|
Default: 20.
|
|
max_tracked_threads: Maximum number of threads to track before
|
|
evicting the least recently used. Default: 100.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
|
|
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
|
window_size: int = _DEFAULT_WINDOW_SIZE,
|
|
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
|
):
|
|
super().__init__()
|
|
self.warn_threshold = warn_threshold
|
|
self.hard_limit = hard_limit
|
|
self.window_size = window_size
|
|
self.max_tracked_threads = max_tracked_threads
|
|
self._lock = threading.Lock()
|
|
# Per-thread tracking using OrderedDict for LRU eviction
|
|
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
|
self._warned: dict[str, set[str]] = defaultdict(set)
|
|
|
|
def _get_thread_id(self, runtime: Runtime) -> str:
|
|
"""Extract thread_id from runtime context for per-thread tracking."""
|
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
if thread_id:
|
|
return thread_id
|
|
return "default"
|
|
|
|
def _evict_if_needed(self) -> None:
|
|
"""Evict least recently used threads if over the limit.
|
|
|
|
Must be called while holding self._lock.
|
|
"""
|
|
while len(self._history) > self.max_tracked_threads:
|
|
evicted_id, _ = self._history.popitem(last=False)
|
|
self._warned.pop(evicted_id, None)
|
|
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
|
|
|
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
|
"""Track tool calls and check for loops.
|
|
|
|
Returns:
|
|
(warning_message_or_none, should_hard_stop)
|
|
"""
|
|
messages = state.get("messages", [])
|
|
if not messages:
|
|
return None, False
|
|
|
|
last_msg = messages[-1]
|
|
if getattr(last_msg, "type", None) != "ai":
|
|
return None, False
|
|
|
|
tool_calls = getattr(last_msg, "tool_calls", None)
|
|
if not tool_calls:
|
|
return None, False
|
|
|
|
thread_id = self._get_thread_id(runtime)
|
|
call_hash = _hash_tool_calls(tool_calls)
|
|
|
|
with self._lock:
|
|
# Touch / create entry (move to end for LRU)
|
|
if thread_id in self._history:
|
|
self._history.move_to_end(thread_id)
|
|
else:
|
|
self._history[thread_id] = []
|
|
self._evict_if_needed()
|
|
|
|
history = self._history[thread_id]
|
|
history.append(call_hash)
|
|
if len(history) > self.window_size:
|
|
history[:] = history[-self.window_size :]
|
|
|
|
count = history.count(call_hash)
|
|
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
|
|
|
if count >= self.hard_limit:
|
|
logger.error(
|
|
"Loop hard limit reached — forcing stop",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"call_hash": call_hash,
|
|
"count": count,
|
|
"tools": tool_names,
|
|
},
|
|
)
|
|
return _HARD_STOP_MSG, True
|
|
|
|
if count >= self.warn_threshold:
|
|
warned = self._warned[thread_id]
|
|
if call_hash not in warned:
|
|
warned.add(call_hash)
|
|
logger.warning(
|
|
"Repetitive tool calls detected — injecting warning",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"call_hash": call_hash,
|
|
"count": count,
|
|
"tools": tool_names,
|
|
},
|
|
)
|
|
return _WARNING_MSG, False
|
|
# Warning already injected for this hash — suppress
|
|
return None, False
|
|
|
|
return None, False
|
|
|
|
@staticmethod
|
|
def _append_text(content: str | list | None, text: str) -> str | list:
|
|
"""Append *text* to AIMessage content, handling str, list, and None.
|
|
|
|
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
|
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
|
a string to a list, which would raise ``TypeError``.
|
|
"""
|
|
if content is None:
|
|
return text
|
|
if isinstance(content, list):
|
|
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
|
if isinstance(content, str):
|
|
return content + f"\n\n{text}"
|
|
# Fallback: coerce unexpected types to str to avoid TypeError
|
|
return str(content) + f"\n\n{text}"
|
|
|
|
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
warning, hard_stop = self._track_and_check(state, runtime)
|
|
|
|
if hard_stop:
|
|
# Strip tool_calls from the last AIMessage to force text output
|
|
messages = state.get("messages", [])
|
|
last_msg = messages[-1]
|
|
stripped_msg = last_msg.model_copy(
|
|
update={
|
|
"tool_calls": [],
|
|
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
|
|
}
|
|
)
|
|
return {"messages": [stripped_msg]}
|
|
|
|
if warning:
|
|
# Inject as HumanMessage instead of SystemMessage to avoid
|
|
# Anthropic's "multiple non-consecutive system messages" error.
|
|
# Anthropic models require system messages only at the start of
|
|
# the conversation; injecting one mid-conversation crashes
|
|
# langchain_anthropic's _format_messages(). HumanMessage works
|
|
# with all providers. See #1299.
|
|
return {"messages": [HumanMessage(content=warning)]}
|
|
|
|
return None
|
|
|
|
@override
|
|
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._apply(state, runtime)
|
|
|
|
@override
|
|
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._apply(state, runtime)
|
|
|
|
def reset(self, thread_id: str | None = None) -> None:
|
|
"""Clear tracking state. If thread_id given, clear only that thread."""
|
|
with self._lock:
|
|
if thread_id:
|
|
self._history.pop(thread_id, None)
|
|
self._warned.pop(thread_id, None)
|
|
else:
|
|
self._history.clear()
|
|
self._warned.clear()
|