refactor: thread app_config through middleware factories (#2652)

* refactor: thread app_config through middleware factories

Continues the incremental config-refactor sequence (#2611 root, #2612 lead
path) one layer deeper into the middleware factories. Two ambient lookups
inside _build_runtime_middlewares are eliminated and the LLMErrorHandling
band-aid removed:

- _build_runtime_middlewares / build_lead_runtime_middlewares /
  build_subagent_runtime_middlewares now require app_config: AppConfig.
- get_guardrails_config() inside the factory is replaced with
  app_config.guardrails (semantically identical — same default-factory
  GuardrailsConfig — verified by direct equality check).
- LLMErrorHandlingMiddleware.__init__ now requires app_config and reads
  circuit_breaker fields directly. The class-level
  circuit_failure_threshold / circuit_recovery_timeout_sec defaults are
  removed along with the try/except (FileNotFoundError, RuntimeError):
  pass band-aid — the let-it-crash invariant the rest of the refactor
  enforces.

Caller chain (already-resolved app_config sources):
- _build_middlewares in lead_agent/agent.py: reorder so
  resolved_app_config = app_config or get_app_config() is computed BEFORE
  build_lead_runtime_middlewares is called, then passed as kwarg.
- SubagentExecutor: optional app_config parameter (mirrors the lead-agent
  pattern); _create_agent does the same `or get_app_config()` fallback at
  agent-build time, so task_tool callers don't need to plumb app_config
  through yet (typed-context plumbing for tool runtimes is a separate
  refactor).

Tests:
- test_llm_error_handling_middleware: _make_app_config helper using
  AppConfig(sandbox=SandboxConfig(use="test")) — same minimal-config
  pattern conftest already uses. Three direct LLMErrorHandlingMiddleware()
  calls each followed by post-construction circuit_breaker mutation fold
  cleanly into _build_middleware(circuit_failure_threshold=...,
  circuit_recovery_timeout_sec=...).

Verification:
- tests/test_llm_error_handling_middleware.py — 14 passed
- tests/test_subagent_executor.py — 28 passed
- tests/test_tool_error_handling_middleware.py — 6 passed
- tests/test_task_tool_core_logic.py — 18 passed (verifies task_tool
  unchanged behavior)
- Full suite: 2697 passed, 3 skipped. The single intermittent failure in
  tests/test_client_e2e.py::test_tool_call_produces_events is pre-existing
  LLM flakiness (the test asserts the model decided to call a tool;
  reproduces 1/3 on unchanged main as well).

* fix: address middleware app config review comments

* fix: satisfy app config annotation lint

* test: cover explicit app config middleware wiring

---------

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
greatmengqi
2026-04-30 12:41:09 +08:00
committed by GitHub
parent 74081a85a6
commit 38714b6ceb
8 changed files with 236 additions and 34 deletions
@@ -259,8 +259,8 @@ def _build_middlewares(
Returns: Returns:
List of middleware instances. List of middleware instances.
""" """
middlewares = build_lead_runtime_middlewares(lazy_init=True)
resolved_app_config = app_config or get_app_config() resolved_app_config = app_config or get_app_config()
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
# Add summarization middleware if enabled # Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config) summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
@@ -20,7 +20,7 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -70,20 +70,11 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
retry_base_delay_ms: int = 1000 retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000 retry_cap_delay_ms: int = 8000
circuit_failure_threshold: int = 5 def __init__(self, *, app_config: AppConfig, **kwargs: Any) -> None:
circuit_recovery_timeout_sec: int = 60
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
# Load Circuit Breaker configs from app config if available, fall back to defaults self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
try: self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
app_config = get_app_config()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
# Gracefully fall back to class defaults in test environments
pass
# Circuit Breaker state # Circuit Breaker state
self._circuit_lock = threading.Lock() self._circuit_lock = threading.Lock()
@@ -11,6 +11,8 @@ from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command from langgraph.types import Command
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id" _MISSING_TOOL_CALL_ID = "missing_tool_call_id"
@@ -67,6 +69,7 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
def _build_runtime_middlewares( def _build_runtime_middlewares(
*, *,
app_config: AppConfig,
include_uploads: bool, include_uploads: bool,
include_dangling_tool_call_patch: bool, include_dangling_tool_call_patch: bool,
lazy_init: bool = True, lazy_init: bool = True,
@@ -91,12 +94,10 @@ def _build_runtime_middlewares(
middlewares.append(DanglingToolCallMiddleware()) middlewares.append(DanglingToolCallMiddleware())
middlewares.append(LLMErrorHandlingMiddleware()) middlewares.append(LLMErrorHandlingMiddleware(app_config=app_config))
# Guardrail middleware (if configured) # Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config guardrails_config = app_config.guardrails
guardrails_config = get_guardrails_config()
if guardrails_config.enabled and guardrails_config.provider: if guardrails_config.enabled and guardrails_config.provider:
import inspect import inspect
@@ -125,18 +126,20 @@ def _build_runtime_middlewares(
return middlewares return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]: def build_lead_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares.""" """Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares( return _build_runtime_middlewares(
app_config=app_config,
include_uploads=True, include_uploads=True,
include_dangling_tool_call_patch=True, include_dangling_tool_call_patch=True,
lazy_init=lazy_init, lazy_init=lazy_init,
) )
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]: def build_subagent_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by subagent runtime before subagent-only middlewares.""" """Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares( return _build_runtime_middlewares(
app_config=app_config,
include_uploads=False, include_uploads=False,
include_dangling_tool_call_patch=True, include_dangling_tool_call_patch=True,
lazy_init=lazy_init, lazy_init=lazy_init,
@@ -17,6 +17,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
from deerflow.subagents.config import SubagentConfig from deerflow.subagents.config import SubagentConfig
@@ -132,6 +133,7 @@ class SubagentExecutor:
self, self,
config: SubagentConfig, config: SubagentConfig,
tools: list[BaseTool], tools: list[BaseTool],
app_config: AppConfig | None = None,
parent_model: str | None = None, parent_model: str | None = None,
sandbox_state: SandboxState | None = None, sandbox_state: SandboxState | None = None,
thread_data: ThreadDataState | None = None, thread_data: ThreadDataState | None = None,
@@ -143,6 +145,9 @@ class SubagentExecutor:
Args: Args:
config: Subagent configuration. config: Subagent configuration.
tools: List of all available tools (will be filtered). tools: List of all available tools (will be filtered).
app_config: Resolved AppConfig; threaded into middleware factories
at agent-build time. When None, ``_create_agent`` falls back to
``get_app_config()`` (matches the lead-agent factory's pattern).
parent_model: The parent agent's model name for inheritance. parent_model: The parent agent's model name for inheritance.
sandbox_state: Sandbox state from parent agent. sandbox_state: Sandbox state from parent agent.
thread_data: Thread data from parent agent. thread_data: Thread data from parent agent.
@@ -150,6 +155,7 @@ class SubagentExecutor:
trace_id: Trace ID from parent for distributed tracing. trace_id: Trace ID from parent for distributed tracing.
""" """
self.config = config self.config = config
self.app_config = app_config
self.parent_model = parent_model self.parent_model = parent_model
self.sandbox_state = sandbox_state self.sandbox_state = sandbox_state
self.thread_data = thread_data self.thread_data = thread_data
@@ -168,13 +174,17 @@ class SubagentExecutor:
def _create_agent(self): def _create_agent(self):
"""Create the agent instance.""" """Create the agent instance."""
# Mirror lead-agent factory pattern: prefer explicit app_config,
# fall back to ambient lookup at agent-build time.
from deerflow.config import get_app_config
resolved_app_config = self.app_config or get_app_config()
model_name = _get_model_name(self.config, self.parent_model) model_name = _get_model_name(self.config, self.parent_model)
model = create_chat_model(name=model_name, thinking_enabled=False) model = create_chat_model(name=model_name, thinking_enabled=False, app_config=resolved_app_config)
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
# Reuse shared middleware composition with lead agent. middlewares = build_subagent_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
middlewares = build_subagent_runtime_middlewares(lazy_init=True)
return create_agent( return create_agent(
model=model, model=model,
@@ -217,6 +217,40 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock) assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
captured: dict[str, object] = {}
def _raise_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
def _fake_build_lead_runtime_middlewares(*, app_config, lazy_init):
captured["app_config"] = app_config
captured["lazy_init"] = lazy_init
return ["base-middleware"]
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
monkeypatch.setattr(
lead_agent_module,
"build_lead_runtime_middlewares",
_fake_build_lead_runtime_middlewares,
)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
assert captured == {
"app_config": app_config,
"lazy_init": True,
}
assert middlewares[0] == "base-middleware"
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch): def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
lead_agent_module, lead_agent_module,
@@ -11,6 +11,13 @@ from langgraph.errors import GraphBubbleUp
from deerflow.agents.middlewares.llm_error_handling_middleware import ( from deerflow.agents.middlewares.llm_error_handling_middleware import (
LLMErrorHandlingMiddleware, LLMErrorHandlingMiddleware,
) )
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_app_config() -> AppConfig:
"""Minimal AppConfig for middleware tests; circuit_breaker uses defaults."""
return AppConfig(sandbox=SandboxConfig(use="test"))
class FakeError(Exception): class FakeError(Exception):
@@ -31,7 +38,7 @@ class FakeError(Exception):
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware: def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
middleware = LLMErrorHandlingMiddleware() middleware = LLMErrorHandlingMiddleware(app_config=_make_app_config())
for key, value in attrs.items(): for key, value in attrs.items():
setattr(middleware, key, value) setattr(middleware, key, value)
return middleware return middleware
@@ -226,9 +233,7 @@ def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) ->
current_time = 1000.0 current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time) monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware() middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable) monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
request: Any = {"messages": []} request: Any = {"messages": []}
@@ -284,8 +289,7 @@ def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pyte
waits: list[float] = [] waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d)) monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
middleware = LLMErrorHandlingMiddleware() middleware = _build_middleware(circuit_failure_threshold=3)
middleware.circuit_failure_threshold = 3
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable) monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
request: Any = {"messages": []} request: Any = {"messages": []}
@@ -386,9 +390,7 @@ async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.Monk
current_time = 1000.0 current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time) monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware() middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable) monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
async def async_failing_handler(request: Any) -> Any: async def async_failing_handler(request: Any) -> Any:
+90
View File
@@ -17,6 +17,7 @@ import asyncio
import sys import sys
import threading import threading
from datetime import datetime from datetime import datetime
from types import ModuleType
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@@ -153,6 +154,13 @@ def mock_agent():
return agent return agent
def _module(name: str, **attrs):
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
# Helper to create real message objects # Helper to create real message objects
class _MsgHelper: class _MsgHelper:
"""Helper to create real message objects from fixture classes.""" """Helper to create real message objects from fixture classes."""
@@ -176,6 +184,88 @@ def msg(classes):
return _MsgHelper(classes) return _MsgHelper(classes)
# -----------------------------------------------------------------------------
# Agent Construction Tests
# -----------------------------------------------------------------------------
class TestAgentConstruction:
"""Test _create_agent() wiring before execution starts."""
def test_create_agent_threads_explicit_app_config_to_model_and_middlewares(
self,
classes,
base_config,
monkeypatch: pytest.MonkeyPatch,
):
"""Explicit app_config must flow into both model and middleware factories."""
import deerflow.config as config_module
from deerflow.subagents import executor as executor_module
SubagentExecutor = classes["SubagentExecutor"]
app_config = object()
model = object()
middlewares = [object()]
agent = object()
captured: dict[str, dict] = {}
def fake_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
def fake_create_chat_model(**kwargs):
captured["model"] = kwargs
return model
def fake_build_subagent_runtime_middlewares(**kwargs):
captured["middlewares"] = kwargs
return middlewares
def fake_create_agent(**kwargs):
captured["agent"] = kwargs
return agent
monkeypatch.setattr(config_module, "get_app_config", fake_get_app_config)
monkeypatch.setattr(
executor_module,
"create_chat_model",
fake_create_chat_model,
)
monkeypatch.setattr(executor_module, "create_agent", fake_create_agent)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.tool_error_handling_middleware",
_module(
"deerflow.agents.middlewares.tool_error_handling_middleware",
build_subagent_runtime_middlewares=fake_build_subagent_runtime_middlewares,
),
)
executor = SubagentExecutor(
config=base_config,
tools=[],
app_config=app_config,
parent_model="parent-model",
)
result = executor._create_agent()
assert result is agent
assert captured["model"] == {
"name": "parent-model",
"thinking_enabled": False,
"app_config": app_config,
}
assert captured["middlewares"] == {
"app_config": app_config,
"lazy_init": True,
}
assert captured["agent"]["model"] is model
assert captured["agent"]["middleware"] is middlewares
assert captured["agent"]["tools"] == []
assert captured["agent"]["system_prompt"] == base_config.system_prompt
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Async Execution Path Tests # Async Execution Path Tests
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1,10 +1,32 @@
from types import SimpleNamespace import sys
from types import ModuleType, SimpleNamespace
import pytest import pytest
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt
from deerflow.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware from deerflow.agents.middlewares.tool_error_handling_middleware import (
ToolErrorHandlingMiddleware,
build_subagent_runtime_middlewares,
)
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.sandbox_config import SandboxConfig
def _module(name: str, **attrs):
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
def _make_app_config() -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
guardrails=GuardrailsConfig(enabled=False),
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
)
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"): def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
@@ -14,6 +36,56 @@ def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
return SimpleNamespace(tool_call=tool_call) return SimpleNamespace(tool_call=tool_call)
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
captured: dict[str, object] = {}
class FakeMiddleware:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class FakeLLMErrorHandlingMiddleware:
def __init__(self, *, app_config):
captured["app_config"] = app_config
app_config = _make_app_config()
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.llm_error_handling_middleware",
_module(
"deerflow.agents.middlewares.llm_error_handling_middleware",
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.thread_data_middleware",
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.sandbox.middleware",
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.dangling_tool_call_middleware",
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.sandbox_audit_middleware",
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
assert len(middlewares) == 6
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
def test_wrap_tool_call_passthrough_on_success(): def test_wrap_tool_call_passthrough_on_success():
middleware = ToolErrorHandlingMiddleware() middleware = ToolErrorHandlingMiddleware()
req = _request() req = _request()