Compare commits

...

1 Commits

Author SHA1 Message Date
greatmengqi 2eb45e9bb5 fix: thread app config through client and sync providers 2026-05-02 12:07:26 +08:00
5 changed files with 254 additions and 40 deletions
+18 -8
View File
@@ -228,14 +228,21 @@ class DeerFlowClient:
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled), "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled), "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares), "middleware": _build_middlewares(
config,
model_name=model_name,
agent_name=self._agent_name,
custom_middlewares=self._middlewares,
app_config=self._app_config,
),
"system_prompt": apply_prompt_template( "system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled, subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents, max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name, agent_name=self._agent_name,
available_skills=self._available_skills, available_skills=self._available_skills,
app_config=self._app_config,
), ),
"state_schema": ThreadState, "state_schema": ThreadState,
} }
@@ -243,7 +250,7 @@ class DeerFlowClient:
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer import get_checkpointer from deerflow.runtime.checkpointer import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer(app_config=self._app_config)
if checkpointer is not None: if checkpointer is not None:
kwargs["checkpointer"] = checkpointer kwargs["checkpointer"] = checkpointer
@@ -251,12 +258,15 @@ class DeerFlowClient:
self._agent_config_key = key self._agent_config_key = key
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled) logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
@staticmethod def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
"""Lazy import to avoid circular dependency at module level.""" """Lazy import to avoid circular dependency at module level."""
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) return get_available_tools(
model_name=model_name,
subagent_enabled=subagent_enabled,
app_config=self._app_config,
)
@staticmethod @staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]: def _serialize_tool_calls(tool_calls) -> list[dict]:
@@ -377,7 +387,7 @@ class DeerFlowClient:
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer(app_config=self._app_config)
thread_info_map = {} thread_info_map = {}
@@ -432,7 +442,7 @@ class DeerFlowClient:
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer(app_config=self._app_config)
config = {"configurable": {"thread_id": thread_id}} config = {"configurable": {"thread_id": thread_id}}
checkpoints = [] checkpoints = []
@@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
@@ -98,9 +98,78 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
_checkpointer: Checkpointer | None = None _checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive _checkpointer_ctx = None # open context manager keeping the connection alive
_explicit_checkpointers: dict[int, Checkpointer] = {}
_explicit_checkpointer_contexts: dict[int, object] = {}
def get_checkpointer() -> Checkpointer: def _default_in_memory_checkpointer() -> Checkpointer:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
return InMemorySaver()
def _persistent_database_backend(db_config) -> str | None:
backend = getattr(db_config, "backend", None)
if backend in {"sqlite", "postgres"}:
return backend
return None
@contextlib.contextmanager
def _sync_checkpointer_from_database_cm(db_config) -> Iterator[Checkpointer]:
"""Context manager that creates a sync checkpointer from unified DatabaseConfig."""
backend = _persistent_database_backend(db_config)
if backend is None:
yield _default_in_memory_checkpointer()
return
if backend == "sqlite":
try:
from langgraph.checkpoint.sqlite import SqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
yield saver
return
if backend == "postgres":
try:
from langgraph.checkpoint.postgres import PostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
with PostgresSaver.from_conn_string(db_config.postgres_url) as saver:
saver.setup()
logger.info("Checkpointer: using PostgresSaver")
yield saver
return
raise ValueError(f"Unknown database backend: {backend!r}")
def _build_checkpointer_from_app_config(app_config: AppConfig) -> tuple[Checkpointer, object | None]:
if app_config.checkpointer is not None:
ctx = _sync_checkpointer_cm(app_config.checkpointer)
return ctx.__enter__(), ctx
db_config = getattr(app_config, "database", None)
if _persistent_database_backend(db_config) is not None:
ctx = _sync_checkpointer_from_database_cm(db_config)
return ctx.__enter__(), ctx
return _default_in_memory_checkpointer(), None
def get_checkpointer(app_config: AppConfig | None = None) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call. """Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
@@ -111,6 +180,18 @@ def get_checkpointer() -> Checkpointer:
""" """
global _checkpointer, _checkpointer_ctx global _checkpointer, _checkpointer_ctx
if app_config is not None:
cache_key = id(app_config)
cached = _explicit_checkpointers.get(cache_key)
if cached is not None:
return cached
explicit_checkpointer, explicit_ctx = _build_checkpointer_from_app_config(app_config)
_explicit_checkpointers[cache_key] = explicit_checkpointer
if explicit_ctx is not None:
_explicit_checkpointer_contexts[cache_key] = explicit_ctx
return explicit_checkpointer
if _checkpointer is not None: if _checkpointer is not None:
return _checkpointer return _checkpointer
@@ -121,28 +202,30 @@ def get_checkpointer() -> Checkpointer:
from deerflow.config.checkpointer_config import get_checkpointer_config from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config() config = get_checkpointer_config()
global_app_config = _app_config
if config is None and _app_config is None: if config is None and global_app_config is None:
# Only load app config lazily when neither the app config nor an explicit # Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that # checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any # intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk. # ambient config.yaml on disk.
try: try:
get_app_config() global_app_config = get_app_config()
except FileNotFoundError: except FileNotFoundError:
# In test environments without config.yaml, this is expected. # In test environments without config.yaml, this is expected.
pass pass
config = get_checkpointer_config() config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)") if config is not None:
_checkpointer = InMemorySaver() _checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
return _checkpointer return _checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config) if global_app_config is not None:
_checkpointer = _checkpointer_ctx.__enter__() _checkpointer, _checkpointer_ctx = _build_checkpointer_from_app_config(global_app_config)
return _checkpointer
_checkpointer = _default_in_memory_checkpointer()
return _checkpointer return _checkpointer
@@ -161,6 +244,18 @@ def reset_checkpointer() -> None:
_checkpointer_ctx = None _checkpointer_ctx = None
_checkpointer = None _checkpointer = None
for cache_key, ctx in list(_explicit_checkpointer_contexts.items()):
try:
ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during explicit checkpointer cleanup", exc_info=True)
finally:
_explicit_checkpointer_contexts.pop(cache_key, None)
_explicit_checkpointers.pop(cache_key, None)
_explicit_checkpointers.clear()
_explicit_checkpointer_contexts.clear()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Sync context manager # Sync context manager
@@ -168,7 +263,7 @@ def reset_checkpointer() -> None:
@contextlib.contextmanager @contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]: def checkpointer_context(app_config: AppConfig | None = None) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit. """Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance — Unlike :func:`get_checkpointer`, this does **not** cache the instance —
@@ -181,12 +276,16 @@ def checkpointer_context() -> Iterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
""" """
config = get_app_config() resolved_app_config = app_config or get_app_config()
if config.checkpointer is None: if resolved_app_config.checkpointer is not None:
from langgraph.checkpoint.memory import InMemorySaver with _sync_checkpointer_cm(resolved_app_config.checkpointer) as saver:
yield saver
yield InMemorySaver()
return return
with _sync_checkpointer_cm(config.checkpointer) as saver: db_config = getattr(resolved_app_config, "database", None)
yield saver if _persistent_database_backend(db_config) is not None:
with _sync_checkpointer_from_database_cm(db_config) as saver:
yield saver
return
yield _default_in_memory_checkpointer()
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -98,9 +98,26 @@ def _sync_store_cm(config) -> Iterator[BaseStore]:
_store: BaseStore | None = None _store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive _store_ctx = None # open context manager keeping the connection alive
_explicit_stores: dict[int, BaseStore] = {}
_explicit_store_contexts: dict[int, object] = {}
def get_store() -> BaseStore: def _default_in_memory_store() -> BaseStore:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
return InMemoryStore()
def _build_store_from_app_config(app_config: AppConfig) -> tuple[BaseStore, object | None]:
if app_config.checkpointer is not None:
ctx = _sync_store_cm(app_config.checkpointer)
return ctx.__enter__(), ctx
return _default_in_memory_store(), None
def get_store(app_config: AppConfig | None = None) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call. """Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@@ -112,6 +129,18 @@ def get_store() -> BaseStore:
""" """
global _store, _store_ctx global _store, _store_ctx
if app_config is not None:
cache_key = id(app_config)
cached = _explicit_stores.get(cache_key)
if cached is not None:
return cached
explicit_store, explicit_ctx = _build_store_from_app_config(app_config)
_explicit_stores[cache_key] = explicit_store
if explicit_ctx is not None:
_explicit_store_contexts[cache_key] = explicit_ctx
return explicit_store
if _store is not None: if _store is not None:
return _store return _store
@@ -130,10 +159,7 @@ def get_store() -> BaseStore:
config = get_checkpointer_config() config = get_checkpointer_config()
if config is None: if config is None:
from langgraph.store.memory import InMemoryStore _store = _default_in_memory_store()
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
_store = InMemoryStore()
return _store return _store
_store_ctx = _sync_store_cm(config) _store_ctx = _sync_store_cm(config)
@@ -156,6 +182,18 @@ def reset_store() -> None:
_store_ctx = None _store_ctx = None
_store = None _store = None
for cache_key, ctx in list(_explicit_store_contexts.items()):
try:
ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during explicit store cleanup", exc_info=True)
finally:
_explicit_store_contexts.pop(cache_key, None)
_explicit_stores.pop(cache_key, None)
_explicit_stores.clear()
_explicit_store_contexts.clear()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Sync context manager # Sync context manager
@@ -163,7 +201,7 @@ def reset_store() -> None:
@contextlib.contextmanager @contextlib.contextmanager
def store_context() -> Iterator[BaseStore]: def store_context(app_config: AppConfig | None = None) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit. """Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance — each Unlike :func:`get_store`, this does **not** cache the instance — each
@@ -176,13 +214,10 @@ def store_context() -> Iterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*. checkpointer is configured in *config.yaml*.
""" """
config = get_app_config() resolved_app_config = app_config or get_app_config()
if config.checkpointer is None: if resolved_app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore yield _default_in_memory_store()
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return return
with _sync_store_cm(config.checkpointer) as store: with _sync_store_cm(resolved_app_config.checkpointer) as store:
yield store yield store
+48
View File
@@ -1,6 +1,7 @@
"""Unit tests for checkpointer config and singleton factory.""" """Unit tests for checkpointer config and singleton factory."""
import sys import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -103,6 +104,53 @@ class TestGetCheckpointer:
cp2 = get_checkpointer() cp2 = get_checkpointer()
assert cp1 is not cp2 assert cp1 is not cp2
def test_explicit_app_config_bypasses_global_config_lookup(self):
from langgraph.checkpoint.memory import InMemorySaver
explicit_config = SimpleNamespace(
checkpointer=CheckpointerConfig(type="memory"),
database=SimpleNamespace(backend="memory"),
)
with patch(
"deerflow.runtime.checkpointer.provider.get_app_config",
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
):
cp = get_checkpointer(app_config=explicit_config)
assert isinstance(cp, InMemorySaver)
def test_explicit_app_config_uses_unified_database_sqlite_backend(self):
explicit_config = SimpleNamespace(
checkpointer=None,
database=SimpleNamespace(backend="sqlite", checkpointer_sqlite_path="/tmp/explicit/deerflow.db"),
)
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch(
"deerflow.runtime.checkpointer.provider.get_app_config",
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
),
patch("deerflow.runtime.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure,
):
cp = get_checkpointer(app_config=explicit_config)
assert cp is mock_saver_instance
mock_ensure.assert_called_once_with("/tmp/explicit/deerflow.db")
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/explicit/deerflow.db")
def test_sqlite_raises_when_package_missing(self): def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"}) load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}): with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
+22
View File
@@ -848,6 +848,28 @@ class TestEnsureAgent:
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent" assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"} assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
def test_threads_explicit_app_config_to_dependencies(self, client):
"""Client-owned AppConfig must flow into model/tool/prompt/checkpointer composition."""
mock_agent = MagicMock()
mock_checkpointer = MagicMock()
config = client._get_runnable_config("t1")
with (
patch("deerflow.client.create_chat_model", return_value=MagicMock()) as mock_create_chat_model,
patch("deerflow.client.create_agent", return_value=mock_agent),
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch("deerflow.tools.get_available_tools", return_value=[]) as mock_get_available_tools,
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer) as mock_get_checkpointer,
):
client._ensure_agent(config)
assert mock_create_chat_model.call_args.kwargs["app_config"] is client._app_config
assert mock_build_middlewares.call_args.kwargs["app_config"] is client._app_config
assert mock_apply_prompt.call_args.kwargs["app_config"] is client._app_config
assert mock_get_available_tools.call_args.kwargs["app_config"] is client._app_config
assert mock_get_checkpointer.call_args.kwargs["app_config"] is client._app_config
def test_uses_default_checkpointer_when_available(self, client): def test_uses_default_checkpointer_when_available(self, client):
mock_agent = MagicMock() mock_agent = MagicMock()
mock_checkpointer = MagicMock() mock_checkpointer = MagicMock()