Files
deer-flow/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py
T
ppyt db82b59254 fix(middleware): handle list-type AIMessage.content in LoopDetectionMiddleware (#1823)
* fix: inject longTermBackground into memory prompt

The format_memory_for_injection function only processed recentMonths and
earlierContext from the history section, silently dropping longTermBackground.

The LLM writes longTermBackground correctly and it persists to memory.json,
but it was never injected into the system prompt — making the user's
long-term background invisible to the AI.

Add the missing field handling and a regression test.

* fix(middleware): handle list-type AIMessage.content in LoopDetectionMiddleware

LangChain AIMessage.content can be str | list. When using providers that
return structured content blocks (e.g. Anthropic thinking mode, certain
OpenAI-compatible gateways), content is a list of dicts like
[{"type": "text", "text": "..."}].

The hard_limit branch in _apply() concatenated content with a string via
(last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}", which raises
TypeError when content is a non-empty list (list + str is invalid).

Add _append_text() static method that:
- Returns the text directly when content is None
- Appends a {"type": "text"} block when content is a list
- Falls back to string concatenation when content is a str

This is consistent with how other modules in the project already handle
list content (client.py._extract_text, memory_middleware, executor.py).

* test(middleware): add unit tests for _append_text and list content hard stop

Add regression tests to verify LoopDetectionMiddleware handles list-type
AIMessage.content correctly during hard stop:

- TestAppendText: unit tests for the new _append_text() static method
  covering None, str, list (including empty list) content types
- TestHardStopWithListContent: integration tests verifying hard stop
  works correctly with list content (Anthropic thinking mode), None
  content, and str content

Requested by reviewer in PR #1823.

* fix(middleware): improve _append_text robustness and test isolation

- Add explicit isinstance(content, str) check with fallback for
  unexpected types (coerce to str) to prevent TypeError on edge cases
- Deep-copy list content in _make_state() test helper to prevent
  shared mutable references across test iterations
- Add test_unexpected_type_coerced_to_str: verify fallback for
  non-str/list/None content types
- Add test_list_content_not_mutated_in_place: verify _append_text
  does not modify the original list

* style: fix ruff format whitespace in test file

---------

Co-authored-by: ppyt <14163465+ppyt@users.noreply.github.com>
2026-04-04 10:38:22 +08:00

245 lines
9.3 KiB
Python

"""Middleware to detect and break repetitive tool call loops.
P0 safety: prevents the agent from calling the same tool with the same
arguments indefinitely until the recursion limit kills the run.
Detection strategy:
1. After each model response, hash the tool calls (name + args).
2. Track recent hashes in a sliding window.
3. If the same hash appears >= warn_threshold times, inject a
"you are repeating yourself — wrap up" system message (once per hash).
4. If it appears >= hard_limit times, strip all tool_calls from the
response so the agent is forced to produce a final text answer.
"""
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
def _hash_tool_calls(tool_calls: list[dict]) -> str:
"""Deterministic hash of a set of tool calls (name + args).
This is intended to be order-independent: the same multiset of tool calls
should always produce the same hash, regardless of their input order.
"""
# First normalize each tool call to a minimal (name, args) structure.
normalized: list[dict] = []
for tc in tool_calls:
normalized.append(
{
"name": tc.get("name", ""),
"args": tc.get("args", {}),
}
)
# Sort by both name and a deterministic serialization of args so that
# permutations of the same multiset of calls yield the same ordering.
normalized.sort(
key=lambda tc: (
tc["name"],
json.dumps(tc["args"], sort_keys=True, default=str),
)
)
blob = json.dumps(normalized, sort_keys=True, default=str)
return hashlib.md5(blob.encode()).hexdigest()[:12]
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
hard_limit: Number of identical tool call sets before stripping
tool_calls entirely. Default: 5.
window_size: Size of the sliding window for tracking calls.
Default: 20.
max_tracked_threads: Maximum number of threads to track before
evicting the least recently used. Default: 100.
"""
def __init__(
self,
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
hard_limit: int = _DEFAULT_HARD_LIMIT,
window_size: int = _DEFAULT_WINDOW_SIZE,
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
):
super().__init__()
self.warn_threshold = warn_threshold
self.hard_limit = hard_limit
self.window_size = window_size
self.max_tracked_threads = max_tracked_threads
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
Must be called while holding self._lock.
"""
while len(self._history) > self.max_tracked_threads:
evicted_id, _ = self._history.popitem(last=False)
self._warned.pop(evicted_id, None)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops.
Returns:
(warning_message_or_none, should_hard_stop)
"""
messages = state.get("messages", [])
if not messages:
return None, False
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None, False
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None, False
thread_id = self._get_thread_id(runtime)
call_hash = _hash_tool_calls(tool_calls)
with self._lock:
# Touch / create entry (move to end for LRU)
if thread_id in self._history:
self._history.move_to_end(thread_id)
else:
self._history[thread_id] = []
self._evict_if_needed()
history = self._history[thread_id]
history.append(call_hash)
if len(history) > self.window_size:
history[:] = history[-self.window_size :]
count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls]
if count >= self.hard_limit:
logger.error(
"Loop hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _HARD_STOP_MSG, True
if count >= self.warn_threshold:
warned = self._warned[thread_id]
if call_hash not in warned:
warned.add(call_hash)
logger.warning(
"Repetitive tool calls detected — injecting warning",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _WARNING_MSG, False
# Warning already injected for this hash — suppress
return None, False
return None, False
@staticmethod
def _append_text(content: str | list | None, text: str) -> str | list:
"""Append *text* to AIMessage content, handling str, list, and None.
When content is a list of content blocks (e.g. Anthropic thinking mode),
we append a new ``{"type": "text", ...}`` block instead of concatenating
a string to a list, which would raise ``TypeError``.
"""
if content is None:
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
# Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}"
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
if hard_stop:
# Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", [])
last_msg = messages[-1]
stripped_msg = last_msg.model_copy(
update={
"tool_calls": [],
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
}
)
return {"messages": [stripped_msg]}
if warning:
# Inject as HumanMessage instead of SystemMessage to avoid
# Anthropic's "multiple non-consecutive system messages" error.
# Anthropic models require system messages only at the start of
# the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299.
return {"messages": [HumanMessage(content=warning)]}
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:
"""Clear tracking state. If thread_id given, clear only that thread."""
with self._lock:
if thread_id:
self._history.pop(thread_id, None)
self._warned.pop(thread_id, None)
else:
self._history.clear()
self._warned.clear()