Feature: #6351 - Agent selection, tool resolution & framework integration -- first version with mocked MCPRegistry
This commit is contained in:
@@ -70,6 +70,7 @@ exports/*
|
||||
.claude/settings.local.json
|
||||
|
||||
.venv
|
||||
/venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"profile": "all"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"servers": [
|
||||
{
|
||||
"name": "hive-tools",
|
||||
"version": "1.0.0",
|
||||
"tags": ["core", "productivity"],
|
||||
"profiles": ["all", "core", "productivity"],
|
||||
"mcp_config": {
|
||||
"transport": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "mcp_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"description": "Hive tools MCP server providing core utilities"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "coder-tools",
|
||||
"version": "1.0.0",
|
||||
"tags": ["coding", "productivity"],
|
||||
"profiles": ["all", "coding", "productivity"],
|
||||
"mcp_config": {
|
||||
"transport": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "coder_tools_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"description": "Unsandboxed file/code tools for code generation"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "tools",
|
||||
"version": "1.0.0",
|
||||
"tags": ["web", "productivity"],
|
||||
"profiles": ["all", "general"],
|
||||
"mcp_config": {
|
||||
"transport": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "mcp_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"description": "Aden tools MCP server providing web/file utilities"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CACHE_INDEX_PATH = Path.home() / ".hive" / "mcp_registry" / "cache" / "registry_index.json"
|
||||
_FIXTURE_INDEX_PATH = Path(__file__).resolve().parent / "fixtures" / "registry_index.json"
|
||||
|
||||
|
||||
def resolve_registry_servers(
|
||||
*,
|
||||
include: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
profile: str | None = None,
|
||||
max_tools: int | None = None,
|
||||
versions: dict[str, str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Resolve registry-sourced MCP servers for `mcp_registry.json` selection.
|
||||
|
||||
This function is written to be mock-friendly during early development:
|
||||
- If the real `MCPRegistry` core module is present, delegate to it.
|
||||
- Otherwise, fall back to a cached local index (`~/.hive/.../registry_index.json`)
|
||||
and then to the repo fixture index.
|
||||
"""
|
||||
|
||||
# `max_tools` is enforced by ToolRegistry. We keep it in the resolver
|
||||
# signature to match the PRD and future MCPRegistry interfaces.
|
||||
_ = max_tools
|
||||
|
||||
try:
|
||||
from framework.runner.mcp_registry import MCPRegistry # type: ignore
|
||||
|
||||
registry = MCPRegistry()
|
||||
resolved = registry.resolve_for_agent(
|
||||
include=include or [],
|
||||
tags=tags or [],
|
||||
exclude=exclude or [],
|
||||
profile=profile,
|
||||
max_tools=max_tools,
|
||||
versions=versions or {},
|
||||
)
|
||||
# Future-proof: normalize both dicts and typed objects to dicts.
|
||||
return [_normalize_server_config(x) for x in resolved]
|
||||
except ImportError:
|
||||
# Expected while #6349/#6574 is not merged locally.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("MCPRegistry resolution failed; falling back to cache/fixtures: %s", e)
|
||||
|
||||
return _resolve_from_local_index(
|
||||
include=include,
|
||||
tags=tags,
|
||||
exclude=exclude,
|
||||
profile=profile,
|
||||
versions=versions or {},
|
||||
)
|
||||
|
||||
|
||||
def _resolve_from_local_index(
|
||||
*,
|
||||
include: list[str] | None,
|
||||
tags: list[str] | None,
|
||||
exclude: list[str] | None,
|
||||
profile: str | None,
|
||||
versions: dict[str, str],
|
||||
) -> list[dict[str, Any]]:
|
||||
index = _load_index_json()
|
||||
servers = _coerce_index_servers(index)
|
||||
servers_by_name: dict[str, dict[str, Any]] = {
|
||||
s["name"]: s for s in servers if isinstance(s, dict) and "name" in s
|
||||
}
|
||||
|
||||
include_list = include or []
|
||||
tags_list = tags or []
|
||||
exclude_set = set(exclude or [])
|
||||
|
||||
def _profiles_of(entry: dict[str, Any]) -> set[str]:
|
||||
if isinstance(entry.get("profiles"), list):
|
||||
return set(entry["profiles"])
|
||||
hive = entry.get("hive")
|
||||
if isinstance(hive, dict) and isinstance(hive.get("profiles"), list):
|
||||
return set(hive["profiles"])
|
||||
return set()
|
||||
|
||||
def _tags_of(entry: dict[str, Any]) -> set[str]:
|
||||
if isinstance(entry.get("tags"), list):
|
||||
return set(entry["tags"])
|
||||
return set()
|
||||
|
||||
def _entry_version(entry: dict[str, Any]) -> str | None:
|
||||
# Prefer flat `version`, but support a few common shapes.
|
||||
v = entry.get("version")
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
v2 = entry.get("manifest_version")
|
||||
if isinstance(v2, str):
|
||||
return v2
|
||||
hive = entry.get("manifest")
|
||||
if isinstance(hive, dict) and isinstance(hive.get("version"), str):
|
||||
return hive["version"]
|
||||
return None
|
||||
|
||||
def _version_allows(server_name: str) -> bool:
|
||||
if server_name not in versions:
|
||||
return True
|
||||
pinned = versions[server_name]
|
||||
entry = servers_by_name.get(server_name)
|
||||
if not entry:
|
||||
return False
|
||||
return _entry_version(entry) == pinned
|
||||
|
||||
resolved_names: list[str] = []
|
||||
resolved_set: set[str] = set()
|
||||
|
||||
# 1) Include-order first
|
||||
for name in include_list:
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if name in servers_by_name and _version_allows(name) and name not in resolved_set:
|
||||
resolved_names.append(name)
|
||||
resolved_set.add(name)
|
||||
|
||||
# 2) Then tag/profile matches, alphabetical
|
||||
profile_candidates = set()
|
||||
if profile:
|
||||
for name, entry in servers_by_name.items():
|
||||
if name in exclude_set or not _version_allows(name):
|
||||
continue
|
||||
if profile in _profiles_of(entry):
|
||||
profile_candidates.add(name)
|
||||
|
||||
tag_candidates = set()
|
||||
if tags_list:
|
||||
tags_set = set(tags_list)
|
||||
for name, entry in servers_by_name.items():
|
||||
if name in exclude_set or not _version_allows(name):
|
||||
continue
|
||||
if _tags_of(entry).intersection(tags_set):
|
||||
tag_candidates.add(name)
|
||||
|
||||
tag_profile_names = sorted((profile_candidates | tag_candidates) - resolved_set)
|
||||
resolved_names.extend(tag_profile_names)
|
||||
|
||||
# Missing requested servers should warn (FR-54).
|
||||
for name in include_list:
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if name not in resolved_set:
|
||||
if name not in servers_by_name:
|
||||
logger.warning(
|
||||
"Server '%s' requested by mcp_registry.json but not found in index. "
|
||||
"Run: hive mcp install %s",
|
||||
name,
|
||||
name,
|
||||
)
|
||||
elif name in versions:
|
||||
logger.warning(
|
||||
"Server '%s' was requested but pinned version '%s' was not found in index. "
|
||||
"Run: hive mcp update %s or change the pin in mcp_registry.json",
|
||||
name,
|
||||
versions[name],
|
||||
name,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Server '%s' requested by mcp_registry.json was not selected. "
|
||||
"Check selection filters/exclude lists.",
|
||||
name,
|
||||
)
|
||||
|
||||
resolved_configs: list[dict[str, Any]] = []
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
for name in resolved_names:
|
||||
entry = servers_by_name.get(name)
|
||||
if not entry:
|
||||
continue
|
||||
config = entry.get("mcp_config")
|
||||
if not isinstance(config, dict):
|
||||
# Best-effort: allow a direct MCP config shape at top-level.
|
||||
config = {
|
||||
k: v
|
||||
for k, v in entry.items()
|
||||
if k
|
||||
in {
|
||||
"name",
|
||||
"transport",
|
||||
"command",
|
||||
"args",
|
||||
"env",
|
||||
"cwd",
|
||||
"url",
|
||||
"headers",
|
||||
"description",
|
||||
}
|
||||
}
|
||||
mcp_config = dict(config)
|
||||
mcp_config["name"] = name
|
||||
if mcp_config.get("transport") == "stdio":
|
||||
_absolutize_stdio_config_in_place(repo_root, mcp_config)
|
||||
resolved_configs.append(mcp_config)
|
||||
|
||||
return resolved_configs
|
||||
|
||||
|
||||
def _load_index_json() -> Any:
|
||||
if _CACHE_INDEX_PATH.exists():
|
||||
return json.loads(_CACHE_INDEX_PATH.read_text(encoding="utf-8"))
|
||||
if _FIXTURE_INDEX_PATH.exists():
|
||||
logger.info("Using local fixture index because registry cache is missing")
|
||||
return json.loads(_FIXTURE_INDEX_PATH.read_text(encoding="utf-8"))
|
||||
logger.warning("No local MCP registry index found (cache and fixture missing)")
|
||||
return {"servers": []}
|
||||
|
||||
|
||||
def _coerce_index_servers(index: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(index, list):
|
||||
return [x for x in index if isinstance(x, dict)]
|
||||
if isinstance(index, dict):
|
||||
servers = index.get("servers", [])
|
||||
if isinstance(servers, list):
|
||||
return [x for x in servers if isinstance(x, dict)]
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_server_config(raw: Any) -> dict[str, Any]:
|
||||
if isinstance(raw, dict):
|
||||
return dict(raw)
|
||||
|
||||
# Future-proof object-to-dict normalization.
|
||||
for attr in ("to_dict", "model_dump"):
|
||||
maybe = getattr(raw, attr, None)
|
||||
if callable(maybe):
|
||||
return dict(maybe())
|
||||
|
||||
return dict(getattr(raw, "__dict__", {}))
|
||||
|
||||
|
||||
def _absolutize_stdio_config_in_place(repo_root: Path, config: dict[str, Any]) -> None:
|
||||
cwd = config.get("cwd")
|
||||
if isinstance(cwd, str) and not Path(cwd).is_absolute():
|
||||
config["cwd"] = str((repo_root / cwd).resolve())
|
||||
|
||||
# We intentionally do not absolutize `args` here.
|
||||
# For stdio servers, arguments may include the script name relative to
|
||||
# `cwd` (e.g. "coder_tools_server.py" with cwd="tools"). ToolRegistry's
|
||||
# stdio resolution logic handles script path checks and platform quirks.
|
||||
@@ -808,6 +808,29 @@ class AgentRunner:
|
||||
if mcp_config_path.exists():
|
||||
self._load_mcp_servers_from_config(mcp_config_path)
|
||||
|
||||
# Optional: load additional MCP servers selected via mcp_registry.json.
|
||||
# This is backward-compatible: if the file doesn't exist, nothing changes.
|
||||
mcp_registry_path = agent_path / "mcp_registry.json"
|
||||
if mcp_registry_path.exists():
|
||||
try:
|
||||
raw_selection = json.loads(mcp_registry_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(raw_selection, dict):
|
||||
raise TypeError("mcp_registry.json must be a JSON object")
|
||||
|
||||
allowed_keys = {
|
||||
"include",
|
||||
"tags",
|
||||
"exclude",
|
||||
"profile",
|
||||
"max_tools",
|
||||
"versions",
|
||||
}
|
||||
selection = {k: v for k, v in raw_selection.items() if k in allowed_keys}
|
||||
|
||||
self._tool_registry.load_registry_servers(**selection)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load mcp_registry.json: %s", e)
|
||||
|
||||
@staticmethod
|
||||
def _import_agent_module(agent_path: Path):
|
||||
"""Import an agent package from its directory path.
|
||||
|
||||
@@ -64,6 +64,7 @@ class ToolRegistry:
|
||||
self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time
|
||||
self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time
|
||||
self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names
|
||||
self._mcp_registry_selection: dict[str, Any] | None = None # for resync
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -479,10 +480,85 @@ class ToolRegistry:
|
||||
self._mcp_cred_snapshot = self._snapshot_credentials()
|
||||
self._mcp_aden_key_snapshot = os.environ.get("ADEN_API_KEY")
|
||||
|
||||
def load_registry_servers(
|
||||
self,
|
||||
*,
|
||||
include: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
profile: str | None = None,
|
||||
max_tools: int | None = None,
|
||||
versions: dict[str, str] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Resolve and load MCP servers based on `mcp_registry.json` selection.
|
||||
|
||||
Implements:
|
||||
- deterministic server order (resolved by resolver)
|
||||
- first-wins tool collisions (existing tools are preserved)
|
||||
- `max_tools` cap on *newly registered* tools from registry servers
|
||||
"""
|
||||
|
||||
from framework.runner.mcp_registry_resolver import resolve_registry_servers
|
||||
|
||||
self._mcp_registry_selection = {
|
||||
"include": include,
|
||||
"tags": tags,
|
||||
"exclude": exclude,
|
||||
"profile": profile,
|
||||
"max_tools": max_tools,
|
||||
"versions": versions,
|
||||
}
|
||||
|
||||
resolved_servers = resolve_registry_servers(
|
||||
include=include,
|
||||
tags=tags,
|
||||
exclude=exclude,
|
||||
profile=profile,
|
||||
max_tools=max_tools,
|
||||
versions=versions or {},
|
||||
)
|
||||
if not resolved_servers:
|
||||
logger.warning("MCP registry selection resolved to 0 servers; nothing to load")
|
||||
return 0
|
||||
|
||||
tools_added = 0
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
for server_cfg in resolved_servers:
|
||||
if not isinstance(server_cfg, dict):
|
||||
continue
|
||||
|
||||
if max_tools is not None and tools_added >= max_tools:
|
||||
break
|
||||
|
||||
# Normalize stdio config so scripts/cwd behave like mcp_servers.json loading.
|
||||
server_cfg = self._resolve_mcp_server_config(server_cfg, repo_root)
|
||||
|
||||
remaining = None
|
||||
if max_tools is not None:
|
||||
remaining = max_tools - tools_added
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
added = self.register_mcp_server(
|
||||
server_cfg,
|
||||
preserve_existing_tools=True,
|
||||
tool_cap=remaining,
|
||||
log_collisions=True,
|
||||
)
|
||||
tools_added += added
|
||||
|
||||
return tools_added
|
||||
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
use_connection_manager: bool = False,
|
||||
*,
|
||||
preserve_existing_tools: bool = False,
|
||||
tool_cap: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Register an MCP server and discover its tools.
|
||||
@@ -540,6 +616,23 @@ class ToolRegistry:
|
||||
self._mcp_server_tools[server_name] = set()
|
||||
count = 0
|
||||
for mcp_tool in client.list_tools():
|
||||
if tool_cap is not None and count >= tool_cap:
|
||||
break
|
||||
|
||||
if preserve_existing_tools and mcp_tool.name in self._tools:
|
||||
if log_collisions:
|
||||
origin_server = (
|
||||
self._find_mcp_origin_server_for_tool(mcp_tool.name) or "<existing>"
|
||||
)
|
||||
logger.warning(
|
||||
"MCP tool '%s' from '%s' shadowed by '%s' (loaded first)",
|
||||
mcp_tool.name,
|
||||
server_name,
|
||||
origin_server,
|
||||
)
|
||||
# Skip registration; do not update MCP tool bookkeeping for this server.
|
||||
continue
|
||||
|
||||
# Convert MCP tool to framework Tool (strips context params from LLM schema)
|
||||
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
|
||||
|
||||
@@ -606,6 +699,12 @@ class ToolRegistry:
|
||||
)
|
||||
return 0
|
||||
|
||||
def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None:
|
||||
for server_name, tool_names in self._mcp_server_tools.items():
|
||||
if tool_name in tool_names:
|
||||
return server_name
|
||||
return None
|
||||
|
||||
def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool:
|
||||
"""
|
||||
Convert an MCP tool to a framework Tool.
|
||||
@@ -738,9 +837,13 @@ class ToolRegistry:
|
||||
for name in self._mcp_tool_names:
|
||||
self._tools.pop(name, None)
|
||||
self._mcp_tool_names.clear()
|
||||
self._mcp_server_tools.clear()
|
||||
|
||||
# 3. Re-load MCP servers (spawns fresh subprocesses with new credentials)
|
||||
self.load_mcp_config(self._mcp_config_path)
|
||||
if self._mcp_registry_selection is not None:
|
||||
# Re-apply the same registry selection (preserves tool collision semantics).
|
||||
self.load_registry_servers(**self._mcp_registry_selection)
|
||||
|
||||
logger.info("MCP server resync complete")
|
||||
return True
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from framework.runner.mcp_client import MCPTool
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
|
||||
|
||||
def _make_tool(name: str, server_name: str) -> MCPTool:
|
||||
return MCPTool(
|
||||
name=name,
|
||||
description=f"{name} from {server_name}",
|
||||
input_schema={"type": "object", "properties": {}, "required": []},
|
||||
server_name=server_name,
|
||||
)
|
||||
|
||||
|
||||
def test_registry_first_wins_collisions(monkeypatch):
|
||||
"""
|
||||
When multiple registry servers expose the same tool name, the first server
|
||||
in load order should win and later servers should not overwrite it.
|
||||
"""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"s1": ["tool_common", "tool_hive"],
|
||||
"s2": ["tool_common", "tool_coder"],
|
||||
}
|
||||
|
||||
from framework.runner import mcp_client as mcp_client_mod
|
||||
|
||||
class FakeMCPClient:
|
||||
def __init__(self, config: Any):
|
||||
self.config = config
|
||||
self._tools: list[MCPTool] = []
|
||||
|
||||
def connect(self) -> None:
|
||||
return
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return
|
||||
|
||||
def list_tools(self) -> list[MCPTool]:
|
||||
names = tool_map.get(self.config.name, [])
|
||||
return [_make_tool(n, self.config.name) for n in names]
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
monkeypatch.setattr(mcp_client_mod, "MCPClient", FakeMCPClient)
|
||||
|
||||
from framework.runner import mcp_registry_resolver as resolver_mod
|
||||
|
||||
# Return server configs in the desired deterministic order.
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
resolver_mod,
|
||||
"resolve_registry_servers",
|
||||
lambda **kwargs: resolved_servers,
|
||||
)
|
||||
|
||||
registry = ToolRegistry()
|
||||
added = registry.load_registry_servers(include=["s1"], tags=None, exclude=None, profile=None)
|
||||
|
||||
assert added == 3 # tool_common + tool_hive + tool_coder
|
||||
assert registry.has_tool("tool_common") is True
|
||||
assert registry.has_tool("tool_hive") is True
|
||||
assert registry.has_tool("tool_coder") is True
|
||||
|
||||
assert registry.get_server_tool_names("s1") == {"tool_common", "tool_hive"}
|
||||
assert registry.get_server_tool_names("s2") == {"tool_coder"}
|
||||
|
||||
|
||||
def test_registry_precedence_over_existing_mcp_servers(monkeypatch):
|
||||
"""Registry-loaded tools should not overwrite already registered MCP tools."""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"pre": ["tool_common", "tool_pre"],
|
||||
"s1": ["tool_common", "tool_hive"],
|
||||
"s2": ["tool_common", "tool_coder"],
|
||||
}
|
||||
|
||||
from framework.runner import mcp_client as mcp_client_mod
|
||||
|
||||
class FakeMCPClient:
|
||||
def __init__(self, config: Any):
|
||||
self.config = config
|
||||
|
||||
def connect(self) -> None:
|
||||
return
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return
|
||||
|
||||
def list_tools(self) -> list[MCPTool]:
|
||||
names = tool_map.get(self.config.name, [])
|
||||
return [_make_tool(n, self.config.name) for n in names]
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
monkeypatch.setattr(mcp_client_mod, "MCPClient", FakeMCPClient)
|
||||
|
||||
from framework.runner import mcp_registry_resolver as resolver_mod
|
||||
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
resolver_mod,
|
||||
"resolve_registry_servers",
|
||||
lambda **kwargs: resolved_servers,
|
||||
)
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.register_mcp_server(
|
||||
{"name": "pre", "transport": "stdio", "command": "fake", "args": [], "cwd": None}
|
||||
)
|
||||
|
||||
added = registry.load_registry_servers(include=None, tags=None, exclude=None, profile=None)
|
||||
|
||||
assert added == 2 # only tool_hive + tool_coder; tool_common is preserved from "pre"
|
||||
assert registry.get_server_tool_names("pre") == {"tool_common", "tool_pre"}
|
||||
assert registry.get_server_tool_names("s1") == {"tool_hive"}
|
||||
assert registry.get_server_tool_names("s2") == {"tool_coder"}
|
||||
|
||||
|
||||
def test_registry_max_tools_cap(monkeypatch):
|
||||
"""max_tools caps the total number of newly added tools from registry servers."""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"s1": ["tool_a", "tool_b"],
|
||||
"s2": ["tool_c"],
|
||||
}
|
||||
|
||||
from framework.runner import mcp_client as mcp_client_mod
|
||||
|
||||
class FakeMCPClient:
|
||||
def __init__(self, config: Any):
|
||||
self.config = config
|
||||
|
||||
def connect(self) -> None:
|
||||
return
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return
|
||||
|
||||
def list_tools(self) -> list[MCPTool]:
|
||||
names = tool_map.get(self.config.name, [])
|
||||
return [_make_tool(n, self.config.name) for n in names]
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
monkeypatch.setattr(mcp_client_mod, "MCPClient", FakeMCPClient)
|
||||
|
||||
from framework.runner import mcp_registry_resolver as resolver_mod
|
||||
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
resolver_mod,
|
||||
"resolve_registry_servers",
|
||||
lambda **kwargs: resolved_servers,
|
||||
)
|
||||
|
||||
registry = ToolRegistry()
|
||||
added = registry.load_registry_servers(
|
||||
include=None,
|
||||
tags=None,
|
||||
exclude=None,
|
||||
profile=None,
|
||||
max_tools=2,
|
||||
)
|
||||
|
||||
assert added == 2
|
||||
assert registry.has_tool("tool_a") is True
|
||||
assert registry.has_tool("tool_b") is True
|
||||
assert registry.has_tool("tool_c") is False
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"include": ["jira"],
|
||||
"max_tools": 10
|
||||
}
|
||||
Reference in New Issue
Block a user