refactor: thread release config through lead path (#2612)
Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
@@ -148,6 +148,7 @@ def get_run_context(request: Request) -> RunContext:
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(config, "run_events", None),
|
||||
thread_store=get_thread_store(request),
|
||||
app_config=config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["models"])
|
||||
|
||||
@@ -36,7 +37,7 @@ class ModelsListResponse(BaseModel):
|
||||
summary="List All Models",
|
||||
description="Retrieve a list of all available AI models configured in the system.",
|
||||
)
|
||||
async def list_models() -> ModelsListResponse:
|
||||
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
|
||||
"""List all available models from configuration.
|
||||
|
||||
Returns model information suitable for frontend display,
|
||||
@@ -72,7 +73,6 @@ async def list_models() -> ModelsListResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = get_app_config()
|
||||
models = [
|
||||
ModelResponse(
|
||||
name=model.name,
|
||||
@@ -96,7 +96,7 @@ async def list_models() -> ModelsListResponse:
|
||||
summary="Get Model Details",
|
||||
description="Retrieve detailed information about a specific AI model by its name.",
|
||||
)
|
||||
async def get_model(model_name: str) -> ModelResponse:
|
||||
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
|
||||
"""Get a specific model by name.
|
||||
|
||||
Args:
|
||||
@@ -118,7 +118,6 @@ async def get_model(model_name: str) -> ModelResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = get_app_config()
|
||||
model = config.get_model_config(model_name)
|
||||
if model is None:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
@@ -4,11 +4,13 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||
from deerflow.skills import Skill, load_skills
|
||||
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
||||
@@ -101,9 +103,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
|
||||
summary="List All Skills",
|
||||
description="Retrieve a list of all available skills from both public and custom directories.",
|
||||
)
|
||||
async def list_skills() -> SkillsListResponse:
|
||||
async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(enabled_only=False, app_config=config)
|
||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load skills: {e}", exc_info=True)
|
||||
@@ -136,9 +138,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
||||
|
||||
|
||||
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
|
||||
async def list_custom_skills() -> SkillsListResponse:
|
||||
async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
|
||||
try:
|
||||
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
|
||||
skills = [skill for skill in load_skills(enabled_only=False, app_config=config) if skill.category == "custom"]
|
||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||
except Exception as e:
|
||||
logger.error("Failed to list custom skills: %s", e, exc_info=True)
|
||||
@@ -146,13 +148,13 @@ async def list_custom_skills() -> SkillsListResponse:
|
||||
|
||||
|
||||
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
|
||||
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
|
||||
async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(enabled_only=False, app_config=config)
|
||||
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
|
||||
if skill is None:
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
|
||||
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name, app_config=config))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -161,14 +163,14 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
|
||||
|
||||
|
||||
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
|
||||
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
|
||||
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
||||
try:
|
||||
ensure_custom_skill_is_editable(skill_name)
|
||||
ensure_custom_skill_is_editable(skill_name, app_config=config)
|
||||
validate_skill_markdown_content(skill_name, request.content)
|
||||
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md", app_config=config)
|
||||
if scan.decision == "block":
|
||||
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
|
||||
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
|
||||
skill_file = get_custom_skill_dir(skill_name, app_config=config) / "SKILL.md"
|
||||
prev_content = skill_file.read_text(encoding="utf-8")
|
||||
atomic_write(skill_file, request.content)
|
||||
append_history(
|
||||
@@ -182,9 +184,10 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
|
||||
"new_content": request.content,
|
||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||
},
|
||||
app_config=config,
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return await get_custom_skill(skill_name)
|
||||
return await get_custom_skill(skill_name, config)
|
||||
except HTTPException:
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
@@ -197,11 +200,11 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
|
||||
|
||||
|
||||
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
|
||||
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||
async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> dict[str, bool]:
|
||||
try:
|
||||
ensure_custom_skill_is_editable(skill_name)
|
||||
skill_dir = get_custom_skill_dir(skill_name)
|
||||
prev_content = read_custom_skill_content(skill_name)
|
||||
ensure_custom_skill_is_editable(skill_name, app_config=config)
|
||||
skill_dir = get_custom_skill_dir(skill_name, app_config=config)
|
||||
prev_content = read_custom_skill_content(skill_name, app_config=config)
|
||||
try:
|
||||
append_history(
|
||||
skill_name,
|
||||
@@ -214,6 +217,7 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||
"new_content": None,
|
||||
"scanner": {"decision": "allow", "reason": "Deletion requested."},
|
||||
},
|
||||
app_config=config,
|
||||
)
|
||||
except OSError as e:
|
||||
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
|
||||
@@ -232,11 +236,11 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||
|
||||
|
||||
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
|
||||
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
|
||||
async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
|
||||
try:
|
||||
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||
if not custom_skill_exists(skill_name, app_config=config) and not get_skill_history_file(skill_name, app_config=config).exists():
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
return CustomSkillHistoryResponse(history=read_history(skill_name))
|
||||
return CustomSkillHistoryResponse(history=read_history(skill_name, app_config=config))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -245,11 +249,11 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
|
||||
|
||||
|
||||
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
|
||||
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
|
||||
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
||||
try:
|
||||
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||
if not custom_skill_exists(skill_name, app_config=config) and not get_skill_history_file(skill_name, app_config=config).exists():
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
history = read_history(skill_name)
|
||||
history = read_history(skill_name, app_config=config)
|
||||
if not history:
|
||||
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
|
||||
record = history[request.history_index]
|
||||
@@ -257,8 +261,8 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
|
||||
if target_content is None:
|
||||
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
|
||||
validate_skill_markdown_content(skill_name, target_content)
|
||||
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||
skill_file = get_custom_skill_file(skill_name)
|
||||
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md", app_config=config)
|
||||
skill_file = get_custom_skill_file(skill_name, app_config=config)
|
||||
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
|
||||
history_entry = {
|
||||
"action": "rollback",
|
||||
@@ -271,12 +275,12 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
|
||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||
}
|
||||
if scan.decision == "block":
|
||||
append_history(skill_name, history_entry)
|
||||
append_history(skill_name, history_entry, app_config=config)
|
||||
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
|
||||
atomic_write(skill_file, target_content)
|
||||
append_history(skill_name, history_entry)
|
||||
append_history(skill_name, history_entry, app_config=config)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return await get_custom_skill(skill_name)
|
||||
return await get_custom_skill(skill_name, config)
|
||||
except HTTPException:
|
||||
raise
|
||||
except IndexError:
|
||||
@@ -296,9 +300,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
|
||||
summary="Get Skill Details",
|
||||
description="Retrieve detailed information about a specific skill by its name.",
|
||||
)
|
||||
async def get_skill(skill_name: str) -> SkillResponse:
|
||||
async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> SkillResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(enabled_only=False, app_config=config)
|
||||
skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if skill is None:
|
||||
@@ -318,9 +322,9 @@ async def get_skill(skill_name: str) -> SkillResponse:
|
||||
summary="Update Skill",
|
||||
description="Update a skill's enabled status by modifying the extensions_config.json file.",
|
||||
)
|
||||
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
|
||||
async def update_skill(skill_name: str, request: SkillUpdateRequest, config: AppConfig = Depends(get_config)) -> SkillResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(enabled_only=False, app_config=config)
|
||||
skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if skill is None:
|
||||
@@ -346,7 +350,7 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
|
||||
reload_extensions_config()
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(enabled_only=False, app_config=config)
|
||||
updated_skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if updated_skill is None:
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -100,7 +102,12 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||
)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
|
||||
async def generate_suggestions(
|
||||
thread_id: str,
|
||||
body: SuggestionsRequest,
|
||||
request: Request,
|
||||
config: AppConfig = Depends(get_config),
|
||||
) -> SuggestionsResponse:
|
||||
if not body.messages:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
@@ -122,7 +129,7 @@ async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request
|
||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=body.model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=config)
|
||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
|
||||
raw = _extract_response_text(response.content)
|
||||
suggestions = _parse_json_string_list(raw) or []
|
||||
|
||||
@@ -4,11 +4,12 @@ import logging
|
||||
import os
|
||||
import stat
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
|
||||
@@ -60,23 +61,22 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
|
||||
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
|
||||
|
||||
|
||||
def _get_uploads_config_value(key: str, default: object) -> object:
|
||||
def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
|
||||
"""Read a value from the uploads config, supporting dict and attribute access."""
|
||||
cfg = get_app_config()
|
||||
uploads_cfg = getattr(cfg, "uploads", None)
|
||||
uploads_cfg = getattr(app_config, "uploads", None)
|
||||
if isinstance(uploads_cfg, dict):
|
||||
return uploads_cfg.get(key, default)
|
||||
return getattr(uploads_cfg, key, default)
|
||||
|
||||
|
||||
def _auto_convert_documents_enabled() -> bool:
|
||||
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
|
||||
"""Return whether automatic host-side document conversion is enabled.
|
||||
|
||||
The secure default is disabled unless an operator explicitly opts in via
|
||||
uploads.auto_convert_documents in config.yaml.
|
||||
"""
|
||||
try:
|
||||
raw = _get_uploads_config_value("auto_convert_documents", False)
|
||||
raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
|
||||
if isinstance(raw, str):
|
||||
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(raw)
|
||||
@@ -90,6 +90,7 @@ async def upload_files(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
files: list[UploadFile] = File(...),
|
||||
config: AppConfig = Depends(get_config),
|
||||
) -> UploadResponse:
|
||||
"""Upload multiple files to a thread's uploads directory."""
|
||||
if not files:
|
||||
@@ -108,7 +109,7 @@ async def upload_files(
|
||||
if sync_to_sandbox:
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
auto_convert_documents = _auto_convert_documents_enabled()
|
||||
auto_convert_documents = _auto_convert_documents_enabled(config)
|
||||
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
|
||||
@@ -18,7 +18,7 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.summarization_config import get_summarization_config
|
||||
from deerflow.models import create_chat_model
|
||||
@@ -35,9 +35,9 @@ def _get_runtime_config(config: RunnableConfig) -> dict:
|
||||
return cfg
|
||||
|
||||
|
||||
def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
def _resolve_model_name(requested_model_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||
app_config = get_app_config()
|
||||
app_config = app_config or get_app_config()
|
||||
default_model_name = app_config.models[0].name if app_config.models else None
|
||||
if default_model_name is None:
|
||||
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||
@@ -50,7 +50,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
return default_model_name
|
||||
|
||||
|
||||
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
|
||||
def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
|
||||
"""Create and configure the summarization middleware from config."""
|
||||
config = get_summarization_config()
|
||||
|
||||
@@ -73,9 +73,9 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
|
||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||
# LangChain built-in, so we tag the model at creation time).
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = create_chat_model(thinking_enabled=False, app_config=app_config)
|
||||
model = model.with_config(tags=["middleware:summarize"])
|
||||
|
||||
# Prepare kwargs
|
||||
@@ -99,7 +99,8 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
|
||||
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
|
||||
# config is not expected to change after startup.
|
||||
try:
|
||||
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve skills container path; falling back to default")
|
||||
skills_container_path = "/mnt/skills"
|
||||
@@ -240,7 +241,14 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
|
||||
def _build_middlewares(
|
||||
config: RunnableConfig,
|
||||
model_name: str | None,
|
||||
agent_name: str | None = None,
|
||||
custom_middlewares: list[AgentMiddleware] | None = None,
|
||||
*,
|
||||
app_config: AppConfig | None = None,
|
||||
):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
|
||||
Args:
|
||||
@@ -252,9 +260,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
List of middleware instances.
|
||||
"""
|
||||
middlewares = build_lead_runtime_middlewares(lazy_init=True)
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware()
|
||||
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
||||
if summarization_middleware is not None:
|
||||
middlewares.append(summarization_middleware)
|
||||
|
||||
@@ -266,7 +275,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
middlewares.append(todo_list_middleware)
|
||||
|
||||
# Add TokenUsageMiddleware when token_usage tracking is enabled
|
||||
if get_app_config().token_usage.enabled:
|
||||
if resolved_app_config.token_usage.enabled:
|
||||
middlewares.append(TokenUsageMiddleware())
|
||||
|
||||
# Add TitleMiddleware
|
||||
@@ -277,13 +286,12 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision.
|
||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
model_config = resolved_app_config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
|
||||
if app_config.tool_search.enabled:
|
||||
if resolved_app_config.tool_search.enabled:
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
middlewares.append(DeferredToolFilterMiddleware())
|
||||
@@ -306,12 +314,13 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
return middlewares
|
||||
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
def make_lead_agent(config: RunnableConfig, app_config: AppConfig | None = None):
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
from deerflow.tools.builtins import setup_agent
|
||||
|
||||
cfg = _get_runtime_config(config)
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
|
||||
thinking_enabled = cfg.get("thinking_enabled", True)
|
||||
reasoning_effort = cfg.get("reasoning_effort", None)
|
||||
@@ -327,10 +336,9 @@ def make_lead_agent(config: RunnableConfig):
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||
|
||||
# Final model name resolution: request → agent config → global default, with fallback for unknown names
|
||||
model_name = _resolve_model_name(requested_model_name or agent_model_name)
|
||||
model_name = _resolve_model_name(requested_model_name or agent_model_name, app_config=resolved_app_config)
|
||||
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name)
|
||||
model_config = resolved_app_config.get_model_config(model_name)
|
||||
|
||||
if model_config is None:
|
||||
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
|
||||
@@ -369,20 +377,34 @@ def make_lead_agent(config: RunnableConfig):
|
||||
if is_bootstrap:
|
||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
|
||||
middleware=_build_middlewares(config, model_name=model_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
|
||||
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
available_skills=set(["bootstrap"]),
|
||||
app_config=resolved_app_config,
|
||||
),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
# Default lead agent (unchanged behavior)
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
|
||||
tools=get_available_tools(
|
||||
model_name=model_name,
|
||||
groups=agent_config.tool_groups if agent_config else None,
|
||||
subagent_enabled=subagent_enabled,
|
||||
app_config=resolved_app_config,
|
||||
),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
agent_name=agent_name,
|
||||
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
|
||||
app_config=resolved_app_config,
|
||||
),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.config.agents_config import load_agent_soul
|
||||
from deerflow.skills import load_skills
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.subagents import get_available_subagent_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
|
||||
@@ -111,6 +117,19 @@ def _get_enabled_skills():
|
||||
return []
|
||||
|
||||
|
||||
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
|
||||
"""Return enabled skills using the caller's config source.
|
||||
|
||||
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
|
||||
cache so the skill list and skill paths are resolved from the same config
|
||||
object. This keeps request-scoped config injection consistent even while the
|
||||
release branch still supports global fallback paths.
|
||||
"""
|
||||
if app_config is None:
|
||||
return _get_enabled_skills()
|
||||
return list(load_skills(enabled_only=True, app_config=app_config))
|
||||
|
||||
|
||||
def _skill_mutability_label(category: str) -> str:
|
||||
return "[custom, editable]" if category == "custom" else "[built-in]"
|
||||
|
||||
@@ -576,14 +595,14 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
||||
</skill_system>"""
|
||||
|
||||
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills()
|
||||
skills = _get_enabled_skills_for_config(app_config)
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
config = app_config or get_app_config()
|
||||
container_base_path = config.skills.container_path
|
||||
skill_evolution_enabled = config.skill_evolution.enabled
|
||||
except Exception:
|
||||
@@ -612,7 +631,7 @@ def get_agent_soul(agent_name: str | None) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def get_deferred_tools_prompt_section() -> str:
|
||||
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
|
||||
"""Generate <available-deferred-tools> block for the system prompt.
|
||||
|
||||
Lists only deferred tool names so the agent knows what exists
|
||||
@@ -624,7 +643,8 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
if not get_app_config().tool_search.enabled:
|
||||
config = app_config or get_app_config()
|
||||
if not config.tool_search.enabled:
|
||||
return ""
|
||||
except Exception:
|
||||
return ""
|
||||
@@ -657,12 +677,13 @@ def _build_acp_section() -> str:
|
||||
)
|
||||
|
||||
|
||||
def _build_custom_mounts_section() -> str:
|
||||
def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
|
||||
"""Build a prompt section for explicitly configured sandbox mounts."""
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
mounts = get_app_config().sandbox.mounts or []
|
||||
config = app_config or get_app_config()
|
||||
mounts = config.sandbox.mounts or []
|
||||
except Exception:
|
||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||
return ""
|
||||
@@ -679,7 +700,14 @@ def _build_custom_mounts_section() -> str:
|
||||
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
|
||||
|
||||
|
||||
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
||||
def apply_prompt_template(
|
||||
subagent_enabled: bool = False,
|
||||
max_concurrent_subagents: int = 3,
|
||||
*,
|
||||
agent_name: str | None = None,
|
||||
available_skills: set[str] | None = None,
|
||||
app_config: AppConfig | None = None,
|
||||
) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name)
|
||||
|
||||
@@ -706,14 +734,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
|
||||
)
|
||||
|
||||
# Get skills section
|
||||
skills_section = get_skills_prompt_section(available_skills)
|
||||
skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
|
||||
|
||||
# Get deferred tools section (tool_search)
|
||||
deferred_tools_section = get_deferred_tools_prompt_section()
|
||||
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
|
||||
|
||||
# Build ACP agent section only if ACP agents are configured
|
||||
acp_section = _build_acp_section()
|
||||
custom_mounts_section = _build_custom_mounts_section()
|
||||
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
|
||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||
|
||||
# Format the prompt with dynamic skills and memory
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from langchain.chat_models import BaseChatModel
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.reflection import resolve_class
|
||||
from deerflow.tracing import build_tracing_callbacks
|
||||
|
||||
@@ -46,7 +47,7 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
|
||||
model_settings_from_config["stream_usage"] = True
|
||||
|
||||
|
||||
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
|
||||
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, **kwargs) -> BaseChatModel:
|
||||
"""Create a chat model instance from the config.
|
||||
|
||||
Args:
|
||||
@@ -55,7 +56,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
Returns:
|
||||
A chat model instance.
|
||||
"""
|
||||
config = get_app_config()
|
||||
config = app_config or get_app_config()
|
||||
if name is None:
|
||||
name = config.models[0].name
|
||||
model_config = config.get_model_config(name)
|
||||
|
||||
@@ -20,11 +20,13 @@ import copy
|
||||
import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.runtime.serialization import serialize
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
@@ -51,6 +53,27 @@ class RunContext:
|
||||
event_store: Any | None = field(default=None)
|
||||
run_events_config: Any | None = field(default=None)
|
||||
thread_store: Any | None = field(default=None)
|
||||
app_config: AppConfig | None = field(default=None)
|
||||
|
||||
|
||||
def _compute_agent_factory_supports_app_config(agent_factory: Any) -> bool:
|
||||
try:
|
||||
return "app_config" in inspect.signature(agent_factory).parameters
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _cached_agent_factory_supports_app_config(agent_factory: Any) -> bool:
|
||||
return _compute_agent_factory_supports_app_config(agent_factory)
|
||||
|
||||
|
||||
def _agent_factory_supports_app_config(agent_factory: Any) -> bool:
|
||||
try:
|
||||
return _cached_agent_factory_supports_app_config(agent_factory)
|
||||
except TypeError:
|
||||
# Some callable instances are unhashable; fall back to a direct check.
|
||||
return _compute_agent_factory_supports_app_config(agent_factory)
|
||||
|
||||
|
||||
async def run_agent(
|
||||
@@ -163,7 +186,10 @@ async def run_agent(
|
||||
config.setdefault("callbacks", []).append(journal)
|
||||
|
||||
runnable_config = RunnableConfig(**config)
|
||||
agent = agent_factory(config=runnable_config)
|
||||
if ctx.app_config is not None and _agent_factory_supports_app_config(agent_factory):
|
||||
agent = agent_factory(config=runnable_config, app_config=ctx.app_config)
|
||||
else:
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
# 4. Attach checkpointer and store
|
||||
if checkpointer is not None:
|
||||
|
||||
@@ -2,6 +2,8 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
from .parser import parse_skill_file
|
||||
from .types import Skill
|
||||
|
||||
@@ -22,7 +24,7 @@ def get_skills_root_path() -> Path:
|
||||
return skills_dir
|
||||
|
||||
|
||||
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
|
||||
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False, *, app_config: AppConfig | None = None) -> list[Skill]:
|
||||
"""
|
||||
Load all skills from the skills directory.
|
||||
|
||||
@@ -44,7 +46,7 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
config = app_config or get_app_config()
|
||||
skills_path = config.skills.get_skills_path()
|
||||
except Exception:
|
||||
# Fallback to default if config fails
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.skills.loader import load_skills
|
||||
from deerflow.skills.validation import _validate_skill_frontmatter
|
||||
|
||||
@@ -20,16 +21,17 @@ ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
|
||||
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
|
||||
|
||||
|
||||
def get_skills_root_dir() -> Path:
|
||||
return get_app_config().skills.get_skills_path()
|
||||
def get_skills_root_dir(*, app_config: AppConfig | None = None) -> Path:
|
||||
config = app_config or get_app_config()
|
||||
return config.skills.get_skills_path()
|
||||
|
||||
|
||||
def get_public_skills_dir() -> Path:
|
||||
return get_skills_root_dir() / "public"
|
||||
def get_public_skills_dir(*, app_config: AppConfig | None = None) -> Path:
|
||||
return get_skills_root_dir(app_config=app_config) / "public"
|
||||
|
||||
|
||||
def get_custom_skills_dir() -> Path:
|
||||
path = get_skills_root_dir() / "custom"
|
||||
def get_custom_skills_dir(*, app_config: AppConfig | None = None) -> Path:
|
||||
path = get_skills_root_dir(app_config=app_config) / "custom"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
@@ -43,46 +45,46 @@ def validate_skill_name(name: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
def get_custom_skill_dir(name: str) -> Path:
|
||||
return get_custom_skills_dir() / validate_skill_name(name)
|
||||
def get_custom_skill_dir(name: str, *, app_config: AppConfig | None = None) -> Path:
|
||||
return get_custom_skills_dir(app_config=app_config) / validate_skill_name(name)
|
||||
|
||||
|
||||
def get_custom_skill_file(name: str) -> Path:
|
||||
return get_custom_skill_dir(name) / SKILL_FILE_NAME
|
||||
def get_custom_skill_file(name: str, *, app_config: AppConfig | None = None) -> Path:
|
||||
return get_custom_skill_dir(name, app_config=app_config) / SKILL_FILE_NAME
|
||||
|
||||
|
||||
def get_custom_skill_history_dir() -> Path:
|
||||
path = get_custom_skills_dir() / HISTORY_DIR_NAME
|
||||
def get_custom_skill_history_dir(*, app_config: AppConfig | None = None) -> Path:
|
||||
path = get_custom_skills_dir(app_config=app_config) / HISTORY_DIR_NAME
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_skill_history_file(name: str) -> Path:
|
||||
return get_custom_skill_history_dir() / f"{validate_skill_name(name)}.jsonl"
|
||||
def get_skill_history_file(name: str, *, app_config: AppConfig | None = None) -> Path:
|
||||
return get_custom_skill_history_dir(app_config=app_config) / f"{validate_skill_name(name)}.jsonl"
|
||||
|
||||
|
||||
def get_public_skill_dir(name: str) -> Path:
|
||||
return get_public_skills_dir() / validate_skill_name(name)
|
||||
def get_public_skill_dir(name: str, *, app_config: AppConfig | None = None) -> Path:
|
||||
return get_public_skills_dir(app_config=app_config) / validate_skill_name(name)
|
||||
|
||||
|
||||
def custom_skill_exists(name: str) -> bool:
|
||||
return get_custom_skill_file(name).exists()
|
||||
def custom_skill_exists(name: str, *, app_config: AppConfig | None = None) -> bool:
|
||||
return get_custom_skill_file(name, app_config=app_config).exists()
|
||||
|
||||
|
||||
def public_skill_exists(name: str) -> bool:
|
||||
return (get_public_skill_dir(name) / SKILL_FILE_NAME).exists()
|
||||
def public_skill_exists(name: str, *, app_config: AppConfig | None = None) -> bool:
|
||||
return (get_public_skill_dir(name, app_config=app_config) / SKILL_FILE_NAME).exists()
|
||||
|
||||
|
||||
def ensure_custom_skill_is_editable(name: str) -> None:
|
||||
if custom_skill_exists(name):
|
||||
def ensure_custom_skill_is_editable(name: str, *, app_config: AppConfig | None = None) -> None:
|
||||
if custom_skill_exists(name, app_config=app_config):
|
||||
return
|
||||
if public_skill_exists(name):
|
||||
if public_skill_exists(name, app_config=app_config):
|
||||
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
|
||||
raise FileNotFoundError(f"Custom skill '{name}' not found.")
|
||||
|
||||
|
||||
def ensure_safe_support_path(name: str, relative_path: str) -> Path:
|
||||
skill_dir = get_custom_skill_dir(name).resolve()
|
||||
def ensure_safe_support_path(name: str, relative_path: str, *, app_config: AppConfig | None = None) -> Path:
|
||||
skill_dir = get_custom_skill_dir(name, app_config=app_config).resolve()
|
||||
if not relative_path or relative_path.endswith("/"):
|
||||
raise ValueError("Supporting file path must include a filename.")
|
||||
relative = Path(relative_path)
|
||||
@@ -124,8 +126,8 @@ def atomic_write(path: Path, content: str) -> None:
|
||||
tmp_path.replace(path)
|
||||
|
||||
|
||||
def append_history(name: str, record: dict[str, Any]) -> None:
|
||||
history_path = get_skill_history_file(name)
|
||||
def append_history(name: str, record: dict[str, Any], *, app_config: AppConfig | None = None) -> None:
|
||||
history_path = get_skill_history_file(name, app_config=app_config)
|
||||
history_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"ts": datetime.now(UTC).isoformat(),
|
||||
@@ -136,8 +138,8 @@ def append_history(name: str, record: dict[str, Any]) -> None:
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def read_history(name: str) -> list[dict[str, Any]]:
|
||||
history_path = get_skill_history_file(name)
|
||||
def read_history(name: str, *, app_config: AppConfig | None = None) -> list[dict[str, Any]]:
|
||||
history_path = get_skill_history_file(name, app_config=app_config)
|
||||
if not history_path.exists():
|
||||
return []
|
||||
records: list[dict[str, Any]] = []
|
||||
@@ -148,12 +150,12 @@ def read_history(name: str) -> list[dict[str, Any]]:
|
||||
return records
|
||||
|
||||
|
||||
def list_custom_skills() -> list:
|
||||
return [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
|
||||
def list_custom_skills(*, app_config: AppConfig | None = None) -> list:
|
||||
return [skill for skill in load_skills(enabled_only=False, app_config=app_config) if skill.category == "custom"]
|
||||
|
||||
|
||||
def read_custom_skill_content(name: str) -> str:
|
||||
skill_file = get_custom_skill_file(name)
|
||||
def read_custom_skill_content(name: str, *, app_config: AppConfig | None = None) -> str:
|
||||
skill_file = get_custom_skill_file(name, app_config=app_config)
|
||||
if not skill_file.exists():
|
||||
raise FileNotFoundError(f"Custom skill '{name}' not found.")
|
||||
return skill_file.read_text(encoding="utf-8")
|
||||
|
||||
@@ -8,6 +8,7 @@ import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,7 +36,7 @@ def _extract_json_object(raw: str) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
|
||||
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md", app_config: AppConfig | None = None) -> ScanResult:
|
||||
"""Screen skill content before it is written to disk."""
|
||||
rubric = (
|
||||
"You are a security reviewer for AI agent skills. "
|
||||
@@ -47,9 +48,9 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
|
||||
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
|
||||
|
||||
try:
|
||||
config = get_app_config()
|
||||
config = app_config or get_app_config()
|
||||
model_name = config.skill_evolution.moderation_model_name
|
||||
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
|
||||
model = create_chat_model(name=model_name, thinking_enabled=False, app_config=config) if model_name else create_chat_model(thinking_enabled=False, app_config=config)
|
||||
response = await model.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": rubric},
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||
@@ -37,6 +38,8 @@ def get_available_tools(
|
||||
include_mcp: bool = True,
|
||||
model_name: str | None = None,
|
||||
subagent_enabled: bool = False,
|
||||
*,
|
||||
app_config: AppConfig | None = None,
|
||||
) -> list[BaseTool]:
|
||||
"""Get all available tools from config.
|
||||
|
||||
@@ -52,7 +55,7 @@ def get_available_tools(
|
||||
Returns:
|
||||
List of available tools.
|
||||
"""
|
||||
config = get_app_config()
|
||||
config = app_config or get_app_config()
|
||||
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
|
||||
|
||||
# Do not expose host bash by default when LocalSandboxProvider is active.
|
||||
|
||||
@@ -84,14 +84,15 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
captured["app_config"] = app_config
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
@@ -110,6 +111,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
|
||||
assert captured["name"] == "safe-model"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert captured["app_config"] is app_config
|
||||
assert result["model"] is not None
|
||||
|
||||
|
||||
@@ -126,14 +128,15 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
captured["app_config"] = app_config
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
@@ -156,8 +159,9 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
|
||||
"name": "context-model",
|
||||
"thinking_enabled": False,
|
||||
"reasoning_effort": "high",
|
||||
"app_config": app_config,
|
||||
}
|
||||
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True)
|
||||
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True, app_config=app_config)
|
||||
assert result["model"] is not None
|
||||
|
||||
|
||||
@@ -198,10 +202,15 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="vision-model",
|
||||
custom_middlewares=[MagicMock()],
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
@@ -222,18 +231,20 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
captured["app_config"] = app_config
|
||||
return fake_model
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware()
|
||||
middleware = lead_agent_module._create_summarization_middleware(app_config=_make_app_config([_make_model("model-masswork", supports_thinking=False)]))
|
||||
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert captured["app_config"] is not None
|
||||
assert middleware["model"] is fake_model
|
||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||
|
||||
@@ -48,7 +48,7 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
@@ -66,7 +66,7 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
@@ -100,6 +100,24 @@ def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeyp
|
||||
assert "Skill Self-Evolution" not in disabled_result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_uses_explicit_config_for_enabled_skills(monkeypatch):
|
||||
explicit_config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/alt-skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [_make_skill("global-skill")])
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.lead_agent.prompt.load_skills",
|
||||
lambda enabled_only=True, app_config=None: [_make_skill("explicit-skill")] if app_config is explicit_config else [],
|
||||
)
|
||||
|
||||
result = get_skills_prompt_section(app_config=explicit_config)
|
||||
|
||||
assert "explicit-skill" in result
|
||||
assert "global-skill" not in result
|
||||
|
||||
|
||||
def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -107,7 +125,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
|
||||
# Mock dependencies
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
|
||||
@@ -2,7 +2,7 @@ from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint
|
||||
from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _rollback_to_pre_run_checkpoint
|
||||
|
||||
|
||||
class FakeCheckpointer:
|
||||
@@ -212,3 +212,20 @@ async def test_rollback_propagates_aput_writes_failure():
|
||||
# aput succeeded, aput_writes was called but failed
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_awaited_once()
|
||||
|
||||
|
||||
def test_agent_factory_supports_app_config_detects_supported_signature():
|
||||
def factory(*, config, app_config=None):
|
||||
return (config, app_config)
|
||||
|
||||
assert _agent_factory_supports_app_config(factory) is True
|
||||
|
||||
|
||||
def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch):
|
||||
class BrokenCallable:
|
||||
def __call__(self, **kwargs):
|
||||
return kwargs
|
||||
|
||||
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
|
||||
|
||||
assert _agent_factory_supports_app_config(BrokenCallable()) is False
|
||||
|
||||
@@ -35,6 +35,13 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
|
||||
)
|
||||
|
||||
|
||||
def _make_test_app(config) -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
return app
|
||||
|
||||
|
||||
def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
|
||||
skills_root = tmp_path / "skills"
|
||||
custom_dir = skills_root / "custom" / "demo-skill"
|
||||
@@ -54,8 +61,7 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(skills_router.router)
|
||||
app = _make_test_app(config)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/skills/custom")
|
||||
@@ -96,7 +102,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
get_skill_history_file("demo-skill").write_text(
|
||||
get_skill_history_file("demo-skill", app_config=config).write_text(
|
||||
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
@@ -113,8 +119,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(skills_router.router)
|
||||
app = _make_test_app(config)
|
||||
|
||||
with TestClient(app) as client:
|
||||
rollback_response = client.post("/api/skills/custom/demo-skill/rollback", json={"history_index": -1})
|
||||
@@ -146,8 +151,7 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(skills_router.router)
|
||||
app = _make_test_app(config)
|
||||
|
||||
with TestClient(app) as client:
|
||||
delete_response = client.delete("/api/skills/custom/demo-skill")
|
||||
@@ -187,8 +191,7 @@ def test_custom_skill_delete_continues_when_history_write_is_readonly(monkeypatc
|
||||
monkeypatch.setattr("app.gateway.routers.skills.append_history", _readonly_history)
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(skills_router.router)
|
||||
app = _make_test_app(config)
|
||||
|
||||
with TestClient(app) as client:
|
||||
delete_response = client.delete("/api/skills/custom/demo-skill")
|
||||
@@ -221,8 +224,7 @@ def test_custom_skill_delete_fails_when_skill_dir_removal_fails(monkeypatch, tmp
|
||||
monkeypatch.setattr("app.gateway.routers.skills.shutil.rmtree", _fail_rmtree)
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(skills_router.router)
|
||||
app = _make_test_app(config)
|
||||
|
||||
with TestClient(app) as client:
|
||||
delete_response = client.delete("/api/skills/custom/demo-skill")
|
||||
@@ -238,7 +240,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
|
||||
enabled_state = {"value": True}
|
||||
refresh_calls = []
|
||||
|
||||
def _load_skills(*, enabled_only: bool):
|
||||
def _load_skills(*, enabled_only: bool, app_config=None):
|
||||
skill = _make_skill("demo-skill", enabled=enabled_state["value"])
|
||||
if enabled_only and not skill.enabled:
|
||||
return []
|
||||
@@ -254,8 +256,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
|
||||
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(skills_router.router)
|
||||
app = _make_test_app(SimpleNamespace())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.put("/api/skills/demo-skill", json={"enabled": False})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.gateway.routers import suggestions
|
||||
@@ -48,7 +49,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
||||
fake_model.ainvoke.assert_awaited_once()
|
||||
@@ -70,7 +71,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
fake_model.ainvoke.assert_awaited_once()
|
||||
@@ -92,7 +93,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
fake_model.ainvoke.assert_awaited_once()
|
||||
@@ -111,6 +112,6 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
|
||||
|
||||
assert result.suggestions == []
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import stat
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from _router_auth_helpers import call_unwrapped
|
||||
@@ -26,7 +27,7 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
@@ -49,7 +50,7 @@ def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert result.success is True
|
||||
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
|
||||
@@ -75,7 +76,7 @@ def test_upload_files_does_not_auto_convert_documents_by_default(tmp_path):
|
||||
patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock,
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
@@ -108,7 +109,7 @@ def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path):
|
||||
patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=fake_convert)),
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
@@ -147,7 +148,7 @@ def test_upload_files_makes_non_local_files_sandbox_writable(tmp_path):
|
||||
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert result.success is True
|
||||
make_writable.assert_any_call(thread_uploads_dir / "report.pdf")
|
||||
@@ -171,7 +172,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
|
||||
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert result.success is True
|
||||
make_writable.assert_not_called()
|
||||
@@ -222,13 +223,13 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
|
||||
# These filenames must be rejected outright
|
||||
for bad_name in ["..", "."]:
|
||||
file = UploadFile(filename=bad_name, file=BytesIO(b"data"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
assert result.success is True
|
||||
assert result.files == [], f"Expected no files for unsafe filename {bad_name!r}"
|
||||
|
||||
# Path-traversal prefixes are stripped to the basename and accepted safely
|
||||
file = UploadFile(filename="../etc/passwd", file=BytesIO(b"data"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
assert result.files[0]["filename"] == "passwd"
|
||||
@@ -252,16 +253,20 @@ def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path):
|
||||
|
||||
|
||||
def test_auto_convert_documents_enabled_defaults_to_false_on_config_errors():
|
||||
with patch.object(uploads, "get_app_config", side_effect=RuntimeError("boom")):
|
||||
assert uploads._auto_convert_documents_enabled() is False
|
||||
class BrokenConfig:
|
||||
def __getattribute__(self, name):
|
||||
if name == "uploads":
|
||||
raise RuntimeError("boom")
|
||||
return super().__getattribute__(name)
|
||||
|
||||
assert uploads._auto_convert_documents_enabled(BrokenConfig()) is False
|
||||
|
||||
|
||||
def test_auto_convert_documents_enabled_reads_dict_backed_uploads_config():
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {"auto_convert_documents": True}
|
||||
|
||||
with patch.object(uploads, "get_app_config", return_value=cfg):
|
||||
assert uploads._auto_convert_documents_enabled() is True
|
||||
assert uploads._auto_convert_documents_enabled(cfg) is True
|
||||
|
||||
|
||||
def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values():
|
||||
@@ -277,11 +282,7 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values
|
||||
string_false_cfg = MagicMock()
|
||||
string_false_cfg.uploads = MagicMock(auto_convert_documents="false")
|
||||
|
||||
with patch.object(uploads, "get_app_config", return_value=false_cfg):
|
||||
assert uploads._auto_convert_documents_enabled() is False
|
||||
with patch.object(uploads, "get_app_config", return_value=true_cfg):
|
||||
assert uploads._auto_convert_documents_enabled() is True
|
||||
with patch.object(uploads, "get_app_config", return_value=string_true_cfg):
|
||||
assert uploads._auto_convert_documents_enabled() is True
|
||||
with patch.object(uploads, "get_app_config", return_value=string_false_cfg):
|
||||
assert uploads._auto_convert_documents_enabled() is False
|
||||
assert uploads._auto_convert_documents_enabled(false_cfg) is False
|
||||
assert uploads._auto_convert_documents_enabled(true_cfg) is True
|
||||
assert uploads._auto_convert_documents_enabled(string_true_cfg) is True
|
||||
assert uploads._auto_convert_documents_enabled(string_false_cfg) is False
|
||||
|
||||
Reference in New Issue
Block a user