refactor: thread release config through lead path (#2612)

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
greatmengqi
2026-04-28 14:53:18 +08:00
committed by GitHub
parent 69649d8aae
commit e82940c03d
20 changed files with 325 additions and 179 deletions
+1
View File
@@ -148,6 +148,7 @@ def get_run_context(request: Request) -> RunContext:
event_store=get_run_event_store(request),
run_events_config=getattr(config, "run_events", None),
thread_store=get_thread_store(request),
app_config=config,
)
+5 -6
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"])
@@ -36,7 +37,7 @@ class ModelsListResponse(BaseModel):
summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.",
)
async def list_models() -> ModelsListResponse:
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
"""List all available models from configuration.
Returns model information suitable for frontend display,
@@ -72,7 +73,6 @@ async def list_models() -> ModelsListResponse:
}
```
"""
config = get_app_config()
models = [
ModelResponse(
name=model.name,
@@ -96,7 +96,7 @@ async def list_models() -> ModelsListResponse:
summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.",
)
async def get_model(model_name: str) -> ModelResponse:
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
"""Get a specific model by name.
Args:
@@ -118,7 +118,6 @@ async def get_model(model_name: str) -> ModelResponse:
}
```
"""
config = get_app_config()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+37 -33
View File
@@ -4,11 +4,13 @@ import logging
import shutil
from pathlib import Path
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
@@ -101,9 +103,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.",
)
async def list_skills() -> SkillsListResponse:
async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(enabled_only=False, app_config=config)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True)
@@ -136,9 +138,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse:
async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
skills = [skill for skill in load_skills(enabled_only=False, app_config=config) if skill.category == "custom"]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True)
@@ -146,13 +148,13 @@ async def list_custom_skills() -> SkillsListResponse:
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(enabled_only=False, app_config=config)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name, app_config=config))
except HTTPException:
raise
except Exception as e:
@@ -161,14 +163,14 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
ensure_custom_skill_is_editable(skill_name)
ensure_custom_skill_is_editable(skill_name, app_config=config)
validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md", app_config=config)
if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
skill_file = get_custom_skill_dir(skill_name, app_config=config) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content)
append_history(
@@ -182,9 +184,10 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
"new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason},
},
app_config=config,
)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
return await get_custom_skill(skill_name, config)
except HTTPException:
raise
except FileNotFoundError as e:
@@ -197,11 +200,11 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> dict[str, bool]:
try:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
ensure_custom_skill_is_editable(skill_name, app_config=config)
skill_dir = get_custom_skill_dir(skill_name, app_config=config)
prev_content = read_custom_skill_content(skill_name, app_config=config)
try:
append_history(
skill_name,
@@ -214,6 +217,7 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
app_config=config,
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
@@ -232,11 +236,11 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
if not custom_skill_exists(skill_name, app_config=config) and not get_skill_history_file(skill_name, app_config=config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name))
return CustomSkillHistoryResponse(history=read_history(skill_name, app_config=config))
except HTTPException:
raise
except Exception as e:
@@ -245,11 +249,11 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
if not custom_skill_exists(skill_name, app_config=config) and not get_skill_history_file(skill_name, app_config=config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name)
history = read_history(skill_name, app_config=config)
if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index]
@@ -257,8 +261,8 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md", app_config=config)
skill_file = get_custom_skill_file(skill_name, app_config=config)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = {
"action": "rollback",
@@ -271,12 +275,12 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
"scanner": {"decision": scan.decision, "reason": scan.reason},
}
if scan.decision == "block":
append_history(skill_name, history_entry)
append_history(skill_name, history_entry, app_config=config)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content)
append_history(skill_name, history_entry)
append_history(skill_name, history_entry, app_config=config)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
return await get_custom_skill(skill_name, config)
except HTTPException:
raise
except IndexError:
@@ -296,9 +300,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.",
)
async def get_skill(skill_name: str) -> SkillResponse:
async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(enabled_only=False, app_config=config)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@@ -318,9 +322,9 @@ async def get_skill(skill_name: str) -> SkillResponse:
summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.",
)
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
async def update_skill(skill_name: str, request: SkillUpdateRequest, config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(enabled_only=False, app_config=config)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@@ -346,7 +350,7 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
reload_extensions_config()
await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False)
skills = load_skills(enabled_only=False, app_config=config)
updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None:
+10 -3
View File
@@ -1,11 +1,13 @@
import json
import logging
from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Request
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -100,7 +102,12 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
@require_permission("threads", "read", owner_check=True)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
async def generate_suggestions(
thread_id: str,
body: SuggestionsRequest,
request: Request,
config: AppConfig = Depends(get_config),
) -> SuggestionsResponse:
if not body.messages:
return SuggestionsResponse(suggestions=[])
@@ -122,7 +129,7 @@ async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=body.model_name, thinking_enabled=False)
model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=config)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
+9 -8
View File
@@ -4,11 +4,12 @@ import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from app.gateway.authz import require_permission
from deerflow.config.app_config import get_app_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
@@ -60,23 +61,22 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object:
def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
uploads_cfg = getattr(app_config, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool:
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value("auto_convert_documents", False)
raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
@@ -90,6 +90,7 @@ async def upload_files(
thread_id: str,
request: Request,
files: list[UploadFile] = File(...),
config: AppConfig = Depends(get_config),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
if not files:
@@ -108,7 +109,7 @@ async def upload_files(
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled()
auto_convert_documents = _auto_convert_documents_enabled(config)
for file in files:
if not file.filename:
@@ -18,7 +18,7 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model
@@ -35,9 +35,9 @@ def _get_runtime_config(config: RunnableConfig) -> dict:
return cfg
def _resolve_model_name(requested_model_name: str | None = None) -> str:
def _resolve_model_name(requested_model_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
app_config = app_config or get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@@ -50,7 +50,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
@@ -73,9 +73,9 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
else:
model = create_chat_model(thinking_enabled=False)
model = create_chat_model(thinking_enabled=False, app_config=app_config)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs
@@ -99,7 +99,8 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup.
try:
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
resolved_app_config = app_config or get_app_config()
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
except Exception:
logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills"
@@ -240,7 +241,14 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
def _build_middlewares(
config: RunnableConfig,
model_name: str | None,
agent_name: str | None = None,
custom_middlewares: list[AgentMiddleware] | None = None,
*,
app_config: AppConfig | None = None,
):
"""Build middleware chain based on runtime configuration.
Args:
@@ -252,9 +260,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
List of middleware instances.
"""
middlewares = build_lead_runtime_middlewares(lazy_init=True)
resolved_app_config = app_config or get_app_config()
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware()
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
if summarization_middleware is not None:
middlewares.append(summarization_middleware)
@@ -266,7 +275,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled:
if resolved_app_config.token_usage.enabled:
middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware
@@ -277,13 +286,12 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
model_config = resolved_app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
if app_config.tool_search.enabled:
if resolved_app_config.tool_search.enabled:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
middlewares.append(DeferredToolFilterMiddleware())
@@ -306,12 +314,13 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
return middlewares
def make_lead_agent(config: RunnableConfig):
def make_lead_agent(config: RunnableConfig, app_config: AppConfig | None = None):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
cfg = _get_runtime_config(config)
resolved_app_config = app_config or get_app_config()
thinking_enabled = cfg.get("thinking_enabled", True)
reasoning_effort = cfg.get("reasoning_effort", None)
@@ -327,10 +336,9 @@ def make_lead_agent(config: RunnableConfig):
agent_model_name = agent_config.model if agent_config and agent_config.model else None
# Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name)
model_name = _resolve_model_name(requested_model_name or agent_model_name, app_config=resolved_app_config)
app_config = get_app_config()
model_config = app_config.get_model_config(model_name)
model_config = resolved_app_config.get_model_config(model_name)
if model_config is None:
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
@@ -369,20 +377,34 @@ def make_lead_agent(config: RunnableConfig):
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
available_skills=set(["bootstrap"]),
app_config=resolved_app_config,
),
state_schema=ThreadState,
)
# Default lead agent (unchanged behavior)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
tools=get_available_tools(
model_name=model_name,
groups=agent_config.tool_groups if agent_config else None,
subagent_enabled=subagent_enabled,
app_config=resolved_app_config,
),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=agent_name,
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
app_config=resolved_app_config,
),
state_schema=ThreadState,
)
@@ -1,14 +1,20 @@
from __future__ import annotations
import asyncio
import logging
import threading
from datetime import datetime
from functools import lru_cache
from typing import TYPE_CHECKING
from deerflow.config.agents_config import load_agent_soul
from deerflow.skills import load_skills
from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
@@ -111,6 +117,19 @@ def _get_enabled_skills():
return []
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
"""Return enabled skills using the caller's config source.
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
cache so the skill list and skill paths are resolved from the same config
object. This keeps request-scoped config injection consistent even while the
release branch still supports global fallback paths.
"""
if app_config is None:
return _get_enabled_skills()
return list(load_skills(enabled_only=True, app_config=app_config))
def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]"
@@ -576,14 +595,14 @@ You have access to skills that provide optimized workflows for specific tasks. E
</skill_system>"""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills()
skills = _get_enabled_skills_for_config(app_config)
try:
from deerflow.config import get_app_config
config = get_app_config()
config = app_config or get_app_config()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
@@ -612,7 +631,7 @@ def get_agent_soul(agent_name: str | None) -> str:
return ""
def get_deferred_tools_prompt_section() -> str:
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists
@@ -624,7 +643,8 @@ def get_deferred_tools_prompt_section() -> str:
try:
from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
config = app_config or get_app_config()
if not config.tool_search.enabled:
return ""
except Exception:
return ""
@@ -657,12 +677,13 @@ def _build_acp_section() -> str:
)
def _build_custom_mounts_section() -> str:
def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
"""Build a prompt section for explicitly configured sandbox mounts."""
try:
from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
config = app_config or get_app_config()
mounts = config.sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
@@ -679,7 +700,14 @@ def _build_custom_mounts_section() -> str:
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
def apply_prompt_template(
subagent_enabled: bool = False,
max_concurrent_subagents: int = 3,
*,
agent_name: str | None = None,
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name)
@@ -706,14 +734,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
)
# Get skills section
skills_section = get_skills_prompt_section(available_skills)
skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
# Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section()
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
# Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section()
custom_mounts_section = _build_custom_mounts_section()
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory
@@ -3,6 +3,7 @@ import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
@@ -46,7 +47,7 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
model_settings_from_config["stream_usage"] = True
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config.
Args:
@@ -55,7 +56,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
Returns:
A chat model instance.
"""
config = get_app_config()
config = app_config or get_app_config()
if name is None:
name = config.models[0].name
model_config = config.get_model_config(name)
@@ -20,11 +20,13 @@ import copy
import inspect
import logging
from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
@@ -51,6 +53,27 @@ class RunContext:
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None)
app_config: AppConfig | None = field(default=None)
def _compute_agent_factory_supports_app_config(agent_factory: Any) -> bool:
try:
return "app_config" in inspect.signature(agent_factory).parameters
except (TypeError, ValueError):
return False
@lru_cache(maxsize=128)
def _cached_agent_factory_supports_app_config(agent_factory: Any) -> bool:
return _compute_agent_factory_supports_app_config(agent_factory)
def _agent_factory_supports_app_config(agent_factory: Any) -> bool:
try:
return _cached_agent_factory_supports_app_config(agent_factory)
except TypeError:
# Some callable instances are unhashable; fall back to a direct check.
return _compute_agent_factory_supports_app_config(agent_factory)
async def run_agent(
@@ -163,7 +186,10 @@ async def run_agent(
config.setdefault("callbacks", []).append(journal)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
if ctx.app_config is not None and _agent_factory_supports_app_config(agent_factory):
agent = agent_factory(config=runnable_config, app_config=ctx.app_config)
else:
agent = agent_factory(config=runnable_config)
# 4. Attach checkpointer and store
if checkpointer is not None:
@@ -2,6 +2,8 @@ import logging
import os
from pathlib import Path
from deerflow.config.app_config import AppConfig
from .parser import parse_skill_file
from .types import Skill
@@ -22,7 +24,7 @@ def get_skills_root_path() -> Path:
return skills_dir
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False, *, app_config: AppConfig | None = None) -> list[Skill]:
"""
Load all skills from the skills directory.
@@ -44,7 +46,7 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
try:
from deerflow.config import get_app_config
config = get_app_config()
config = app_config or get_app_config()
skills_path = config.skills.get_skills_path()
except Exception:
# Fallback to default if config fails
@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.skills.loader import load_skills
from deerflow.skills.validation import _validate_skill_frontmatter
@@ -20,16 +21,17 @@ ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path:
return get_app_config().skills.get_skills_path()
def get_skills_root_dir(*, app_config: AppConfig | None = None) -> Path:
config = app_config or get_app_config()
return config.skills.get_skills_path()
def get_public_skills_dir() -> Path:
return get_skills_root_dir() / "public"
def get_public_skills_dir(*, app_config: AppConfig | None = None) -> Path:
return get_skills_root_dir(app_config=app_config) / "public"
def get_custom_skills_dir() -> Path:
path = get_skills_root_dir() / "custom"
def get_custom_skills_dir(*, app_config: AppConfig | None = None) -> Path:
path = get_skills_root_dir(app_config=app_config) / "custom"
path.mkdir(parents=True, exist_ok=True)
return path
@@ -43,46 +45,46 @@ def validate_skill_name(name: str) -> str:
return normalized
def get_custom_skill_dir(name: str) -> Path:
return get_custom_skills_dir() / validate_skill_name(name)
def get_custom_skill_dir(name: str, *, app_config: AppConfig | None = None) -> Path:
return get_custom_skills_dir(app_config=app_config) / validate_skill_name(name)
def get_custom_skill_file(name: str) -> Path:
return get_custom_skill_dir(name) / SKILL_FILE_NAME
def get_custom_skill_file(name: str, *, app_config: AppConfig | None = None) -> Path:
return get_custom_skill_dir(name, app_config=app_config) / SKILL_FILE_NAME
def get_custom_skill_history_dir() -> Path:
path = get_custom_skills_dir() / HISTORY_DIR_NAME
def get_custom_skill_history_dir(*, app_config: AppConfig | None = None) -> Path:
path = get_custom_skills_dir(app_config=app_config) / HISTORY_DIR_NAME
path.mkdir(parents=True, exist_ok=True)
return path
def get_skill_history_file(name: str) -> Path:
return get_custom_skill_history_dir() / f"{validate_skill_name(name)}.jsonl"
def get_skill_history_file(name: str, *, app_config: AppConfig | None = None) -> Path:
return get_custom_skill_history_dir(app_config=app_config) / f"{validate_skill_name(name)}.jsonl"
def get_public_skill_dir(name: str) -> Path:
return get_public_skills_dir() / validate_skill_name(name)
def get_public_skill_dir(name: str, *, app_config: AppConfig | None = None) -> Path:
return get_public_skills_dir(app_config=app_config) / validate_skill_name(name)
def custom_skill_exists(name: str) -> bool:
return get_custom_skill_file(name).exists()
def custom_skill_exists(name: str, *, app_config: AppConfig | None = None) -> bool:
return get_custom_skill_file(name, app_config=app_config).exists()
def public_skill_exists(name: str) -> bool:
return (get_public_skill_dir(name) / SKILL_FILE_NAME).exists()
def public_skill_exists(name: str, *, app_config: AppConfig | None = None) -> bool:
return (get_public_skill_dir(name, app_config=app_config) / SKILL_FILE_NAME).exists()
def ensure_custom_skill_is_editable(name: str) -> None:
if custom_skill_exists(name):
def ensure_custom_skill_is_editable(name: str, *, app_config: AppConfig | None = None) -> None:
if custom_skill_exists(name, app_config=app_config):
return
if public_skill_exists(name):
if public_skill_exists(name, app_config=app_config):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise FileNotFoundError(f"Custom skill '{name}' not found.")
def ensure_safe_support_path(name: str, relative_path: str) -> Path:
skill_dir = get_custom_skill_dir(name).resolve()
def ensure_safe_support_path(name: str, relative_path: str, *, app_config: AppConfig | None = None) -> Path:
skill_dir = get_custom_skill_dir(name, app_config=app_config).resolve()
if not relative_path or relative_path.endswith("/"):
raise ValueError("Supporting file path must include a filename.")
relative = Path(relative_path)
@@ -124,8 +126,8 @@ def atomic_write(path: Path, content: str) -> None:
tmp_path.replace(path)
def append_history(name: str, record: dict[str, Any]) -> None:
history_path = get_skill_history_file(name)
def append_history(name: str, record: dict[str, Any], *, app_config: AppConfig | None = None) -> None:
history_path = get_skill_history_file(name, app_config=app_config)
history_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"ts": datetime.now(UTC).isoformat(),
@@ -136,8 +138,8 @@ def append_history(name: str, record: dict[str, Any]) -> None:
f.write("\n")
def read_history(name: str) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name)
def read_history(name: str, *, app_config: AppConfig | None = None) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name, app_config=app_config)
if not history_path.exists():
return []
records: list[dict[str, Any]] = []
@@ -148,12 +150,12 @@ def read_history(name: str) -> list[dict[str, Any]]:
return records
def list_custom_skills() -> list:
return [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
def list_custom_skills(*, app_config: AppConfig | None = None) -> list:
return [skill for skill in load_skills(enabled_only=False, app_config=app_config) if skill.category == "custom"]
def read_custom_skill_content(name: str) -> str:
skill_file = get_custom_skill_file(name)
def read_custom_skill_content(name: str, *, app_config: AppConfig | None = None) -> str:
skill_file = get_custom_skill_file(name, app_config=app_config)
if not skill_file.exists():
raise FileNotFoundError(f"Custom skill '{name}' not found.")
return skill_file.read_text(encoding="utf-8")
@@ -8,6 +8,7 @@ import re
from dataclasses import dataclass
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@ def _extract_json_object(raw: str) -> dict | None:
return None
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md", app_config: AppConfig | None = None) -> ScanResult:
"""Screen skill content before it is written to disk."""
rubric = (
"You are a security reviewer for AI agent skills. "
@@ -47,9 +48,9 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try:
config = get_app_config()
config = app_config or get_app_config()
model_name = config.skill_evolution.moderation_model_name
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
model = create_chat_model(name=model_name, thinking_enabled=False, app_config=config) if model_name else create_chat_model(thinking_enabled=False, app_config=config)
response = await model.ainvoke(
[
{"role": "system", "content": rubric},
@@ -3,6 +3,7 @@ import logging
from langchain.tools import BaseTool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
@@ -37,6 +38,8 @@ def get_available_tools(
include_mcp: bool = True,
model_name: str | None = None,
subagent_enabled: bool = False,
*,
app_config: AppConfig | None = None,
) -> list[BaseTool]:
"""Get all available tools from config.
@@ -52,7 +55,7 @@ def get_available_tools(
Returns:
List of available tools.
"""
config = get_app_config()
config = app_config or get_app_config()
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
# Do not expose host bash by default when LocalSandboxProvider is active.
@@ -84,14 +84,15 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
@@ -110,6 +111,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
assert captured["name"] == "safe-model"
assert captured["thinking_enabled"] is False
assert captured["app_config"] is app_config
assert result["model"] is not None
@@ -126,14 +128,15 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
get_available_tools = MagicMock(return_value=[])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
@@ -156,8 +159,9 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
"name": "context-model",
"thinking_enabled": False,
"reasoning_effort": "high",
"app_config": app_config,
}
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True)
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True, app_config=app_config)
assert result["model"] is not None
@@ -198,10 +202,15 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
model_name="vision-model",
custom_middlewares=[MagicMock()],
app_config=app_config,
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly
@@ -222,18 +231,20 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware()
middleware = lead_agent_module._create_summarization_middleware(app_config=_make_app_config([_make_model("model-masswork", supports_thinking=False)]))
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
assert captured["app_config"] is not None
assert middleware["model"] is fake_model
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
+2 -2
View File
@@ -48,7 +48,7 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
@@ -66,7 +66,7 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
+19 -1
View File
@@ -100,6 +100,24 @@ def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeyp
assert "Skill Self-Evolution" not in disabled_result
def test_get_skills_prompt_section_uses_explicit_config_for_enabled_skills(monkeypatch):
explicit_config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/alt-skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [_make_skill("global-skill")])
monkeypatch.setattr(
"deerflow.agents.lead_agent.prompt.load_skills",
lambda enabled_only=True, app_config=None: [_make_skill("explicit-skill")] if app_config is explicit_config else [],
)
result = get_skills_prompt_section(app_config=explicit_config)
assert "explicit-skill" in result
assert "global-skill" not in result
def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
from unittest.mock import MagicMock
@@ -107,7 +125,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
# Mock dependencies
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
+18 -1
View File
@@ -2,7 +2,7 @@ from unittest.mock import AsyncMock, call
import pytest
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint
from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _rollback_to_pre_run_checkpoint
class FakeCheckpointer:
@@ -212,3 +212,20 @@ async def test_rollback_propagates_aput_writes_failure():
# aput succeeded, aput_writes was called but failed
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_awaited_once()
def test_agent_factory_supports_app_config_detects_supported_signature():
def factory(*, config, app_config=None):
return (config, app_config)
assert _agent_factory_supports_app_config(factory) is True
def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch):
class BrokenCallable:
def __call__(self, **kwargs):
return kwargs
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
assert _agent_factory_supports_app_config(BrokenCallable()) is False
+15 -14
View File
@@ -35,6 +35,13 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
)
def _make_test_app(config) -> FastAPI:
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
return app
def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
skills_root = tmp_path / "skills"
custom_dir = skills_root / "custom" / "demo-skill"
@@ -54,8 +61,7 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
response = client.get("/api/skills/custom")
@@ -96,7 +102,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
get_skill_history_file("demo-skill").write_text(
get_skill_history_file("demo-skill", app_config=config).write_text(
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
encoding="utf-8",
)
@@ -113,8 +119,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
rollback_response = client.post("/api/skills/custom/demo-skill/rollback", json={"history_index": -1})
@@ -146,8 +151,7 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -187,8 +191,7 @@ def test_custom_skill_delete_continues_when_history_write_is_readonly(monkeypatc
monkeypatch.setattr("app.gateway.routers.skills.append_history", _readonly_history)
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -221,8 +224,7 @@ def test_custom_skill_delete_fails_when_skill_dir_removal_fails(monkeypatch, tmp
monkeypatch.setattr("app.gateway.routers.skills.shutil.rmtree", _fail_rmtree)
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -238,7 +240,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
enabled_state = {"value": True}
refresh_calls = []
def _load_skills(*, enabled_only: bool):
def _load_skills(*, enabled_only: bool, app_config=None):
skill = _make_skill("demo-skill", enabled=enabled_state["value"])
if enabled_only and not skill.enabled:
return []
@@ -254,8 +256,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(SimpleNamespace())
with TestClient(app) as client:
response = client.put("/api/skills/demo-skill", json={"enabled": False})
+5 -4
View File
@@ -1,4 +1,5 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from app.gateway.routers import suggestions
@@ -48,7 +49,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
# Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2", "Q3"]
fake_model.ainvoke.assert_awaited_once()
@@ -70,7 +71,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
# Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2"]
fake_model.ainvoke.assert_awaited_once()
@@ -92,7 +93,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
# Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2"]
fake_model.ainvoke.assert_awaited_once()
@@ -111,6 +112,6 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
# Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == []
+21 -20
View File
@@ -2,6 +2,7 @@ import asyncio
import stat
from io import BytesIO
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from _router_auth_helpers import call_unwrapped
@@ -26,7 +27,7 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert len(result.files) == 1
@@ -49,7 +50,7 @@ def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
@@ -75,7 +76,7 @@ def test_upload_files_does_not_auto_convert_documents_by_default(tmp_path):
patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock,
):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert len(result.files) == 1
@@ -108,7 +109,7 @@ def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path):
patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=fake_convert)),
):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert len(result.files) == 1
@@ -147,7 +148,7 @@ def test_upload_files_makes_non_local_files_sandbox_writable(tmp_path):
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
make_writable.assert_any_call(thread_uploads_dir / "report.pdf")
@@ -171,7 +172,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
make_writable.assert_not_called()
@@ -222,13 +223,13 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
# These filenames must be rejected outright
for bad_name in ["..", "."]:
file = UploadFile(filename=bad_name, file=BytesIO(b"data"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert result.files == [], f"Expected no files for unsafe filename {bad_name!r}"
# Path-traversal prefixes are stripped to the basename and accepted safely
file = UploadFile(filename="../etc/passwd", file=BytesIO(b"data"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert len(result.files) == 1
assert result.files[0]["filename"] == "passwd"
@@ -252,16 +253,20 @@ def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path):
def test_auto_convert_documents_enabled_defaults_to_false_on_config_errors():
with patch.object(uploads, "get_app_config", side_effect=RuntimeError("boom")):
assert uploads._auto_convert_documents_enabled() is False
class BrokenConfig:
def __getattribute__(self, name):
if name == "uploads":
raise RuntimeError("boom")
return super().__getattribute__(name)
assert uploads._auto_convert_documents_enabled(BrokenConfig()) is False
def test_auto_convert_documents_enabled_reads_dict_backed_uploads_config():
cfg = MagicMock()
cfg.uploads = {"auto_convert_documents": True}
with patch.object(uploads, "get_app_config", return_value=cfg):
assert uploads._auto_convert_documents_enabled() is True
assert uploads._auto_convert_documents_enabled(cfg) is True
def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values():
@@ -277,11 +282,7 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values
string_false_cfg = MagicMock()
string_false_cfg.uploads = MagicMock(auto_convert_documents="false")
with patch.object(uploads, "get_app_config", return_value=false_cfg):
assert uploads._auto_convert_documents_enabled() is False
with patch.object(uploads, "get_app_config", return_value=true_cfg):
assert uploads._auto_convert_documents_enabled() is True
with patch.object(uploads, "get_app_config", return_value=string_true_cfg):
assert uploads._auto_convert_documents_enabled() is True
with patch.object(uploads, "get_app_config", return_value=string_false_cfg):
assert uploads._auto_convert_documents_enabled() is False
assert uploads._auto_convert_documents_enabled(false_cfg) is False
assert uploads._auto_convert_documents_enabled(true_cfg) is True
assert uploads._auto_convert_documents_enabled(string_true_cfg) is True
assert uploads._auto_convert_documents_enabled(string_false_cfg) is False