Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 979a461af5 | |||
| ac04f2704f | |||
| c4d273a68a | |||
| dc50a7fdfb | |||
| 5b633449f8 | |||
| 02569136df | |||
| 024ac0e464 | |||
| 19030928e0 |
@@ -24,6 +24,7 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
# SLACK_BOT_TOKEN=your-slack-bot-token
|
||||
# SLACK_APP_TOKEN=your-slack-app-token
|
||||
# TELEGRAM_BOT_TOKEN=your-telegram-bot-token
|
||||
# DISCORD_BOT_TOKEN=your-discord-bot-token
|
||||
|
||||
# Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions.
|
||||
# LANGSMITH_TRACING=true
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
"""Discord channel integration using discord.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DISCORD_MAX_MESSAGE_LEN = 2000
|
||||
|
||||
|
||||
class DiscordChannel(Channel):
|
||||
"""Discord bot channel.
|
||||
|
||||
Configuration keys (in ``config.yaml`` under ``channels.discord``):
|
||||
- ``bot_token``: Discord Bot token.
|
||||
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
|
||||
"""
|
||||
|
||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||||
super().__init__(name="discord", bus=bus, config=config)
|
||||
self._bot_token = str(config.get("bot_token", "")).strip()
|
||||
self._allowed_guilds: set[int] = set()
|
||||
for guild_id in config.get("allowed_guilds", []):
|
||||
try:
|
||||
self._allowed_guilds.add(int(guild_id))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
self._client = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._discord_module = None
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
import discord
|
||||
except ImportError:
|
||||
logger.error("discord.py is not installed. Install it with: uv add discord.py")
|
||||
return
|
||||
|
||||
if not self._bot_token:
|
||||
logger.error("Discord channel requires bot_token")
|
||||
return
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.messages = True
|
||||
intents.guilds = True
|
||||
intents.message_content = True
|
||||
|
||||
client = discord.Client(
|
||||
intents=intents,
|
||||
allowed_mentions=discord.AllowedMentions.none(),
|
||||
)
|
||||
self._client = client
|
||||
self._discord_module = discord
|
||||
self._main_loop = asyncio.get_event_loop()
|
||||
|
||||
@client.event
|
||||
async def on_message(message) -> None:
|
||||
await self._on_message(message)
|
||||
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
|
||||
self._thread = threading.Thread(target=self._run_client, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Discord channel started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
||||
|
||||
if self._client and self._discord_loop and self._discord_loop.is_running():
|
||||
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(close_future), timeout=10)
|
||||
except TimeoutError:
|
||||
logger.warning("[Discord] client close timed out after 10s")
|
||||
except Exception:
|
||||
logger.exception("[Discord] error while closing client")
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=10)
|
||||
self._thread = None
|
||||
|
||||
self._client = None
|
||||
self._discord_loop = None
|
||||
self._discord_module = None
|
||||
logger.info("Discord channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
target = await self._resolve_target(msg)
|
||||
if target is None:
|
||||
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
||||
return
|
||||
|
||||
text = msg.text or ""
|
||||
for chunk in self._split_text(text):
|
||||
send_future = asyncio.run_coroutine_threadsafe(target.send(chunk), self._discord_loop)
|
||||
await asyncio.wrap_future(send_future)
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
target = await self._resolve_target(msg)
|
||||
if target is None:
|
||||
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
||||
return False
|
||||
|
||||
if self._discord_module is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
|
||||
file = self._discord_module.File(fp, filename=attachment.filename)
|
||||
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
|
||||
await asyncio.wrap_future(send_future)
|
||||
logger.info("[Discord] file uploaded: %s", attachment.filename)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
async def _on_message(self, message) -> None:
|
||||
if not self._running or not self._client:
|
||||
return
|
||||
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
if self._client.user and message.author.id == self._client.user.id:
|
||||
return
|
||||
|
||||
guild = message.guild
|
||||
if self._allowed_guilds:
|
||||
if guild is None or guild.id not in self._allowed_guilds:
|
||||
return
|
||||
|
||||
text = (message.content or "").strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
if self._discord_module is None:
|
||||
return
|
||||
|
||||
if isinstance(message.channel, self._discord_module.Thread):
|
||||
chat_id = str(message.channel.parent_id or message.channel.id)
|
||||
thread_id = str(message.channel.id)
|
||||
else:
|
||||
thread = await self._create_thread(message)
|
||||
if thread is None:
|
||||
return
|
||||
chat_id = str(message.channel.id)
|
||||
thread_id = str(thread.id)
|
||||
|
||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=chat_id,
|
||||
user_id=str(message.author.id),
|
||||
text=text,
|
||||
msg_type=msg_type,
|
||||
thread_ts=thread_id,
|
||||
metadata={
|
||||
"guild_id": str(guild.id) if guild else None,
|
||||
"channel_id": str(message.channel.id),
|
||||
"message_id": str(message.id),
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||
|
||||
def _run_client(self) -> None:
|
||||
self._discord_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._discord_loop)
|
||||
try:
|
||||
self._discord_loop.run_until_complete(self._client.start(self._bot_token))
|
||||
except Exception:
|
||||
if self._running:
|
||||
logger.exception("Discord client error")
|
||||
finally:
|
||||
try:
|
||||
if self._client and not self._client.is_closed():
|
||||
self._discord_loop.run_until_complete(self._client.close())
|
||||
except Exception:
|
||||
logger.exception("Error during Discord shutdown")
|
||||
|
||||
async def _create_thread(self, message):
|
||||
try:
|
||||
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
|
||||
return await message.create_thread(name=thread_name)
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
|
||||
try:
|
||||
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _resolve_target(self, msg: OutboundMessage):
|
||||
if not self._client or not self._discord_loop:
|
||||
return None
|
||||
|
||||
target_ids: list[str] = []
|
||||
if msg.thread_ts:
|
||||
target_ids.append(msg.thread_ts)
|
||||
if msg.chat_id and msg.chat_id not in target_ids:
|
||||
target_ids.append(msg.chat_id)
|
||||
|
||||
for raw_id in target_ids:
|
||||
target = await self._get_channel_or_thread(raw_id)
|
||||
if target is not None:
|
||||
return target
|
||||
return None
|
||||
|
||||
async def _get_channel_or_thread(self, raw_id: str):
|
||||
if not self._client or not self._discord_loop:
|
||||
return None
|
||||
|
||||
try:
|
||||
target_id = int(raw_id)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
get_future = asyncio.run_coroutine_threadsafe(self._fetch_channel(target_id), self._discord_loop)
|
||||
try:
|
||||
return await asyncio.wrap_future(get_future)
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to resolve target id=%s", raw_id)
|
||||
return None
|
||||
|
||||
async def _fetch_channel(self, target_id: int):
|
||||
if not self._client:
|
||||
return None
|
||||
|
||||
channel = self._client.get_channel(target_id)
|
||||
if channel is not None:
|
||||
return channel
|
||||
|
||||
try:
|
||||
return await self._client.fetch_channel(target_id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _split_text(text: str) -> list[str]:
|
||||
if not text:
|
||||
return [""]
|
||||
|
||||
chunks: list[str] = []
|
||||
remaining = text
|
||||
while len(remaining) > _DISCORD_MAX_MESSAGE_LEN:
|
||||
split_at = remaining.rfind("\n", 0, _DISCORD_MAX_MESSAGE_LEN)
|
||||
if split_at <= 0:
|
||||
split_at = _DISCORD_MAX_MESSAGE_LEN
|
||||
chunks.append(remaining[:split_at])
|
||||
remaining = remaining[split_at:].lstrip("\n")
|
||||
|
||||
if remaining:
|
||||
chunks.append(remaining)
|
||||
|
||||
return chunks
|
||||
@@ -35,6 +35,7 @@ STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
|
||||
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
||||
|
||||
CHANNEL_CAPABILITIES = {
|
||||
"discord": {"supports_streaming": False},
|
||||
"feishu": {"supports_streaming": True},
|
||||
"slack": {"supports_streaming": False},
|
||||
"telegram": {"supports_streaming": False},
|
||||
|
||||
@@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Channel name → import path for lazy loading
|
||||
_CHANNEL_REGISTRY: dict[str, str] = {
|
||||
"discord": "app.channels.discord:DiscordChannel",
|
||||
"feishu": "app.channels.feishu:FeishuChannel",
|
||||
"slack": "app.channels.slack:SlackChannel",
|
||||
"telegram": "app.channels.telegram:TelegramChannel",
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
- [x] Add Plan Mode with TodoList middleware
|
||||
- [x] Add vision model support with ViewImageMiddleware
|
||||
- [x] Skills system with SKILL.md format
|
||||
- [x] Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
|
||||
|
||||
## Planned Features
|
||||
|
||||
@@ -21,8 +22,7 @@
|
||||
- [ ] Support for more document formats in upload
|
||||
- [ ] Skill marketplace / remote skill installation
|
||||
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
|
||||
- Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
|
||||
- Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
|
||||
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
|
||||
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
|
||||
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
|
||||
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
|
||||
|
||||
@@ -31,6 +31,8 @@ _DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
||||
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
||||
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
|
||||
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
|
||||
|
||||
|
||||
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
||||
@@ -125,8 +127,14 @@ def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
||||
|
||||
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
||||
|
||||
_TOOL_FREQ_WARNING_MSG = (
|
||||
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
||||
)
|
||||
|
||||
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
|
||||
|
||||
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
|
||||
|
||||
|
||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Detects and breaks repetitive tool call loops.
|
||||
@@ -140,6 +148,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
Default: 20.
|
||||
max_tracked_threads: Maximum number of threads to track before
|
||||
evicting the least recently used. Default: 100.
|
||||
tool_freq_warn: Number of calls to the same tool *type* (regardless
|
||||
of arguments) before injecting a frequency warning. Catches
|
||||
cross-file read loops that hash-based detection misses.
|
||||
Default: 30.
|
||||
tool_freq_hard_limit: Number of calls to the same tool type before
|
||||
forcing a stop. Default: 50.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -148,16 +162,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
||||
window_size: int = _DEFAULT_WINDOW_SIZE,
|
||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
||||
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
||||
):
|
||||
super().__init__()
|
||||
self.warn_threshold = warn_threshold
|
||||
self.hard_limit = hard_limit
|
||||
self.window_size = window_size
|
||||
self.max_tracked_threads = max_tracked_threads
|
||||
self.tool_freq_warn = tool_freq_warn
|
||||
self.tool_freq_hard_limit = tool_freq_hard_limit
|
||||
self._lock = threading.Lock()
|
||||
# Per-thread tracking using OrderedDict for LRU eviction
|
||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
# Per-thread, per-tool-type cumulative call counts
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
@@ -174,11 +195,19 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
while len(self._history) > self.max_tracked_threads:
|
||||
evicted_id, _ = self._history.popitem(last=False)
|
||||
self._warned.pop(evicted_id, None)
|
||||
self._tool_freq.pop(evicted_id, None)
|
||||
self._tool_freq_warned.pop(evicted_id, None)
|
||||
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
||||
|
||||
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
||||
"""Track tool calls and check for loops.
|
||||
|
||||
Two detection layers:
|
||||
1. **Hash-based** (existing): catches identical tool call sets.
|
||||
2. **Frequency-based** (new): catches the same *tool type* being
|
||||
called many times with varying arguments (e.g. ``read_file``
|
||||
on 40 different files).
|
||||
|
||||
Returns:
|
||||
(warning_message_or_none, should_hard_stop)
|
||||
"""
|
||||
@@ -213,6 +242,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
count = history.count(call_hash)
|
||||
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
||||
|
||||
# --- Layer 1: hash-based (identical call sets) ---
|
||||
if count >= self.hard_limit:
|
||||
logger.error(
|
||||
"Loop hard limit reached — forcing stop",
|
||||
@@ -239,8 +269,40 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
},
|
||||
)
|
||||
return _WARNING_MSG, False
|
||||
# Warning already injected for this hash — suppress
|
||||
return None, False
|
||||
|
||||
# --- Layer 2: per-tool-type frequency ---
|
||||
freq = self._tool_freq[thread_id]
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
if not name:
|
||||
continue
|
||||
freq[name] += 1
|
||||
tc_count = freq[name]
|
||||
|
||||
if tc_count >= self.tool_freq_hard_limit:
|
||||
logger.error(
|
||||
"Tool frequency hard limit reached — forcing stop",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"tool_name": name,
|
||||
"count": tc_count,
|
||||
},
|
||||
)
|
||||
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
||||
|
||||
if tc_count >= self.tool_freq_warn:
|
||||
warned = self._tool_freq_warned[thread_id]
|
||||
if name not in warned:
|
||||
warned.add(name)
|
||||
logger.warning(
|
||||
"Tool frequency warning — too many calls to same tool type",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"tool_name": name,
|
||||
"count": tc_count,
|
||||
},
|
||||
)
|
||||
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
|
||||
|
||||
return None, False
|
||||
|
||||
@@ -271,7 +333,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
stripped_msg = last_msg.model_copy(
|
||||
update={
|
||||
"tool_calls": [],
|
||||
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
|
||||
"content": self._append_text(last_msg.content, warning),
|
||||
}
|
||||
)
|
||||
return {"messages": [stripped_msg]}
|
||||
@@ -301,6 +363,10 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
if thread_id:
|
||||
self._history.pop(thread_id, None)
|
||||
self._warned.pop(thread_id, None)
|
||||
self._tool_freq.pop(thread_id, None)
|
||||
self._tool_freq_warned.pop(thread_id, None)
|
||||
else:
|
||||
self._history.clear()
|
||||
self._warned.clear()
|
||||
self._tool_freq.clear()
|
||||
self._tool_freq_warned.clear()
|
||||
|
||||
@@ -262,21 +262,25 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
files_message = self._create_files_message(new_files, historical_files)
|
||||
|
||||
# Extract original content - handle both string and list formats
|
||||
original_content = ""
|
||||
if isinstance(last_message.content, str):
|
||||
original_content = last_message.content
|
||||
elif isinstance(last_message.content, list):
|
||||
text_parts = []
|
||||
for block in last_message.content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
original_content = "\n".join(text_parts)
|
||||
original_content = last_message.content
|
||||
if isinstance(original_content, str):
|
||||
# Simple case: string content, just prepend files message
|
||||
updated_content = f"{files_message}\n\n{original_content}"
|
||||
elif isinstance(original_content, list):
|
||||
# Complex case: list content (multimodal), preserve all blocks
|
||||
# Prepend files message as the first text block
|
||||
files_block = {"type": "text", "text": f"{files_message}\n\n"}
|
||||
# Keep all original blocks (including images)
|
||||
updated_content = [files_block, *original_content]
|
||||
else:
|
||||
# Other types, preserve as-is
|
||||
updated_content = original_content
|
||||
|
||||
# Create new message with combined content.
|
||||
# Preserve additional_kwargs (including files metadata) so the frontend
|
||||
# can read structured file info from the streamed message.
|
||||
updated_message = HumanMessage(
|
||||
content=f"{files_message}\n\n{original_content}",
|
||||
content=updated_content,
|
||||
id=last_message.id,
|
||||
additional_kwargs=last_message.additional_kwargs,
|
||||
)
|
||||
|
||||
@@ -20,6 +20,11 @@ class SubagentOverrideConfig(BaseModel):
|
||||
ge=1,
|
||||
description="Maximum turns for this subagent (None = use global or builtin default)",
|
||||
)
|
||||
model: str | None = Field(
|
||||
default=None,
|
||||
min_length=1,
|
||||
description="Model name for this subagent (None = inherit from parent agent)",
|
||||
)
|
||||
|
||||
|
||||
class SubagentsAppConfig(BaseModel):
|
||||
@@ -54,6 +59,20 @@ class SubagentsAppConfig(BaseModel):
|
||||
return override.timeout_seconds
|
||||
return self.timeout_seconds
|
||||
|
||||
def get_model_for(self, agent_name: str) -> str | None:
|
||||
"""Get the model override for a specific agent.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the subagent.
|
||||
|
||||
Returns:
|
||||
Model name if overridden, None otherwise (subagent will inherit parent model).
|
||||
"""
|
||||
override = self.agents.get(agent_name)
|
||||
if override is not None and override.model is not None:
|
||||
return override.model
|
||||
return None
|
||||
|
||||
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
|
||||
"""Get the effective max_turns for a specific agent."""
|
||||
override = self.agents.get(agent_name)
|
||||
@@ -84,6 +103,8 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
|
||||
parts.append(f"timeout={override.timeout_seconds}s")
|
||||
if override.max_turns is not None:
|
||||
parts.append(f"max_turns={override.max_turns}")
|
||||
if override.model is not None:
|
||||
parts.append(f"model={override.model}")
|
||||
if parts:
|
||||
overrides_summary[name] = ", ".join(parts)
|
||||
|
||||
|
||||
@@ -62,6 +62,9 @@ class LocalSandbox(Sandbox):
|
||||
"""
|
||||
super().__init__(id)
|
||||
self.path_mappings = path_mappings or []
|
||||
# Track files written through write_file so read_file only
|
||||
# reverse-resolves paths in agent-authored content.
|
||||
self._agent_written_paths: set[str] = set()
|
||||
|
||||
def _is_read_only_path(self, resolved_path: str) -> bool:
|
||||
"""Check if a resolved path is under a read-only mount.
|
||||
@@ -205,6 +208,39 @@ class LocalSandbox(Sandbox):
|
||||
|
||||
return pattern.sub(replace_match, command)
|
||||
|
||||
def _resolve_paths_in_content(self, content: str) -> str:
|
||||
"""Resolve container paths to local paths in arbitrary file content.
|
||||
|
||||
Unlike ``_resolve_paths_in_command`` which uses shell-aware boundary
|
||||
characters, this method treats the content as plain text and resolves
|
||||
every occurrence of a container path prefix. Resolved paths are
|
||||
normalized to forward slashes to avoid backslash-escape issues on
|
||||
Windows hosts (e.g. ``C:\\Users\\..`` breaking Python string literals).
|
||||
|
||||
Args:
|
||||
content: File content that may contain container paths.
|
||||
|
||||
Returns:
|
||||
Content with container paths resolved to local paths (forward slashes).
|
||||
"""
|
||||
import re
|
||||
|
||||
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True)
|
||||
if not sorted_mappings:
|
||||
return content
|
||||
|
||||
patterns = [re.escape(m.container_path) + r"(?=/|$|[^\w./-])(?:/[^\s\"';&|<>()]*)?" for m in sorted_mappings]
|
||||
pattern = re.compile("|".join(f"({p})" for p in patterns))
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
matched_path = match.group(0)
|
||||
resolved = self._resolve_path(matched_path)
|
||||
# Normalize to forward slashes so that Windows backslash paths
|
||||
# don't create invalid escape sequences in source files.
|
||||
return resolved.replace("\\", "/")
|
||||
|
||||
return pattern.sub(replace_match, content)
|
||||
|
||||
@staticmethod
|
||||
def _get_shell() -> str:
|
||||
"""Detect available shell executable with fallback."""
|
||||
@@ -280,7 +316,14 @@ class LocalSandbox(Sandbox):
|
||||
resolved_path = self._resolve_path(path)
|
||||
try:
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
content = f.read()
|
||||
# Only reverse-resolve paths in files that were previously written
|
||||
# by write_file (agent-authored content). User-uploaded files,
|
||||
# external tool output, and other non-agent content should not be
|
||||
# silently rewritten — see discussion on PR #1935.
|
||||
if resolved_path in self._agent_written_paths:
|
||||
content = self._reverse_resolve_paths_in_output(content)
|
||||
return content
|
||||
except OSError as e:
|
||||
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
|
||||
raise type(e)(e.errno, e.strerror, path) from None
|
||||
@@ -293,9 +336,16 @@ class LocalSandbox(Sandbox):
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
# Resolve container paths in content to local paths
|
||||
# using the content-specific resolver (forward-slash safe)
|
||||
resolved_content = self._resolve_paths_in_content(content)
|
||||
mode = "a" if append else "w"
|
||||
with open(resolved_path, mode, encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
f.write(resolved_content)
|
||||
# Track this path so read_file knows to reverse-resolve on read.
|
||||
# Only agent-written files get reverse-resolved; user uploads and
|
||||
# external tool output are left untouched.
|
||||
self._agent_written_paths.add(resolved_path)
|
||||
except OSError as e:
|
||||
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
|
||||
raise type(e)(e.errno, e.strerror, path) from None
|
||||
|
||||
@@ -39,7 +39,7 @@ def is_host_bash_allowed(config=None) -> bool:
|
||||
|
||||
sandbox_cfg = getattr(config, "sandbox", None)
|
||||
if sandbox_cfg is None:
|
||||
return True
|
||||
return False
|
||||
if not uses_local_sandbox_provider(config):
|
||||
return True
|
||||
return bool(getattr(sandbox_cfg, "allow_host_bash", False))
|
||||
|
||||
@@ -23,7 +23,8 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
# Apply timeout override from config.yaml (lazy import to avoid circular deps)
|
||||
# Apply runtime overrides (timeout, max_turns, model) from config.yaml
|
||||
# Lazy import to avoid circular deps.
|
||||
from deerflow.config.subagents_config import get_subagents_app_config
|
||||
|
||||
app_config = get_subagents_app_config()
|
||||
@@ -47,6 +48,15 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
effective_max_turns,
|
||||
)
|
||||
overrides["max_turns"] = effective_max_turns
|
||||
effective_model = app_config.get_model_for(name)
|
||||
if effective_model is not None and effective_model != config.model:
|
||||
logger.debug(
|
||||
"Subagent '%s': model overridden by config.yaml (%s -> %s)",
|
||||
name,
|
||||
config.model,
|
||||
effective_model,
|
||||
)
|
||||
overrides["model"] = effective_model
|
||||
if overrides:
|
||||
config = replace(config, **overrides)
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Tests for Discord channel integration wiring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.channels.discord import DiscordChannel
|
||||
from app.channels.manager import CHANNEL_CAPABILITIES
|
||||
from app.channels.message_bus import MessageBus
|
||||
from app.channels.service import _CHANNEL_REGISTRY
|
||||
|
||||
|
||||
def test_discord_channel_registered() -> None:
|
||||
assert "discord" in _CHANNEL_REGISTRY
|
||||
|
||||
|
||||
def test_discord_channel_capabilities() -> None:
|
||||
assert "discord" in CHANNEL_CAPABILITIES
|
||||
|
||||
|
||||
def test_discord_channel_init() -> None:
|
||||
bus = MessageBus()
|
||||
channel = DiscordChannel(bus=bus, config={"bot_token": "token"})
|
||||
|
||||
assert channel.name == "discord"
|
||||
@@ -1,5 +1,6 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
|
||||
@@ -79,3 +80,8 @@ def test_get_available_tools_keeps_bash_for_aio_sandbox(monkeypatch):
|
||||
|
||||
assert "bash" in names
|
||||
assert "ls" in names
|
||||
|
||||
|
||||
def test_is_host_bash_allowed_defaults_false_when_sandbox_missing():
|
||||
assert is_host_bash_allowed(SimpleNamespace()) is False
|
||||
assert is_host_bash_allowed(SimpleNamespace(sandbox=None)) is False
|
||||
|
||||
@@ -363,6 +363,98 @@ class TestLocalSandboxProviderMounts:
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
def test_write_file_resolves_container_paths_in_content(self, tmp_path):
|
||||
"""write_file should replace container paths in file content with local paths."""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
|
||||
],
|
||||
)
|
||||
sandbox.write_file(
|
||||
"/mnt/data/script.py",
|
||||
'import pathlib\npath = "/mnt/data/output"\nprint(path)',
|
||||
)
|
||||
written = (data_dir / "script.py").read_text()
|
||||
# Container path should be resolved to local path (forward slashes)
|
||||
assert str(data_dir).replace("\\", "/") in written
|
||||
assert "/mnt/data/output" not in written
|
||||
|
||||
def test_write_file_uses_forward_slashes_on_windows_paths(self, tmp_path):
|
||||
"""Resolved paths in content should always use forward slashes."""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
|
||||
],
|
||||
)
|
||||
sandbox.write_file(
|
||||
"/mnt/data/config.py",
|
||||
'DATA_DIR = "/mnt/data/files"',
|
||||
)
|
||||
written = (data_dir / "config.py").read_text()
|
||||
# Must not contain backslashes that could break escape sequences
|
||||
assert "\\" not in written.split("DATA_DIR = ")[1].split("\n")[0]
|
||||
|
||||
def test_read_file_reverse_resolves_local_paths_in_agent_written_files(self, tmp_path):
|
||||
"""read_file should convert local paths back to container paths in agent-written files."""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
|
||||
],
|
||||
)
|
||||
# Use write_file so the path is tracked as agent-written
|
||||
sandbox.write_file("/mnt/data/info.txt", "File located at: /mnt/data/info.txt")
|
||||
|
||||
content = sandbox.read_file("/mnt/data/info.txt")
|
||||
assert "/mnt/data/info.txt" in content
|
||||
|
||||
def test_read_file_does_not_reverse_resolve_non_agent_files(self, tmp_path):
|
||||
"""read_file should NOT rewrite paths in user-uploaded or external files."""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
|
||||
],
|
||||
)
|
||||
# Write directly to filesystem (simulates user upload or external tool output)
|
||||
local_path = str(data_dir).replace("\\", "/")
|
||||
(data_dir / "config.yml").write_text(f"output_dir: {local_path}/outputs")
|
||||
|
||||
content = sandbox.read_file("/mnt/data/config.yml")
|
||||
# Content should be returned as-is, NOT reverse-resolved
|
||||
assert local_path in content
|
||||
|
||||
def test_write_then_read_roundtrip(self, tmp_path):
|
||||
"""Container paths survive a write → read roundtrip."""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
|
||||
],
|
||||
)
|
||||
original = 'cfg = {"path": "/mnt/data/config.json", "flag": true}'
|
||||
sandbox.write_file("/mnt/data/settings.py", original)
|
||||
result = sandbox.read_file("/mnt/data/settings.py")
|
||||
# The container path should be preserved through roundtrip
|
||||
assert "/mnt/data/config.json" in result
|
||||
|
||||
def test_setup_path_mappings_normalizes_container_path_trailing_slash(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
@@ -280,6 +280,8 @@ class TestLoopDetection:
|
||||
mw._apply(_make_state(tool_calls=call), runtime_new)
|
||||
|
||||
assert "thread-0" not in mw._history
|
||||
assert "thread-0" not in mw._tool_freq
|
||||
assert "thread-0" not in mw._tool_freq_warned
|
||||
assert "thread-new" in mw._history
|
||||
assert len(mw._history) == 3
|
||||
|
||||
@@ -410,3 +412,188 @@ class TestHardStopWithListContent:
|
||||
assert isinstance(msg.content, str)
|
||||
assert msg.content.startswith("thinking...")
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
|
||||
class TestToolFrequencyDetection:
|
||||
"""Tests for per-tool-type frequency detection (Layer 2).
|
||||
|
||||
This catches the case where an agent calls the same tool type many times
|
||||
with *different* arguments (e.g. read_file on 40 different files), which
|
||||
bypasses hash-based detection.
|
||||
"""
|
||||
|
||||
def _read_call(self, path):
|
||||
return {"name": "read_file", "id": f"call_read_{path}", "args": {"path": path}}
|
||||
|
||||
def test_below_freq_warn_returns_none(self):
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(4):
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_freq_warn_at_threshold(self):
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(4):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 5th call to read_file (different file each time) triggers freq warning
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, HumanMessage)
|
||||
assert "read_file" in msg.content
|
||||
assert "LOOP DETECTED" in msg.content
|
||||
|
||||
def test_freq_warn_only_injected_once(self):
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 3rd triggers warning
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
# 4th should not re-warn (already warned for read_file)
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_freq_hard_stop_at_limit(self):
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(5):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 6th call triggers hard stop
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_5.py")]), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert msg.tool_calls == []
|
||||
assert "FORCED STOP" in msg.content
|
||||
assert "read_file" in msg.content
|
||||
|
||||
def test_different_tools_tracked_independently(self):
|
||||
"""read_file and bash should have independent frequency counters."""
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
|
||||
# 2 read_file calls
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 2 bash calls — should not trigger (bash count = 2, read_file count = 2)
|
||||
for i in range(2):
|
||||
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
|
||||
assert result is None
|
||||
|
||||
# 3rd read_file triggers (read_file count = 3)
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
|
||||
def test_freq_reset_clears_state(self):
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
mw.reset()
|
||||
|
||||
# After reset, count restarts — should not trigger
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_new.py")]), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_freq_reset_per_thread_clears_only_target(self):
|
||||
"""reset(thread_id=...) should clear frequency state for that thread only."""
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
|
||||
runtime_a = _make_runtime("thread-A")
|
||||
runtime_b = _make_runtime("thread-B")
|
||||
|
||||
# 2 calls on each thread
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/a_{i}.py")]), runtime_a)
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/b_{i}.py")]), runtime_b)
|
||||
|
||||
# Reset only thread-A
|
||||
mw.reset(thread_id="thread-A")
|
||||
|
||||
assert "thread-A" not in mw._tool_freq
|
||||
assert "thread-A" not in mw._tool_freq_warned
|
||||
|
||||
# thread-B state should still be intact — 3rd call triggers warn
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
# thread-A restarted from 0 — should not trigger
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a)
|
||||
assert result is None
|
||||
|
||||
def test_freq_per_thread_isolation(self):
|
||||
"""Frequency counts should be independent per thread."""
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
|
||||
runtime_a = _make_runtime("thread-A")
|
||||
runtime_b = _make_runtime("thread-B")
|
||||
|
||||
# 2 calls on thread A
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime_a)
|
||||
|
||||
# 2 calls on thread B — should NOT push thread A over threshold
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b)
|
||||
|
||||
# 3rd call on thread A — triggers (count=3 for thread A only)
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
def test_multi_tool_single_response_counted(self):
|
||||
"""When a single response has multiple tool calls, each is counted."""
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
|
||||
# Response 1: 2 read_file calls → count = 2
|
||||
call = [self._read_call("/a.py"), self._read_call("/b.py")]
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
|
||||
# Response 2: 2 more → count = 4
|
||||
call = [self._read_call("/c.py"), self._read_call("/d.py")]
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
|
||||
# Response 3: 1 more → count = 5 → triggers warn
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
|
||||
def test_hash_detection_takes_priority(self):
|
||||
"""Hash-based hard stop fires before frequency check for identical calls."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
warn_threshold=2,
|
||||
hard_limit=3,
|
||||
tool_freq_warn=100,
|
||||
tool_freq_hard_limit=200,
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
call = [self._read_call("/same_file.py")]
|
||||
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# 3rd identical call → hash hard_limit=3 fires (not freq)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
@@ -50,11 +50,19 @@ class TestSubagentOverrideConfig:
|
||||
override = SubagentOverrideConfig()
|
||||
assert override.timeout_seconds is None
|
||||
assert override.max_turns is None
|
||||
assert override.model is None
|
||||
|
||||
def test_explicit_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42)
|
||||
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42, model="gpt-5.4")
|
||||
assert override.timeout_seconds == 300
|
||||
assert override.max_turns == 42
|
||||
assert override.model == "gpt-5.4"
|
||||
|
||||
def test_model_accepts_any_non_empty_string(self):
|
||||
"""Model name is a free-form non-empty string; cross-reference validation
|
||||
against the `models:` section happens at registry lookup time."""
|
||||
override = SubagentOverrideConfig(model="any-arbitrary-model-name")
|
||||
assert override.model == "any-arbitrary-model-name"
|
||||
|
||||
def test_rejects_zero(self):
|
||||
with pytest.raises(ValueError):
|
||||
@@ -68,6 +76,13 @@ class TestSubagentOverrideConfig:
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=-1)
|
||||
|
||||
def test_rejects_empty_model(self):
|
||||
"""Empty-string model would silently bypass the `is not None` check and
|
||||
reach `create_chat_model(name="")` as a runtime error. Reject at load time
|
||||
instead, symmetric with the `ge=1` guard on timeout_seconds / max_turns."""
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(model="")
|
||||
|
||||
def test_minimum_valid_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
|
||||
assert override.timeout_seconds == 1
|
||||
@@ -165,6 +180,42 @@ class TestRuntimeResolution:
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_get_model_for_returns_none_when_no_override(self):
|
||||
"""No per-agent model override -> returns None so callers fall back to builtin/parent."""
|
||||
config = SubagentsAppConfig(timeout_seconds=900)
|
||||
assert config.get_model_for("general-purpose") is None
|
||||
assert config.get_model_for("bash") is None
|
||||
assert config.get_model_for("unknown-agent") is None
|
||||
|
||||
def test_get_model_for_returns_override_when_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(model="qwen3.5-35b-a3b"),
|
||||
"bash": SubagentOverrideConfig(model="gpt-5.4"),
|
||||
},
|
||||
)
|
||||
assert config.get_model_for("general-purpose") == "qwen3.5-35b-a3b"
|
||||
assert config.get_model_for("bash") == "gpt-5.4"
|
||||
|
||||
def test_get_model_for_returns_none_for_omitted_agent(self):
|
||||
"""An agent not listed in overrides returns None even when other agents have model overrides."""
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(model="gpt-5.4")},
|
||||
)
|
||||
assert config.get_model_for("general-purpose") is None
|
||||
|
||||
def test_get_model_for_handles_explicit_none(self):
|
||||
"""Explicit model=None in the override is equivalent to no override."""
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, model=None)},
|
||||
)
|
||||
assert config.get_model_for("bash") is None
|
||||
# Timeout override is still applied even when model is None.
|
||||
assert config.get_timeout_for("bash") == 300
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_subagents_config_from_dict / get_subagents_app_config singleton
|
||||
@@ -211,6 +262,22 @@ class TestLoadSubagentsConfig:
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert cfg.get_max_turns_for("bash", 60) == 70
|
||||
|
||||
def test_load_with_model_overrides(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"general-purpose": {"model": "qwen3.5-35b-a3b"},
|
||||
"bash": {"model": "gpt-5.4", "timeout_seconds": 300},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_model_for("general-purpose") == "qwen3.5-35b-a3b"
|
||||
assert cfg.get_model_for("bash") == "gpt-5.4"
|
||||
# Other override fields on the same agent must still load correctly.
|
||||
assert cfg.get_timeout_for("bash") == 300
|
||||
|
||||
def test_load_empty_dict_uses_defaults(self):
|
||||
load_subagents_config_from_dict({})
|
||||
cfg = get_subagents_app_config()
|
||||
@@ -296,6 +363,97 @@ class TestRegistryGetSubagentConfig:
|
||||
assert gp_config.timeout_seconds == 900
|
||||
assert gp_config.max_turns == 120
|
||||
|
||||
def test_per_agent_model_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == "gpt-5.4-mini"
|
||||
|
||||
def test_omitted_model_keeps_builtin_value(self):
|
||||
"""When config.yaml has no `model` field for an agent, the builtin default must be preserved."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"timeout_seconds": 300}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == builtin_bash_model
|
||||
|
||||
def test_explicit_null_model_keeps_builtin_value(self):
|
||||
"""An explicit `model: null` in config.yaml is equivalent to omission — builtin wins."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": None}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == builtin_bash_model
|
||||
|
||||
def test_model_override_does_not_affect_other_agents(self):
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_gp_model = BUILTIN_SUBAGENTS["general-purpose"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4"}},
|
||||
}
|
||||
)
|
||||
gp_config = get_subagent_config("general-purpose")
|
||||
assert gp_config.model == builtin_gp_model
|
||||
|
||||
def test_model_override_preserves_other_fields(self):
|
||||
"""Applying a model override must leave timeout_seconds / max_turns / name intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original = BUILTIN_SUBAGENTS["bash"]
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
overridden = get_subagent_config("bash")
|
||||
assert overridden.model == "gpt-5.4-mini"
|
||||
assert overridden.name == original.name
|
||||
assert overridden.description == original.description
|
||||
# No timeout / max_turns override was set, so they use global default / builtin.
|
||||
assert overridden.timeout_seconds == 900
|
||||
assert overridden.max_turns == original.max_turns
|
||||
|
||||
def test_model_override_does_not_mutate_builtin(self):
|
||||
"""Registry must return a new object, leaving the builtin default intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
_ = get_subagent_config("bash")
|
||||
assert BUILTIN_SUBAGENTS["bash"].model == original_bash_model
|
||||
|
||||
def test_builtin_config_object_is_not_mutated(self):
|
||||
"""Registry must return a new object, leaving the builtin default intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
|
||||
@@ -256,8 +256,10 @@ class TestBeforeAgent:
|
||||
|
||||
assert result is not None
|
||||
updated_msg = result["messages"][-1]
|
||||
assert "<uploaded_files>" in updated_msg.content
|
||||
assert "analyse this" in updated_msg.content
|
||||
assert isinstance(updated_msg.content, list)
|
||||
combined_text = "\n".join(block.get("text", "") for block in updated_msg.content if isinstance(block, dict))
|
||||
assert "<uploaded_files>" in combined_text
|
||||
assert "analyse this" in combined_text
|
||||
|
||||
def test_preserves_additional_kwargs_on_updated_message(self, tmp_path):
|
||||
mw = _middleware(tmp_path)
|
||||
|
||||
Generated
+34
-4
@@ -749,6 +749,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
ollama = [
|
||||
{ name = "langchain-ollama" },
|
||||
]
|
||||
pymupdf = [
|
||||
{ name = "pymupdf4llm" },
|
||||
]
|
||||
@@ -769,6 +772,7 @@ requires-dist = [
|
||||
{ name = "langchain-deepseek", specifier = ">=1.0.1" },
|
||||
{ name = "langchain-google-genai", specifier = ">=4.2.1" },
|
||||
{ name = "langchain-mcp-adapters", specifier = ">=0.1.0" },
|
||||
{ name = "langchain-ollama", marker = "extra == 'ollama'", specifier = ">=0.3.0" },
|
||||
{ name = "langchain-openai", specifier = ">=1.1.7" },
|
||||
{ name = "langfuse", specifier = ">=3.4.1" },
|
||||
{ name = "langgraph", specifier = ">=1.0.6,<1.0.10" },
|
||||
@@ -786,7 +790,7 @@ requires-dist = [
|
||||
{ name = "tavily-python", specifier = ">=0.7.17" },
|
||||
{ name = "tiktoken", specifier = ">=0.8.0" },
|
||||
]
|
||||
provides-extras = ["pymupdf"]
|
||||
provides-extras = ["ollama", "pymupdf"]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
@@ -1564,7 +1568,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.2.17"
|
||||
version = "1.2.28"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -1576,9 +1580,9 @@ dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "uuid-utils" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1d/93/36226f593df52b871fc24d494c274f3a6b2ac76763a2806e7d35611634a1/langchain_core-1.2.17.tar.gz", hash = "sha256:54aa267f3311e347fb2e50951fe08e53761cebfb999ab80e6748d70525bbe872", size = 836130, upload-time = "2026-03-02T22:47:55.846Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f8/a4/317a1a3ac1df33a64adb3670bf88bbe3b3d5baa274db6863a979db472897/langchain_core-1.2.28.tar.gz", hash = "sha256:271a3d8bd618f795fdeba112b0753980457fc90537c46a0c11998516a74dc2cb", size = 846119, upload-time = "2026-04-08T18:19:34.867Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/be/90/073f33ab383a62908eca7ea699586dfea280e77182176e33199c80ddf22a/langchain_core-1.2.17-py3-none-any.whl", hash = "sha256:bf6bd6ce503874e9c2da1669a69383e967c3de1ea808921d19a9a6bff1a9fbbe", size = 502727, upload-time = "2026-03-02T22:47:54.537Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/92/32f785f077c7e898da97064f113c73fbd9ad55d1e2169cf3a391b183dedb/langchain_core-1.2.28-py3-none-any.whl", hash = "sha256:80764232581eaf8057bcefa71dbf8adc1f6a28d257ebd8b95ba9b8b452e8c6ac", size = 508727, upload-time = "2026-04-08T18:19:32.823Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1623,6 +1627,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/03/81/b2479eb26861ab36be851026d004b2d391d789b7856e44c272b12828ece0/langchain_mcp_adapters-0.2.1-py3-none-any.whl", hash = "sha256:9f96ad4c64230f6757297fec06fde19d772c99dbdfbca987f7b7cfd51ff77240", size = 22708, upload-time = "2025-12-09T16:28:37.877Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-ollama"
|
||||
version = "1.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "ollama" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d4/9b/6641afe8a5bf807e454fd464eddfc7eb2f2df53cb0b29744381171f9c609/langchain_ollama-1.1.0.tar.gz", hash = "sha256:f776f56f6782ae4da7692579b94a6575906118318d1023b455d7207f9d059811", size = 133075, upload-time = "2026-04-07T02:48:00.873Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/b2/c2acb076590a98bee2816ed5f285e00df162a34238f9e276e175e14ebc35/langchain_ollama-1.1.0-py3-none-any.whl", hash = "sha256:43ac83a6eacb0f43855810739794dd55019e0d9b17bdcf3ecb3b1991ac3b59dd", size = 31413, upload-time = "2026-04-07T02:47:59.642Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "1.1.7"
|
||||
@@ -2264,6 +2281,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/17/d3/b64c356a907242d719fc668b71befd73324e47ab46c8ebbbede252c154b2/olefile-0.47-py2.py3-none-any.whl", hash = "sha256:543c7da2a7adadf21214938bb79c83ea12b473a4b6ee4ad4bf854e7715e13d1f", size = 114565, upload-time = "2023-12-01T16:22:51.518Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ollama"
|
||||
version = "0.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9d/5a/652dac4b7affc2b37b95386f8ae78f22808af09d720689e3d7a86b6ed98e/ollama-0.6.1.tar.gz", hash = "sha256:478c67546836430034b415ed64fa890fd3d1ff91781a9d548b3325274e69d7c6", size = 51620, upload-time = "2025-11-13T23:02:17.416Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/47/4f/4a617ee93d8208d2bcf26b2d8b9402ceaed03e3853c754940e2290fed063/ollama-0.6.1-py3-none-any.whl", hash = "sha256:fc4c984b345735c5486faeee67d8a265214a31cbb828167782dc642ce0a2bf8c", size = 14354, upload-time = "2025-11-13T23:02:16.292Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnxruntime"
|
||||
version = "1.20.1"
|
||||
|
||||
@@ -575,9 +575,14 @@ sandbox:
|
||||
# general-purpose:
|
||||
# timeout_seconds: 1800 # 30 minutes for complex multi-step tasks
|
||||
# max_turns: 160
|
||||
# # model: qwen3:32b # Use a specific model (default: inherit from lead agent)
|
||||
# bash:
|
||||
# timeout_seconds: 300 # 5 minutes for quick command execution
|
||||
# max_turns: 80
|
||||
#
|
||||
# # Model override: by default, subagents inherit the lead agent's model.
|
||||
# # Set `model` to use a different model (e.g., a local Ollama model for cost savings).
|
||||
# # The model name must match a name defined in the `models:` section above.
|
||||
|
||||
# ============================================================================
|
||||
# ACP Agents Configuration
|
||||
|
||||
Generated
+247
-287
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user