Files
deer-flow/backend/packages/harness/deerflow/runtime/runs/worker.py
T
greatmengqi 8ba01dfd83 refactor: thread app_config through lead and subagent task path (#2666)
* refactor: thread app config through lead prompt

* fix: honor explicit app config across runtime paths

* style: format subagent executor tests

* fix: thread resolved app config and guard subagents-only fallback

Address two PR review findings:

1. _create_summarization_middleware passed the original (possibly None)
   app_config into create_chat_model, forcing the model factory back to
   ambient get_app_config() and risking config drift between the
   middleware's resolved view and the model's view. Pass the resolved
   AppConfig instance through end-to-end.

2. get_available_subagent_names accepted Any-typed config and forwarded
   it to is_host_bash_allowed, which reads ``.sandbox``. A
   SubagentsAppConfig (also accepted upstream as a sum-type input) has
   no ``.sandbox`` attribute and would be silently treated as "no
   sandbox configured", incorrectly disabling the bash subagent. Guard
   on hasattr and fall back to ambient lookup otherwise.

Adds regression tests for both paths.

* chore: simplify hasattr guard and tighten regression tests

- Collapse if/else into ternary in get_available_subagent_names; hasattr(None, ...) is False so the explicit None check was redundant.
- Drop comments that narrate the change rather than explain non-obvious WHY (test names already convey intent).
- Replace stringly-typed sentinel "no-arg" in regression test with direct args tuple comparison.

---------

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
2026-05-02 06:37:49 +08:00

557 lines
22 KiB
Python

"""Background agent execution.
Runs an agent graph inside an ``asyncio.Task``, publishing events to
a :class:`StreamBridge` as they are produced.
Uses ``graph.astream(stream_mode=[...])`` which gives correct full-state
snapshots for ``values`` mode, proper ``{node: writes}`` for ``updates``,
and ``(chunk, metadata)`` tuples for ``messages`` mode.
Note: ``events`` mode is not supported through the gateway — it requires
``graph.astream_events()`` which cannot simultaneously produce ``values``
snapshots. The JS open-source LangGraph API server works around this via
internal checkpoint callbacks that are not exposed in the Python public API.
"""
from __future__ import annotations
import asyncio
import copy
import inspect
import logging
from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
from .manager import RunManager, RunRecord
from .schemas import RunStatus
logger = logging.getLogger(__name__)
# Valid stream_mode values for LangGraph's graph.astream()
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
def _build_runtime_context(
thread_id: str,
run_id: str,
caller_context: Any | None,
app_config: AppConfig | None = None,
) -> dict[str, Any]:
"""Build the dict that becomes ``ToolRuntime.context`` for the run.
Always includes ``thread_id`` and ``run_id``. Additional keys from the caller's
``config['context']`` (e.g. ``agent_name`` for the bootstrap flow — issue #2677)
are merged in but never override ``thread_id``/``run_id``. The resolved
``AppConfig`` is added by the worker so tools can consume it without ambient
global lookups.
langgraph 1.1+ surfaces this as ``runtime.context`` via the parent runtime stored
under ``config['configurable']['__pregel_runtime']`` — see
``langgraph.pregel.main`` where ``parent_runtime.merge(...)`` is invoked.
"""
runtime_ctx: dict[str, Any] = {"thread_id": thread_id, "run_id": run_id}
if isinstance(caller_context, dict):
for key, value in caller_context.items():
runtime_ctx.setdefault(key, value)
if app_config is not None:
runtime_ctx["app_config"] = app_config
return runtime_ctx
@dataclass(frozen=True)
class RunContext:
"""Infrastructure dependencies for a single agent run.
Groups checkpointer, store, and persistence-related singletons so that
``run_agent`` (and any future callers) receive one object instead of a
growing list of keyword arguments.
"""
checkpointer: Any
store: Any | None = field(default=None)
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None)
app_config: AppConfig | None = field(default=None)
def _install_runtime_context(config: dict, runtime_context: dict[str, Any]) -> None:
existing_context = config.get("context")
if isinstance(existing_context, dict):
existing_context.setdefault("thread_id", runtime_context["thread_id"])
existing_context.setdefault("run_id", runtime_context["run_id"])
if "app_config" in runtime_context:
existing_context["app_config"] = runtime_context["app_config"]
return
config["context"] = dict(runtime_context)
def _compute_agent_factory_supports_app_config(agent_factory: Any) -> bool:
try:
return "app_config" in inspect.signature(agent_factory).parameters
except (TypeError, ValueError):
return False
@lru_cache(maxsize=128)
def _cached_agent_factory_supports_app_config(agent_factory: Any) -> bool:
return _compute_agent_factory_supports_app_config(agent_factory)
def _agent_factory_supports_app_config(agent_factory: Any) -> bool:
try:
return _cached_agent_factory_supports_app_config(agent_factory)
except TypeError:
# Some callable instances are unhashable; fall back to a direct check.
return _compute_agent_factory_supports_app_config(agent_factory)
async def run_agent(
bridge: StreamBridge,
run_manager: RunManager,
record: RunRecord,
*,
ctx: RunContext,
agent_factory: Any,
graph_input: dict,
config: dict,
stream_modes: list[str] | None = None,
stream_subgraphs: bool = False,
interrupt_before: list[str] | Literal["*"] | None = None,
interrupt_after: list[str] | Literal["*"] | None = None,
) -> None:
"""Execute an agent in the background, publishing events to *bridge*."""
# Unpack infrastructure dependencies from RunContext.
checkpointer = ctx.checkpointer
store = ctx.store
event_store = ctx.event_store
run_events_config = ctx.run_events_config
thread_store = ctx.thread_store
run_id = record.run_id
thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"])
pre_run_checkpoint_id: str | None = None
pre_run_snapshot: dict[str, Any] | None = None
snapshot_capture_failed = False
journal = None
journal = None
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
"Run %s: 'events' stream_mode not supported in gateway (requires astream_events + checkpoint callbacks). Skipping.",
run_id,
)
try:
# Initialize RunJournal + write human_message event.
# These are inside the try block so any exception (e.g. a DB
# error writing the event) flows through the except/finally
# path that publishes an "end" event to the SSE bridge —
# otherwise a failure here would leave the stream hanging
# with no terminator.
if event_store is not None:
from deerflow.runtime.journal import RunJournal
journal = RunJournal(
run_id=run_id,
thread_id=thread_id,
event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
# 1. Mark running
await run_manager.set_status(run_id, RunStatus.running)
# Snapshot the latest pre-run checkpoint so rollback can restore it.
if checkpointer is not None:
try:
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
if ckpt_tuple is not None:
ckpt_config = getattr(ckpt_tuple, "config", {}).get("configurable", {})
pre_run_checkpoint_id = ckpt_config.get("checkpoint_id")
pre_run_snapshot = {
"checkpoint_ns": ckpt_config.get("checkpoint_ns", ""),
"checkpoint": copy.deepcopy(getattr(ckpt_tuple, "checkpoint", {})),
"metadata": copy.deepcopy(getattr(ckpt_tuple, "metadata", {})),
"pending_writes": copy.deepcopy(getattr(ckpt_tuple, "pending_writes", []) or []),
}
except Exception:
snapshot_capture_failed = True
logger.warning("Could not capture pre-run checkpoint snapshot for run %s", run_id, exc_info=True)
# 2. Publish metadata — useStream needs both run_id AND thread_id
await bridge.publish(
run_id,
"metadata",
{
"run_id": run_id,
"thread_id": thread_id,
},
)
# 3. Build the agent
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares and tools (via ToolRuntime.context) can
# access thread-level data. langgraph-cli does this automatically; we must do it
# manually here because we drive the graph through ``agent.astream(config=...)``
# without passing the official ``context=`` parameter.
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None:
config.setdefault("callbacks", []).append(journal)
runnable_config = RunnableConfig(**config)
if ctx.app_config is not None and _agent_factory_supports_app_config(agent_factory):
agent = agent_factory(config=runnable_config, app_config=ctx.app_config)
else:
agent = agent_factory(config=runnable_config)
# 4. Attach checkpointer and store
if checkpointer is not None:
agent.checkpointer = checkpointer
if store is not None:
agent.store = store
# 5. Set interrupt nodes
if interrupt_before:
agent.interrupt_before_nodes = interrupt_before
if interrupt_after:
agent.interrupt_after_nodes = interrupt_after
# 6. Build LangGraph stream_mode list
# "events" is NOT a valid astream mode — skip it
# "messages-tuple" maps to LangGraph's "messages" mode
lg_modes: list[str] = []
for m in requested_modes:
if m == "messages-tuple":
lg_modes.append("messages")
elif m == "events":
# Skipped — see log above
continue
elif m in _VALID_LG_MODES:
lg_modes.append(m)
if not lg_modes:
lg_modes = ["values"]
# Deduplicate while preserving order
seen: set[str] = set()
deduped: list[str] = []
for m in lg_modes:
if m not in seen:
seen.add(m)
deduped.append(m)
lg_modes = deduped
logger.info("Run %s: streaming with modes %s (requested: %s)", run_id, lg_modes, requested_modes)
# 7. Stream using graph.astream
if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
sse_event = _lg_mode_to_sse_event(single_mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode))
else:
# Multiple modes or subgraphs: astream yields tuples
async for item in agent.astream(
graph_input,
config=runnable_config,
stream_mode=lg_modes,
subgraphs=stream_subgraphs,
):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
mode, chunk = _unpack_stream_item(item, lg_modes, stream_subgraphs)
if mode is None:
continue
sse_event = _lg_mode_to_sse_event(mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode))
# 8. Final status
if record.abort_event.is_set():
action = record.abort_action
if action == "rollback":
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
try:
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id=thread_id,
run_id=run_id,
pre_run_checkpoint_id=pre_run_checkpoint_id,
pre_run_snapshot=pre_run_snapshot,
snapshot_capture_failed=snapshot_capture_failed,
)
logger.info("Run %s rolled back to pre-run checkpoint %s", run_id, pre_run_checkpoint_id)
except Exception:
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
else:
await run_manager.set_status(run_id, RunStatus.success)
except asyncio.CancelledError:
action = record.abort_action
if action == "rollback":
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
try:
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id=thread_id,
run_id=run_id,
pre_run_checkpoint_id=pre_run_checkpoint_id,
pre_run_snapshot=pre_run_snapshot,
snapshot_capture_failed=snapshot_capture_failed,
)
logger.info("Run %s was cancelled and rolled back", run_id)
except Exception:
logger.warning("Run %s cancellation rollback failed", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
logger.info("Run %s was cancelled", run_id)
except Exception as exc:
error_msg = f"{exc}"
logger.exception("Run %s failed: %s", run_id, error_msg)
await run_manager.set_status(run_id, RunStatus.error, error=error_msg)
await bridge.publish(
run_id,
"error",
{
"message": error_msg,
"name": type(exc).__name__,
},
)
finally:
# Flush any buffered journal events and persist completion data
if journal is not None:
try:
await journal.flush()
except Exception:
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
try:
# Persist token usage + convenience fields to RunStore
completion = journal.get_completion_data()
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
except Exception:
logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True)
# Sync title from checkpoint to threads_meta.display_name
if checkpointer is not None and thread_store is not None:
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is not None:
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
title = ckpt.get("channel_values", {}).get("title")
if title:
await thread_store.update_display_name(thread_id, title)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
# Update threads_meta status based on run outcome
if thread_store is not None:
try:
final_status = "idle" if record.status == RunStatus.success else record.status.value
await thread_store.update_status(thread_id, final_status)
except Exception:
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _call_checkpointer_method(checkpointer: Any, async_name: str, sync_name: str, *args: Any, **kwargs: Any) -> Any:
"""Call a checkpointer method, supporting async and sync variants."""
method = getattr(checkpointer, async_name, None) or getattr(checkpointer, sync_name, None)
if method is None:
raise AttributeError(f"Missing checkpointer method: {async_name}/{sync_name}")
result = method(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
async def _rollback_to_pre_run_checkpoint(
*,
checkpointer: Any,
thread_id: str,
run_id: str,
pre_run_checkpoint_id: str | None,
pre_run_snapshot: dict[str, Any] | None,
snapshot_capture_failed: bool,
) -> None:
"""Restore thread state to the checkpoint snapshot captured before run start."""
if checkpointer is None:
logger.info("Run %s rollback requested but no checkpointer is configured", run_id)
return
if snapshot_capture_failed:
logger.warning("Run %s rollback skipped: pre-run checkpoint snapshot capture failed", run_id)
return
if pre_run_snapshot is None:
await _call_checkpointer_method(checkpointer, "adelete_thread", "delete_thread", thread_id)
logger.info("Run %s rollback reset thread %s to empty state", run_id, thread_id)
return
checkpoint_to_restore = None
metadata_to_restore: dict[str, Any] = {}
checkpoint_ns = ""
checkpoint = pre_run_snapshot.get("checkpoint")
if not isinstance(checkpoint, dict):
logger.warning("Run %s rollback skipped: invalid pre-run checkpoint snapshot", run_id)
return
checkpoint_to_restore = checkpoint
if checkpoint_to_restore.get("id") is None and pre_run_checkpoint_id is not None:
checkpoint_to_restore = {**checkpoint_to_restore, "id": pre_run_checkpoint_id}
if checkpoint_to_restore.get("id") is None:
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
return
metadata = pre_run_snapshot.get("metadata", {})
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
checkpoint_ns = raw_checkpoint_ns if isinstance(raw_checkpoint_ns, str) else ""
channel_versions = checkpoint_to_restore.get("channel_versions")
new_versions = dict(channel_versions) if isinstance(channel_versions, dict) else {}
restore_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
restored_config = await _call_checkpointer_method(
checkpointer,
"aput",
"put",
restore_config,
checkpoint_to_restore,
metadata_to_restore if isinstance(metadata_to_restore, dict) else {},
new_versions,
)
if not isinstance(restored_config, dict):
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config: expected dict")
restored_configurable = restored_config.get("configurable", {})
if not isinstance(restored_configurable, dict):
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config payload")
restored_checkpoint_id = restored_configurable.get("checkpoint_id")
if not restored_checkpoint_id:
raise RuntimeError(f"Run {run_id} rollback restore did not return checkpoint_id")
pending_writes = pre_run_snapshot.get("pending_writes", [])
if not pending_writes:
return
writes_by_task: dict[str, list[tuple[str, Any]]] = {}
for item in pending_writes:
if not isinstance(item, (tuple, list)) or len(item) != 3:
raise RuntimeError(f"Run {run_id} rollback failed: pending_write is not a 3-tuple: {item!r}")
task_id, channel, value = item
if not isinstance(channel, str):
raise RuntimeError(f"Run {run_id} rollback failed: pending_write has non-string channel: task_id={task_id!r}, channel={channel!r}")
writes_by_task.setdefault(str(task_id), []).append((channel, value))
for task_id, writes in writes_by_task.items():
await _call_checkpointer_method(
checkpointer,
"aput_writes",
"put_writes",
restored_config,
writes,
task_id=task_id,
)
def _lg_mode_to_sse_event(mode: str) -> str:
"""Map LangGraph internal stream_mode name to SSE event name.
LangGraph's ``astream(stream_mode="messages")`` produces message
tuples. The SSE protocol calls this ``messages-tuple`` when the
client explicitly requests it, but the default SSE event name used
by LangGraph Platform is simply ``"messages"``.
"""
# All LG modes map 1:1 to SSE event names — "messages" stays "messages"
return mode
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
"""Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get
the checkpoint-aligned serialization format.
"""
from langchain_core.messages import HumanMessage
messages = graph_input.get("messages")
if not messages:
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
content = last.content
return HumanMessage(content=content)
if isinstance(last, dict):
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
def _unpack_stream_item(
item: Any,
lg_modes: list[str],
stream_subgraphs: bool,
) -> tuple[str | None, Any]:
"""Unpack a multi-mode or subgraph stream item into (mode, chunk).
Returns ``(None, None)`` if the item cannot be parsed.
"""
if stream_subgraphs:
if isinstance(item, tuple) and len(item) == 3:
_ns, mode, chunk = item
return str(mode), chunk
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
return None, None
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
# Fallback: single-element output from first mode
return lg_modes[0] if lg_modes else None, item