diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 2ba9302c..e83eefbf 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -228,14 +228,21 @@ class DeerFlowClient: max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) kwargs: dict[str, Any] = { - "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled), + "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config), "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled), - "middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares), + "middleware": _build_middlewares( + config, + model_name=model_name, + agent_name=self._agent_name, + custom_middlewares=self._middlewares, + app_config=self._app_config, + ), "system_prompt": apply_prompt_template( subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=self._agent_name, available_skills=self._available_skills, + app_config=self._app_config, ), "state_schema": ThreadState, } @@ -243,7 +250,7 @@ class DeerFlowClient: if checkpointer is None: from deerflow.runtime.checkpointer import get_checkpointer - checkpointer = get_checkpointer() + checkpointer = get_checkpointer(app_config=self._app_config) if checkpointer is not None: kwargs["checkpointer"] = checkpointer @@ -251,12 +258,15 @@ class DeerFlowClient: self._agent_config_key = key logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled) - @staticmethod - def _get_tools(*, model_name: str | None, subagent_enabled: bool): + def _get_tools(self, *, model_name: str | None, subagent_enabled: bool): """Lazy import to avoid circular dependency at module level.""" from deerflow.tools import get_available_tools - return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + return get_available_tools( + model_name=model_name, + subagent_enabled=subagent_enabled, + app_config=self._app_config, + ) @staticmethod def _serialize_tool_calls(tool_calls) -> list[dict]: @@ -377,7 +387,7 @@ class DeerFlowClient: if checkpointer is None: from deerflow.runtime.checkpointer.provider import get_checkpointer - checkpointer = get_checkpointer() + checkpointer = get_checkpointer(app_config=self._app_config) thread_info_map = {} @@ -432,7 +442,7 @@ class DeerFlowClient: if checkpointer is None: from deerflow.runtime.checkpointer.provider import get_checkpointer - checkpointer = get_checkpointer() + checkpointer = get_checkpointer(app_config=self._app_config) config = {"configurable": {"thread_id": thread_id}} checkpoints = [] diff --git a/backend/packages/harness/deerflow/runtime/checkpointer/provider.py b/backend/packages/harness/deerflow/runtime/checkpointer/provider.py index 5ee66be8..c500b1e9 100644 --- a/backend/packages/harness/deerflow/runtime/checkpointer/provider.py +++ b/backend/packages/harness/deerflow/runtime/checkpointer/provider.py @@ -25,7 +25,7 @@ from collections.abc import Iterator from langgraph.types import Checkpointer -from deerflow.config.app_config import get_app_config +from deerflow.config.app_config import AppConfig, get_app_config from deerflow.config.checkpointer_config import CheckpointerConfig from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str @@ -98,9 +98,78 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]: _checkpointer: Checkpointer | None = None _checkpointer_ctx = None # open context manager keeping the connection alive +_explicit_checkpointers: dict[int, Checkpointer] = {} +_explicit_checkpointer_contexts: dict[int, object] = {} -def get_checkpointer() -> Checkpointer: +def _default_in_memory_checkpointer() -> Checkpointer: + from langgraph.checkpoint.memory import InMemorySaver + + logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)") + return InMemorySaver() + + +def _persistent_database_backend(db_config) -> str | None: + backend = getattr(db_config, "backend", None) + if backend in {"sqlite", "postgres"}: + return backend + return None + + +@contextlib.contextmanager +def _sync_checkpointer_from_database_cm(db_config) -> Iterator[Checkpointer]: + """Context manager that creates a sync checkpointer from unified DatabaseConfig.""" + backend = _persistent_database_backend(db_config) + if backend is None: + yield _default_in_memory_checkpointer() + return + + if backend == "sqlite": + try: + from langgraph.checkpoint.sqlite import SqliteSaver + except ImportError as exc: + raise ImportError(SQLITE_INSTALL) from exc + + conn_str = db_config.checkpointer_sqlite_path + ensure_sqlite_parent_dir(conn_str) + with SqliteSaver.from_conn_string(conn_str) as saver: + saver.setup() + logger.info("Checkpointer: using SqliteSaver (%s)", conn_str) + yield saver + return + + if backend == "postgres": + try: + from langgraph.checkpoint.postgres import PostgresSaver + except ImportError as exc: + raise ImportError(POSTGRES_INSTALL) from exc + + if not db_config.postgres_url: + raise ValueError("database.postgres_url is required for the postgres backend") + + with PostgresSaver.from_conn_string(db_config.postgres_url) as saver: + saver.setup() + logger.info("Checkpointer: using PostgresSaver") + yield saver + return + + raise ValueError(f"Unknown database backend: {backend!r}") + + +def _build_checkpointer_from_app_config(app_config: AppConfig) -> tuple[Checkpointer, object | None]: + if app_config.checkpointer is not None: + ctx = _sync_checkpointer_cm(app_config.checkpointer) + return ctx.__enter__(), ctx + + db_config = getattr(app_config, "database", None) + if _persistent_database_backend(db_config) is not None: + ctx = _sync_checkpointer_from_database_cm(db_config) + return ctx.__enter__(), ctx + + return _default_in_memory_checkpointer(), None + + +def get_checkpointer(app_config: AppConfig | None = None) -> Checkpointer: """Return the global sync checkpointer singleton, creating it on first call. Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. @@ -111,6 +180,18 @@ def get_checkpointer() -> Checkpointer: """ global _checkpointer, _checkpointer_ctx + if app_config is not None: + cache_key = id(app_config) + cached = _explicit_checkpointers.get(cache_key) + if cached is not None: + return cached + + explicit_checkpointer, explicit_ctx = _build_checkpointer_from_app_config(app_config) + _explicit_checkpointers[cache_key] = explicit_checkpointer + if explicit_ctx is not None: + _explicit_checkpointer_contexts[cache_key] = explicit_ctx + return explicit_checkpointer + if _checkpointer is not None: return _checkpointer @@ -121,28 +202,30 @@ def get_checkpointer() -> Checkpointer: from deerflow.config.checkpointer_config import get_checkpointer_config config = get_checkpointer_config() + global_app_config = _app_config - if config is None and _app_config is None: + if config is None and global_app_config is None: # Only load app config lazily when neither the app config nor an explicit # checkpointer config has been initialized yet. This keeps tests that # intentionally set the global checkpointer config isolated from any # ambient config.yaml on disk. try: - get_app_config() + global_app_config = get_app_config() except FileNotFoundError: # In test environments without config.yaml, this is expected. pass config = get_checkpointer_config() - if config is None: - from langgraph.checkpoint.memory import InMemorySaver - logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)") - _checkpointer = InMemorySaver() + if config is not None: + _checkpointer_ctx = _sync_checkpointer_cm(config) + _checkpointer = _checkpointer_ctx.__enter__() return _checkpointer - _checkpointer_ctx = _sync_checkpointer_cm(config) - _checkpointer = _checkpointer_ctx.__enter__() + if global_app_config is not None: + _checkpointer, _checkpointer_ctx = _build_checkpointer_from_app_config(global_app_config) + return _checkpointer + _checkpointer = _default_in_memory_checkpointer() return _checkpointer @@ -161,6 +244,18 @@ def reset_checkpointer() -> None: _checkpointer_ctx = None _checkpointer = None + for cache_key, ctx in list(_explicit_checkpointer_contexts.items()): + try: + ctx.__exit__(None, None, None) + except Exception: + logger.warning("Error during explicit checkpointer cleanup", exc_info=True) + finally: + _explicit_checkpointer_contexts.pop(cache_key, None) + _explicit_checkpointers.pop(cache_key, None) + + _explicit_checkpointers.clear() + _explicit_checkpointer_contexts.clear() + # --------------------------------------------------------------------------- # Sync context manager @@ -168,7 +263,7 @@ def reset_checkpointer() -> None: @contextlib.contextmanager -def checkpointer_context() -> Iterator[Checkpointer]: +def checkpointer_context(app_config: AppConfig | None = None) -> Iterator[Checkpointer]: """Sync context manager that yields a checkpointer and cleans up on exit. Unlike :func:`get_checkpointer`, this does **not** cache the instance — @@ -181,12 +276,16 @@ def checkpointer_context() -> Iterator[Checkpointer]: Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. """ - config = get_app_config() - if config.checkpointer is None: - from langgraph.checkpoint.memory import InMemorySaver - - yield InMemorySaver() + resolved_app_config = app_config or get_app_config() + if resolved_app_config.checkpointer is not None: + with _sync_checkpointer_cm(resolved_app_config.checkpointer) as saver: + yield saver return - with _sync_checkpointer_cm(config.checkpointer) as saver: - yield saver + db_config = getattr(resolved_app_config, "database", None) + if _persistent_database_backend(db_config) is not None: + with _sync_checkpointer_from_database_cm(db_config) as saver: + yield saver + return + + yield _default_in_memory_checkpointer() diff --git a/backend/packages/harness/deerflow/runtime/store/provider.py b/backend/packages/harness/deerflow/runtime/store/provider.py index a9394fb9..d2835174 100644 --- a/backend/packages/harness/deerflow/runtime/store/provider.py +++ b/backend/packages/harness/deerflow/runtime/store/provider.py @@ -26,7 +26,7 @@ from collections.abc import Iterator from langgraph.store.base import BaseStore -from deerflow.config.app_config import get_app_config +from deerflow.config.app_config import AppConfig, get_app_config from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str logger = logging.getLogger(__name__) @@ -98,9 +98,26 @@ def _sync_store_cm(config) -> Iterator[BaseStore]: _store: BaseStore | None = None _store_ctx = None # open context manager keeping the connection alive +_explicit_stores: dict[int, BaseStore] = {} +_explicit_store_contexts: dict[int, object] = {} -def get_store() -> BaseStore: +def _default_in_memory_store() -> BaseStore: + from langgraph.store.memory import InMemoryStore + + logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") + return InMemoryStore() + + +def _build_store_from_app_config(app_config: AppConfig) -> tuple[BaseStore, object | None]: + if app_config.checkpointer is not None: + ctx = _sync_store_cm(app_config.checkpointer) + return ctx.__enter__(), ctx + + return _default_in_memory_store(), None + + +def get_store(app_config: AppConfig | None = None) -> BaseStore: """Return the global sync Store singleton, creating it on first call. Returns an :class:`~langgraph.store.memory.InMemoryStore` when no @@ -112,6 +129,18 @@ def get_store() -> BaseStore: """ global _store, _store_ctx + if app_config is not None: + cache_key = id(app_config) + cached = _explicit_stores.get(cache_key) + if cached is not None: + return cached + + explicit_store, explicit_ctx = _build_store_from_app_config(app_config) + _explicit_stores[cache_key] = explicit_store + if explicit_ctx is not None: + _explicit_store_contexts[cache_key] = explicit_ctx + return explicit_store + if _store is not None: return _store @@ -130,10 +159,7 @@ def get_store() -> BaseStore: config = get_checkpointer_config() if config is None: - from langgraph.store.memory import InMemoryStore - - logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") - _store = InMemoryStore() + _store = _default_in_memory_store() return _store _store_ctx = _sync_store_cm(config) @@ -156,6 +182,18 @@ def reset_store() -> None: _store_ctx = None _store = None + for cache_key, ctx in list(_explicit_store_contexts.items()): + try: + ctx.__exit__(None, None, None) + except Exception: + logger.warning("Error during explicit store cleanup", exc_info=True) + finally: + _explicit_store_contexts.pop(cache_key, None) + _explicit_stores.pop(cache_key, None) + + _explicit_stores.clear() + _explicit_store_contexts.clear() + # --------------------------------------------------------------------------- # Sync context manager @@ -163,7 +201,7 @@ def reset_store() -> None: @contextlib.contextmanager -def store_context() -> Iterator[BaseStore]: +def store_context(app_config: AppConfig | None = None) -> Iterator[BaseStore]: """Sync context manager that yields a Store and cleans up on exit. Unlike :func:`get_store`, this does **not** cache the instance — each @@ -176,13 +214,10 @@ def store_context() -> Iterator[BaseStore]: Yields an :class:`~langgraph.store.memory.InMemoryStore` when no checkpointer is configured in *config.yaml*. """ - config = get_app_config() - if config.checkpointer is None: - from langgraph.store.memory import InMemoryStore - - logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") - yield InMemoryStore() + resolved_app_config = app_config or get_app_config() + if resolved_app_config.checkpointer is None: + yield _default_in_memory_store() return - with _sync_store_cm(config.checkpointer) as store: + with _sync_store_cm(resolved_app_config.checkpointer) as store: yield store diff --git a/backend/tests/test_checkpointer.py b/backend/tests/test_checkpointer.py index 5a31cfb7..60406443 100644 --- a/backend/tests/test_checkpointer.py +++ b/backend/tests/test_checkpointer.py @@ -1,6 +1,7 @@ """Unit tests for checkpointer config and singleton factory.""" import sys +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -103,6 +104,53 @@ class TestGetCheckpointer: cp2 = get_checkpointer() assert cp1 is not cp2 + def test_explicit_app_config_bypasses_global_config_lookup(self): + from langgraph.checkpoint.memory import InMemorySaver + + explicit_config = SimpleNamespace( + checkpointer=CheckpointerConfig(type="memory"), + database=SimpleNamespace(backend="memory"), + ) + + with patch( + "deerflow.runtime.checkpointer.provider.get_app_config", + side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"), + ): + cp = get_checkpointer(app_config=explicit_config) + + assert isinstance(cp, InMemorySaver) + + def test_explicit_app_config_uses_unified_database_sqlite_backend(self): + explicit_config = SimpleNamespace( + checkpointer=None, + database=SimpleNamespace(backend="sqlite", checkpointer_sqlite_path="/tmp/explicit/deerflow.db"), + ) + + mock_saver_instance = MagicMock() + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance) + mock_cm.__exit__ = MagicMock(return_value=False) + + mock_saver_cls = MagicMock() + mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm) + + mock_module = MagicMock() + mock_module.SqliteSaver = mock_saver_cls + + with ( + patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}), + patch( + "deerflow.runtime.checkpointer.provider.get_app_config", + side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"), + ), + patch("deerflow.runtime.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure, + ): + cp = get_checkpointer(app_config=explicit_config) + + assert cp is mock_saver_instance + mock_ensure.assert_called_once_with("/tmp/explicit/deerflow.db") + mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/explicit/deerflow.db") + def test_sqlite_raises_when_package_missing(self): load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"}) with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}): diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index 8397af16..67f4562e 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -848,6 +848,28 @@ class TestEnsureAgent: assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent" assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"} + def test_threads_explicit_app_config_to_dependencies(self, client): + """Client-owned AppConfig must flow into model/tool/prompt/checkpointer composition.""" + mock_agent = MagicMock() + mock_checkpointer = MagicMock() + config = client._get_runnable_config("t1") + + with ( + patch("deerflow.client.create_chat_model", return_value=MagicMock()) as mock_create_chat_model, + patch("deerflow.client.create_agent", return_value=mock_agent), + patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares, + patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt, + patch("deerflow.tools.get_available_tools", return_value=[]) as mock_get_available_tools, + patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer) as mock_get_checkpointer, + ): + client._ensure_agent(config) + + assert mock_create_chat_model.call_args.kwargs["app_config"] is client._app_config + assert mock_build_middlewares.call_args.kwargs["app_config"] is client._app_config + assert mock_apply_prompt.call_args.kwargs["app_config"] is client._app_config + assert mock_get_available_tools.call_args.kwargs["app_config"] is client._app_config + assert mock_get_checkpointer.call_args.kwargs["app_config"] is client._app_config + def test_uses_default_checkpointer_when_available(self, client): mock_agent = MagicMock() mock_checkpointer = MagicMock()