Compare commits

...

3 Commits

Author SHA1 Message Date
Willem Jiang 7052978a43 fix the lint errors 2026-04-26 11:16:22 +08:00
Willem Jiang d9f7f658be Apply suggestions from code review
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-26 11:12:42 +08:00
Willem Jiang a55de566b9 refactor(backend): consolidate thread_id resolution into shared get_thread_id() utility (#2522)
Extract duplicated thread_id fallback logic from 11 files into a single
  deerflow.utils.runtime.get_thread_id() function with a documented 3-level
  cascade (runtime.context → runtime.config → get_config()).

  The module docstring also clarifies the __pregel_runtime injection pattern used in
  gateway mode.
2026-04-26 10:52:37 +08:00
14 changed files with 191 additions and 87 deletions
@@ -25,6 +25,8 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
@@ -183,10 +185,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
return "default"
return get_thread_id(runtime) or "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
@@ -5,12 +5,12 @@ from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -57,13 +57,10 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
if not config.enabled:
return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
# Resolve thread ID from the runtime or configured fallback sources
thread_id = get_thread_id(runtime)
if not thread_id:
logger.debug("No thread_id in context, skipping memory update")
logger.debug("No thread_id could be resolved from runtime/config, skipping memory update")
return None
# Get messages from state
@@ -14,6 +14,7 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.agents.thread_state import ThreadState
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -218,15 +219,7 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
# ------------------------------------------------------------------
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
runtime = request.runtime # ToolRuntime; may be None-like in tests
if runtime is None:
return None
ctx = getattr(runtime, "context", None) or {}
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
if thread_id is None:
cfg = getattr(runtime, "config", None) or {}
thread_id = cfg.get("configurable", {}).get("thread_id")
return thread_id
return get_thread_id(request.runtime)
_AUDIT_COMMAND_LIMIT = 200
@@ -14,6 +14,8 @@ from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -35,18 +37,6 @@ class BeforeSummarizationHook(Protocol):
def __call__(self, event: SummarizationEvent) -> None: ...
def _resolve_thread_id(runtime: Runtime) -> str | None:
"""Resolve the current thread ID from runtime context or LangGraph config."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
try:
config_data = get_config()
except RuntimeError:
return None
thread_id = config_data.get("configurable", {}).get("thread_id")
return thread_id
def _resolve_agent_name(runtime: Runtime) -> str | None:
"""Resolve the current agent name from runtime context or LangGraph config."""
agent_name = runtime.context.get("agent_name") if runtime.context else None
@@ -334,7 +324,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
event = SummarizationEvent(
messages_to_summarize=tuple(messages_to_summarize),
preserved_messages=tuple(preserved_messages),
thread_id=_resolve_thread_id(runtime),
thread_id=get_thread_id(runtime),
agent_name=_resolve_agent_name(runtime),
runtime=runtime,
)
@@ -3,11 +3,11 @@ from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -75,11 +75,7 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
context = runtime.context or {}
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
thread_id = get_thread_id(runtime)
if thread_id is None:
raise ValueError("Thread ID is required in runtime context or config.configurable")
@@ -11,6 +11,7 @@ from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -213,14 +214,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return None
# Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
thread_id = get_thread_id(runtime)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
@@ -7,6 +7,7 @@ from langgraph.runtime import Runtime
from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.sandbox import get_sandbox_provider
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -56,7 +57,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
# Eager initialization (original behavior)
if "sandbox" not in state or state["sandbox"] is None:
thread_id = (runtime.context or {}).get("thread_id")
thread_id = get_thread_id(runtime)
if thread_id is None:
return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id)
@@ -19,6 +19,7 @@ from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.utils.runtime import get_thread_id
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
@@ -851,11 +852,9 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
# Sandbox was released, fall through to acquire new one
# Lazy acquisition: get thread_id and acquire sandbox
thread_id = runtime.context.get("thread_id") if runtime.context else None
thread_id = get_thread_id(runtime)
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None
if thread_id is None:
raise SandboxRuntimeError("Thread ID not available in runtime context")
raise SandboxRuntimeError("Thread ID not available in runtime context, runtime config, or LangGraph config")
provider = get_sandbox_provider()
sandbox_id = provider.acquire(thread_id)
@@ -3,33 +3,16 @@ from typing import Annotated
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain_core.messages import ToolMessage
from langgraph.config import get_config
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.utils.runtime import get_thread_id
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
"""Resolve the current thread id from runtime context or RunnableConfig."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
runtime_config = getattr(runtime, "config", None) or {}
thread_id = runtime_config.get("configurable", {}).get("thread_id")
if thread_id:
return thread_id
try:
return get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
return None
def _normalize_presented_filepath(
runtime: ToolRuntime[ContextT, ThreadState],
filepath: str,
@@ -51,9 +34,9 @@ def _normalize_presented_filepath(
if runtime.state is None:
raise ValueError("Thread runtime state is not available")
thread_id = _get_thread_id(runtime)
thread_id = get_thread_id(runtime)
if not thread_id:
raise ValueError("Thread ID is not available in runtime context or runtime config")
raise ValueError("Thread ID is not available in runtime context, runtime config, or LangGraph thread-local config")
thread_data = runtime.state.get("thread_data") or {}
outputs_path = thread_data.get("outputs_path")
@@ -14,6 +14,7 @@ from deerflow.agents.thread_state import ThreadState
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -105,9 +106,7 @@ async def task_tool(
if runtime is not None:
sandbox_state = runtime.state.get("sandbox")
thread_data = runtime.state.get("thread_data")
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id")
thread_id = get_thread_id(runtime)
# Try to get parent model from configurable
metadata = runtime.config.get("metadata", {})
@@ -28,6 +28,7 @@ from deerflow.skills.manager import (
validate_skill_name,
)
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -42,14 +43,6 @@ def _get_lock(name: str) -> asyncio.Lock:
return lock
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
if runtime is None:
return None
if runtime.context and runtime.context.get("thread_id"):
return runtime.context.get("thread_id")
return runtime.config.get("configurable", {}).get("thread_id")
def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]:
return {
"action": action,
@@ -98,7 +91,7 @@ async def _skill_manage_impl(
"""
name = validate_skill_name(name)
lock = _get_lock(name)
thread_id = _get_thread_id(runtime)
thread_id = get_thread_id(runtime)
async with lock:
if action == "create":
@@ -0,0 +1,90 @@
"""Runtime utilities for thread_id resolution and context access.
Thread ID Resolution Strategy
=============================
DeerFlow resolves the current ``thread_id`` from a three-level cascade:
1. **runtime.context["thread_id"]** -- Set by ``worker.py`` (gateway mode)
or by LangGraph Server (standard mode) when constructing the Runtime.
2. **runtime.config["configurable"]["thread_id"]** -- Available on
``ToolRuntime`` instances passed to tools via the ``@tool`` decorator.
Not available on ``Runtime`` instances received by middlewares.
3. **get_config()["configurable"]["thread_id"]** -- LangGraph's thread-local
config, available when executing inside a graph's runnable context.
About ``__pregel_runtime``
===========================
In gateway mode (``run_agent()`` in ``worker.py``), the agent graph does not
run inside the LangGraph Server. The server normally injects a ``Runtime``
object automatically. Since we run the graph ourselves, we must inject the
Runtime manually via ``config["configurable"]["__pregel_runtime"]``. This is
the standard mechanism provided by LangGraph's Pregel engine for injecting
runtime context into graph nodes. It is not a private/internal hack -- it is
the documented way to pass Runtime when running a graph outside the server.
Duck Typing
===========
Both ``langgraph.runtime.Runtime`` (middlewares) and
``langchain.tools.ToolRuntime`` (tools) expose a ``.context`` attribute (a
dict or None). ``ToolRuntime`` additionally exposes ``.config``. The
function below uses ``getattr`` with safe defaults so it works with either
type, with ``SimpleNamespace`` in tests, or with ``None``.
"""
from __future__ import annotations
from typing import Any
def get_thread_id(runtime: Any | None) -> str | None:
"""Resolve the current thread_id from a runtime object.
Follows a three-level fallback chain:
1. ``runtime.context.get("thread_id")`` -- if context is a non-empty dict.
2. ``runtime.config.get("configurable", {}).get("thread_id")`` -- if
the runtime has a config dict (ToolRuntime).
3. ``get_config().get("configurable", {}).get("thread_id")`` -- LangGraph's
thread-local config. Wrapped in ``try/except RuntimeError`` because it
raises outside a runnable context (e.g., unit tests).
Args:
runtime: A Runtime, ToolRuntime, SimpleNamespace, or None.
Returns:
The thread_id string, or None if it cannot be resolved.
"""
if runtime is None:
return None
# Level 1: runtime.context["thread_id"]
context = getattr(runtime, "context", None)
if context and isinstance(context, dict):
thread_id = context.get("thread_id")
if thread_id:
return thread_id
# Level 2: runtime.config["configurable"]["thread_id"]
config = getattr(runtime, "config", None)
if config and isinstance(config, dict):
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id:
return thread_id
# Level 3: langgraph.config.get_config() -- only works inside runnable context
try:
from langgraph.config import get_config
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if thread_id:
return thread_id
except RuntimeError:
# Expected when not running inside a LangGraph runnable context (e.g., unit tests).
# In that case, thread_id cannot be resolved from thread-local config, so fall through.
pass
return None
+70
View File
@@ -0,0 +1,70 @@
"""Tests for deerflow.utils.runtime.get_thread_id."""
from types import SimpleNamespace
from unittest.mock import patch
from deerflow.utils.runtime import get_thread_id
class TestGetThreadId:
"""Tests for get_thread_id() with various runtime shapes."""
def test_returns_none_when_runtime_is_none(self):
assert get_thread_id(None) is None
def test_returns_thread_id_from_context(self):
runtime = SimpleNamespace(context={"thread_id": "t-1"}, config={})
assert get_thread_id(runtime) == "t-1"
def test_returns_none_from_empty_context(self):
runtime = SimpleNamespace(context={}, config={})
assert get_thread_id(runtime) is None
def test_returns_none_from_none_context(self):
runtime = SimpleNamespace(context=None, config={})
assert get_thread_id(runtime) is None
def test_falls_back_to_runtime_config(self):
runtime = SimpleNamespace(
context=None,
config={"configurable": {"thread_id": "t-from-config"}},
)
assert get_thread_id(runtime) == "t-from-config"
def test_context_takes_precedence_over_config(self):
runtime = SimpleNamespace(
context={"thread_id": "t-from-context"},
config={"configurable": {"thread_id": "t-from-config"}},
)
assert get_thread_id(runtime) == "t-from-context"
def test_falls_back_to_get_config(self):
runtime = SimpleNamespace(context=None, config={})
with patch("langgraph.config.get_config", return_value={"configurable": {"thread_id": "t-from-lg"}}):
assert get_thread_id(runtime) == "t-from-lg"
def test_returns_none_when_get_config_raises_runtime_error(self):
runtime = SimpleNamespace(context=None, config={})
with patch("langgraph.config.get_config", side_effect=RuntimeError):
assert get_thread_id(runtime) is None
def test_handles_object_without_context_or_config(self):
runtime = SimpleNamespace()
assert get_thread_id(runtime) is None
def test_handles_context_not_dict(self):
runtime = SimpleNamespace(context="not-a-dict", config={})
assert get_thread_id(runtime) is None
def test_config_without_configurable(self):
runtime = SimpleNamespace(context=None, config={"other_key": "value"})
assert get_thread_id(runtime) is None
def test_empty_string_thread_id_treated_as_missing(self):
runtime = SimpleNamespace(context={"thread_id": ""}, config={})
assert get_thread_id(runtime) is None
def test_full_cascade_with_all_levels_failing(self):
runtime = SimpleNamespace(context=None, config={})
with patch("langgraph.config.get_config", return_value={"configurable": {}}):
assert get_thread_id(runtime) is None
+3 -3
View File
@@ -23,7 +23,7 @@ class TestThreadDataMiddleware:
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
runtime = Runtime(context=None)
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
"langgraph.config.get_config",
lambda: {"configurable": {"thread_id": "thread-from-config"}},
)
@@ -37,7 +37,7 @@ class TestThreadDataMiddleware:
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
runtime = Runtime(context={})
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
"langgraph.config.get_config",
lambda: {"configurable": {"thread_id": "thread-from-config"}},
)
@@ -50,7 +50,7 @@ class TestThreadDataMiddleware:
def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
"langgraph.config.get_config",
lambda: {"configurable": {}},
)