refactor: thread app_config through middleware factories (#2652)
* refactor: thread app_config through middleware factories Continues the incremental config-refactor sequence (#2611 root, #2612 lead path) one layer deeper into the middleware factories. Two ambient lookups inside _build_runtime_middlewares are eliminated and the LLMErrorHandling band-aid removed: - _build_runtime_middlewares / build_lead_runtime_middlewares / build_subagent_runtime_middlewares now require app_config: AppConfig. - get_guardrails_config() inside the factory is replaced with app_config.guardrails (semantically identical — same default-factory GuardrailsConfig — verified by direct equality check). - LLMErrorHandlingMiddleware.__init__ now requires app_config and reads circuit_breaker fields directly. The class-level circuit_failure_threshold / circuit_recovery_timeout_sec defaults are removed along with the try/except (FileNotFoundError, RuntimeError): pass band-aid — the let-it-crash invariant the rest of the refactor enforces. Caller chain (already-resolved app_config sources): - _build_middlewares in lead_agent/agent.py: reorder so resolved_app_config = app_config or get_app_config() is computed BEFORE build_lead_runtime_middlewares is called, then passed as kwarg. - SubagentExecutor: optional app_config parameter (mirrors the lead-agent pattern); _create_agent does the same `or get_app_config()` fallback at agent-build time, so task_tool callers don't need to plumb app_config through yet (typed-context plumbing for tool runtimes is a separate refactor). Tests: - test_llm_error_handling_middleware: _make_app_config helper using AppConfig(sandbox=SandboxConfig(use="test")) — same minimal-config pattern conftest already uses. Three direct LLMErrorHandlingMiddleware() calls each followed by post-construction circuit_breaker mutation fold cleanly into _build_middleware(circuit_failure_threshold=..., circuit_recovery_timeout_sec=...). Verification: - tests/test_llm_error_handling_middleware.py — 14 passed - tests/test_subagent_executor.py — 28 passed - tests/test_tool_error_handling_middleware.py — 6 passed - tests/test_task_tool_core_logic.py — 18 passed (verifies task_tool unchanged behavior) - Full suite: 2697 passed, 3 skipped. The single intermittent failure in tests/test_client_e2e.py::test_tool_call_produces_events is pre-existing LLM flakiness (the test asserts the model decided to call a tool; reproduces 1/3 on unchanged main as well). * fix: address middleware app config review comments * fix: satisfy app config annotation lint * test: cover explicit app config middleware wiring --------- Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
@@ -259,8 +259,8 @@ def _build_middlewares(
|
|||||||
Returns:
|
Returns:
|
||||||
List of middleware instances.
|
List of middleware instances.
|
||||||
"""
|
"""
|
||||||
middlewares = build_lead_runtime_middlewares(lazy_init=True)
|
|
||||||
resolved_app_config = app_config or get_app_config()
|
resolved_app_config = app_config or get_app_config()
|
||||||
|
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
|
||||||
|
|
||||||
# Add summarization middleware if enabled
|
# Add summarization middleware if enabled
|
||||||
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
||||||
|
|||||||
+4
-13
@@ -20,7 +20,7 @@ from langchain.agents.middleware.types import (
|
|||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langgraph.errors import GraphBubbleUp
|
from langgraph.errors import GraphBubbleUp
|
||||||
|
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config.app_config import AppConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -70,20 +70,11 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
retry_base_delay_ms: int = 1000
|
retry_base_delay_ms: int = 1000
|
||||||
retry_cap_delay_ms: int = 8000
|
retry_cap_delay_ms: int = 8000
|
||||||
|
|
||||||
circuit_failure_threshold: int = 5
|
def __init__(self, *, app_config: AppConfig, **kwargs: Any) -> None:
|
||||||
circuit_recovery_timeout_sec: int = 60
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# Load Circuit Breaker configs from app config if available, fall back to defaults
|
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
|
||||||
try:
|
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
|
||||||
app_config = get_app_config()
|
|
||||||
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
|
|
||||||
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
|
|
||||||
except (FileNotFoundError, RuntimeError):
|
|
||||||
# Gracefully fall back to class defaults in test environments
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Circuit Breaker state
|
# Circuit Breaker state
|
||||||
self._circuit_lock = threading.Lock()
|
self._circuit_lock = threading.Lock()
|
||||||
|
|||||||
+9
-6
@@ -11,6 +11,8 @@ from langgraph.errors import GraphBubbleUp
|
|||||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
from deerflow.config.app_config import AppConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||||
@@ -67,6 +69,7 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
|
|
||||||
def _build_runtime_middlewares(
|
def _build_runtime_middlewares(
|
||||||
*,
|
*,
|
||||||
|
app_config: AppConfig,
|
||||||
include_uploads: bool,
|
include_uploads: bool,
|
||||||
include_dangling_tool_call_patch: bool,
|
include_dangling_tool_call_patch: bool,
|
||||||
lazy_init: bool = True,
|
lazy_init: bool = True,
|
||||||
@@ -91,12 +94,10 @@ def _build_runtime_middlewares(
|
|||||||
|
|
||||||
middlewares.append(DanglingToolCallMiddleware())
|
middlewares.append(DanglingToolCallMiddleware())
|
||||||
|
|
||||||
middlewares.append(LLMErrorHandlingMiddleware())
|
middlewares.append(LLMErrorHandlingMiddleware(app_config=app_config))
|
||||||
|
|
||||||
# Guardrail middleware (if configured)
|
# Guardrail middleware (if configured)
|
||||||
from deerflow.config.guardrails_config import get_guardrails_config
|
guardrails_config = app_config.guardrails
|
||||||
|
|
||||||
guardrails_config = get_guardrails_config()
|
|
||||||
if guardrails_config.enabled and guardrails_config.provider:
|
if guardrails_config.enabled and guardrails_config.provider:
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@@ -125,18 +126,20 @@ def _build_runtime_middlewares(
|
|||||||
return middlewares
|
return middlewares
|
||||||
|
|
||||||
|
|
||||||
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
def build_lead_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||||
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
||||||
return _build_runtime_middlewares(
|
return _build_runtime_middlewares(
|
||||||
|
app_config=app_config,
|
||||||
include_uploads=True,
|
include_uploads=True,
|
||||||
include_dangling_tool_call_patch=True,
|
include_dangling_tool_call_patch=True,
|
||||||
lazy_init=lazy_init,
|
lazy_init=lazy_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
def build_subagent_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||||
return _build_runtime_middlewares(
|
return _build_runtime_middlewares(
|
||||||
|
app_config=app_config,
|
||||||
include_uploads=False,
|
include_uploads=False,
|
||||||
include_dangling_tool_call_patch=True,
|
include_dangling_tool_call_patch=True,
|
||||||
lazy_init=lazy_init,
|
lazy_init=lazy_init,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
|
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
|
||||||
|
from deerflow.config.app_config import AppConfig
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
from deerflow.subagents.config import SubagentConfig
|
from deerflow.subagents.config import SubagentConfig
|
||||||
|
|
||||||
@@ -132,6 +133,7 @@ class SubagentExecutor:
|
|||||||
self,
|
self,
|
||||||
config: SubagentConfig,
|
config: SubagentConfig,
|
||||||
tools: list[BaseTool],
|
tools: list[BaseTool],
|
||||||
|
app_config: AppConfig | None = None,
|
||||||
parent_model: str | None = None,
|
parent_model: str | None = None,
|
||||||
sandbox_state: SandboxState | None = None,
|
sandbox_state: SandboxState | None = None,
|
||||||
thread_data: ThreadDataState | None = None,
|
thread_data: ThreadDataState | None = None,
|
||||||
@@ -143,6 +145,9 @@ class SubagentExecutor:
|
|||||||
Args:
|
Args:
|
||||||
config: Subagent configuration.
|
config: Subagent configuration.
|
||||||
tools: List of all available tools (will be filtered).
|
tools: List of all available tools (will be filtered).
|
||||||
|
app_config: Resolved AppConfig; threaded into middleware factories
|
||||||
|
at agent-build time. When None, ``_create_agent`` falls back to
|
||||||
|
``get_app_config()`` (matches the lead-agent factory's pattern).
|
||||||
parent_model: The parent agent's model name for inheritance.
|
parent_model: The parent agent's model name for inheritance.
|
||||||
sandbox_state: Sandbox state from parent agent.
|
sandbox_state: Sandbox state from parent agent.
|
||||||
thread_data: Thread data from parent agent.
|
thread_data: Thread data from parent agent.
|
||||||
@@ -150,6 +155,7 @@ class SubagentExecutor:
|
|||||||
trace_id: Trace ID from parent for distributed tracing.
|
trace_id: Trace ID from parent for distributed tracing.
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.app_config = app_config
|
||||||
self.parent_model = parent_model
|
self.parent_model = parent_model
|
||||||
self.sandbox_state = sandbox_state
|
self.sandbox_state = sandbox_state
|
||||||
self.thread_data = thread_data
|
self.thread_data = thread_data
|
||||||
@@ -168,13 +174,17 @@ class SubagentExecutor:
|
|||||||
|
|
||||||
def _create_agent(self):
|
def _create_agent(self):
|
||||||
"""Create the agent instance."""
|
"""Create the agent instance."""
|
||||||
|
# Mirror lead-agent factory pattern: prefer explicit app_config,
|
||||||
|
# fall back to ambient lookup at agent-build time.
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
resolved_app_config = self.app_config or get_app_config()
|
||||||
model_name = _get_model_name(self.config, self.parent_model)
|
model_name = _get_model_name(self.config, self.parent_model)
|
||||||
model = create_chat_model(name=model_name, thinking_enabled=False)
|
model = create_chat_model(name=model_name, thinking_enabled=False, app_config=resolved_app_config)
|
||||||
|
|
||||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
||||||
|
|
||||||
# Reuse shared middleware composition with lead agent.
|
middlewares = build_subagent_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
|
||||||
middlewares = build_subagent_runtime_middlewares(lazy_init=True)
|
|
||||||
|
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
@@ -217,6 +217,40 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
|||||||
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
|
||||||
|
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
def _raise_get_app_config():
|
||||||
|
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
|
||||||
|
|
||||||
|
def _fake_build_lead_runtime_middlewares(*, app_config, lazy_init):
|
||||||
|
captured["app_config"] = app_config
|
||||||
|
captured["lazy_init"] = lazy_init
|
||||||
|
return ["base-middleware"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
lead_agent_module,
|
||||||
|
"build_lead_runtime_middlewares",
|
||||||
|
_fake_build_lead_runtime_middlewares,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
|
||||||
|
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||||
|
|
||||||
|
middlewares = lead_agent_module._build_middlewares(
|
||||||
|
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
|
||||||
|
model_name="safe-model",
|
||||||
|
app_config=app_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured == {
|
||||||
|
"app_config": app_config,
|
||||||
|
"lazy_init": True,
|
||||||
|
}
|
||||||
|
assert middlewares[0] == "base-middleware"
|
||||||
|
|
||||||
|
|
||||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
lead_agent_module,
|
lead_agent_module,
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ from langgraph.errors import GraphBubbleUp
|
|||||||
from deerflow.agents.middlewares.llm_error_handling_middleware import (
|
from deerflow.agents.middlewares.llm_error_handling_middleware import (
|
||||||
LLMErrorHandlingMiddleware,
|
LLMErrorHandlingMiddleware,
|
||||||
)
|
)
|
||||||
|
from deerflow.config.app_config import AppConfig
|
||||||
|
from deerflow.config.sandbox_config import SandboxConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app_config() -> AppConfig:
|
||||||
|
"""Minimal AppConfig for middleware tests; circuit_breaker uses defaults."""
|
||||||
|
return AppConfig(sandbox=SandboxConfig(use="test"))
|
||||||
|
|
||||||
|
|
||||||
class FakeError(Exception):
|
class FakeError(Exception):
|
||||||
@@ -31,7 +38,7 @@ class FakeError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
|
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
|
||||||
middleware = LLMErrorHandlingMiddleware()
|
middleware = LLMErrorHandlingMiddleware(app_config=_make_app_config())
|
||||||
for key, value in attrs.items():
|
for key, value in attrs.items():
|
||||||
setattr(middleware, key, value)
|
setattr(middleware, key, value)
|
||||||
return middleware
|
return middleware
|
||||||
@@ -226,9 +233,7 @@ def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) ->
|
|||||||
current_time = 1000.0
|
current_time = 1000.0
|
||||||
monkeypatch.setattr("time.time", lambda: current_time)
|
monkeypatch.setattr("time.time", lambda: current_time)
|
||||||
|
|
||||||
middleware = LLMErrorHandlingMiddleware()
|
middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
|
||||||
middleware.circuit_failure_threshold = 3
|
|
||||||
middleware.circuit_recovery_timeout_sec = 10
|
|
||||||
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
|
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
|
||||||
|
|
||||||
request: Any = {"messages": []}
|
request: Any = {"messages": []}
|
||||||
@@ -284,8 +289,7 @@ def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pyte
|
|||||||
waits: list[float] = []
|
waits: list[float] = []
|
||||||
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
|
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
|
||||||
|
|
||||||
middleware = LLMErrorHandlingMiddleware()
|
middleware = _build_middleware(circuit_failure_threshold=3)
|
||||||
middleware.circuit_failure_threshold = 3
|
|
||||||
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
|
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
|
||||||
|
|
||||||
request: Any = {"messages": []}
|
request: Any = {"messages": []}
|
||||||
@@ -386,9 +390,7 @@ async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.Monk
|
|||||||
current_time = 1000.0
|
current_time = 1000.0
|
||||||
monkeypatch.setattr("time.time", lambda: current_time)
|
monkeypatch.setattr("time.time", lambda: current_time)
|
||||||
|
|
||||||
middleware = LLMErrorHandlingMiddleware()
|
middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
|
||||||
middleware.circuit_failure_threshold = 3
|
|
||||||
middleware.circuit_recovery_timeout_sec = 10
|
|
||||||
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
|
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
|
||||||
|
|
||||||
async def async_failing_handler(request: Any) -> Any:
|
async def async_failing_handler(request: Any) -> Any:
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from types import ModuleType
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -153,6 +154,13 @@ def mock_agent():
|
|||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def _module(name: str, **attrs):
|
||||||
|
module = ModuleType(name)
|
||||||
|
for key, value in attrs.items():
|
||||||
|
setattr(module, key, value)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
# Helper to create real message objects
|
# Helper to create real message objects
|
||||||
class _MsgHelper:
|
class _MsgHelper:
|
||||||
"""Helper to create real message objects from fixture classes."""
|
"""Helper to create real message objects from fixture classes."""
|
||||||
@@ -176,6 +184,88 @@ def msg(classes):
|
|||||||
return _MsgHelper(classes)
|
return _MsgHelper(classes)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Agent Construction Tests
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentConstruction:
|
||||||
|
"""Test _create_agent() wiring before execution starts."""
|
||||||
|
|
||||||
|
def test_create_agent_threads_explicit_app_config_to_model_and_middlewares(
|
||||||
|
self,
|
||||||
|
classes,
|
||||||
|
base_config,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
"""Explicit app_config must flow into both model and middleware factories."""
|
||||||
|
import deerflow.config as config_module
|
||||||
|
from deerflow.subagents import executor as executor_module
|
||||||
|
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
|
||||||
|
app_config = object()
|
||||||
|
model = object()
|
||||||
|
middlewares = [object()]
|
||||||
|
agent = object()
|
||||||
|
captured: dict[str, dict] = {}
|
||||||
|
|
||||||
|
def fake_get_app_config():
|
||||||
|
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
|
||||||
|
|
||||||
|
def fake_create_chat_model(**kwargs):
|
||||||
|
captured["model"] = kwargs
|
||||||
|
return model
|
||||||
|
|
||||||
|
def fake_build_subagent_runtime_middlewares(**kwargs):
|
||||||
|
captured["middlewares"] = kwargs
|
||||||
|
return middlewares
|
||||||
|
|
||||||
|
def fake_create_agent(**kwargs):
|
||||||
|
captured["agent"] = kwargs
|
||||||
|
return agent
|
||||||
|
|
||||||
|
monkeypatch.setattr(config_module, "get_app_config", fake_get_app_config)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
executor_module,
|
||||||
|
"create_chat_model",
|
||||||
|
fake_create_chat_model,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(executor_module, "create_agent", fake_create_agent)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.tool_error_handling_middleware",
|
||||||
|
_module(
|
||||||
|
"deerflow.agents.middlewares.tool_error_handling_middleware",
|
||||||
|
build_subagent_runtime_middlewares=fake_build_subagent_runtime_middlewares,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
app_config=app_config,
|
||||||
|
parent_model="parent-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = executor._create_agent()
|
||||||
|
|
||||||
|
assert result is agent
|
||||||
|
assert captured["model"] == {
|
||||||
|
"name": "parent-model",
|
||||||
|
"thinking_enabled": False,
|
||||||
|
"app_config": app_config,
|
||||||
|
}
|
||||||
|
assert captured["middlewares"] == {
|
||||||
|
"app_config": app_config,
|
||||||
|
"lazy_init": True,
|
||||||
|
}
|
||||||
|
assert captured["agent"]["model"] is model
|
||||||
|
assert captured["agent"]["middleware"] is middlewares
|
||||||
|
assert captured["agent"]["tools"] == []
|
||||||
|
assert captured["agent"]["system_prompt"] == base_config.system_prompt
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Async Execution Path Tests
|
# Async Execution Path Tests
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -1,10 +1,32 @@
|
|||||||
from types import SimpleNamespace
|
import sys
|
||||||
|
from types import ModuleType, SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
from deerflow.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware
|
from deerflow.agents.middlewares.tool_error_handling_middleware import (
|
||||||
|
ToolErrorHandlingMiddleware,
|
||||||
|
build_subagent_runtime_middlewares,
|
||||||
|
)
|
||||||
|
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
|
||||||
|
from deerflow.config.guardrails_config import GuardrailsConfig
|
||||||
|
from deerflow.config.sandbox_config import SandboxConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _module(name: str, **attrs):
|
||||||
|
module = ModuleType(name)
|
||||||
|
for key, value in attrs.items():
|
||||||
|
setattr(module, key, value)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app_config() -> AppConfig:
|
||||||
|
return AppConfig(
|
||||||
|
sandbox=SandboxConfig(use="test"),
|
||||||
|
guardrails=GuardrailsConfig(enabled=False),
|
||||||
|
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
||||||
@@ -14,6 +36,56 @@ def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
|||||||
return SimpleNamespace(tool_call=tool_call)
|
return SimpleNamespace(tool_call=tool_call)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
class FakeMiddleware:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
class FakeLLMErrorHandlingMiddleware:
|
||||||
|
def __init__(self, *, app_config):
|
||||||
|
captured["app_config"] = app_config
|
||||||
|
|
||||||
|
app_config = _make_app_config()
|
||||||
|
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
||||||
|
_module(
|
||||||
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
||||||
|
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.thread_data_middleware",
|
||||||
|
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.sandbox.middleware",
|
||||||
|
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.dangling_tool_call_middleware",
|
||||||
|
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.sandbox_audit_middleware",
|
||||||
|
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
|
||||||
|
)
|
||||||
|
|
||||||
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||||
|
|
||||||
|
assert captured["app_config"] is app_config
|
||||||
|
assert len(middlewares) == 6
|
||||||
|
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
|
||||||
|
|
||||||
|
|
||||||
def test_wrap_tool_call_passthrough_on_success():
|
def test_wrap_tool_call_passthrough_on_success():
|
||||||
middleware = ToolErrorHandlingMiddleware()
|
middleware = ToolErrorHandlingMiddleware()
|
||||||
req = _request()
|
req = _request()
|
||||||
|
|||||||
Reference in New Issue
Block a user