Merge pull request #6792 from fermano/feat/agent-selection-tool-resolution-n-framework-integration

Feat/agent selection tool resolution n framework integration
This commit is contained in:
Timothy @aden
2026-03-31 17:38:39 -07:00
committed by GitHub
12 changed files with 563 additions and 30 deletions
+4 -2
View File
@@ -13,6 +13,10 @@ out/
.env
.env.local
.env.*.local
.venv
/venv
tools/src/uv.lock
# User configuration (copied from .example)
config.yaml
@@ -69,8 +73,6 @@ exports/*
.claude/settings.local.json
.venv
docs/github-issues/*
core/tests/*dumps/*
@@ -584,11 +584,19 @@ class CredentialTesterAgent:
self._tool_registry.load_mcp_config(mcp_config_path)
try:
agent_dir = Path(__file__).parent
registry = MCPRegistry()
registry.initialize()
registry_configs = registry.load_agent_selection(Path(__file__).parent)
if (agent_dir / "mcp_registry.json").is_file():
self._tool_registry.set_mcp_registry_agent_path(agent_dir)
registry_configs, selection_max_tools = registry.load_agent_selection(agent_dir)
if registry_configs:
self._tool_registry.load_registry_servers(registry_configs)
self._tool_registry.load_registry_servers(
registry_configs,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
except Exception:
logger.warning("MCP registry config failed to load", exc_info=True)
+12 -7
View File
@@ -401,14 +401,16 @@ class MCPRegistry:
# ── load_agent_selection ────────────────────────────────────────
def load_agent_selection(self, agent_path: Path) -> list[dict[str, Any]]:
def load_agent_selection(self, agent_path: Path) -> tuple[list[dict[str, Any]], int | None]:
"""Load mcp_registry.json from an agent directory and resolve servers.
Returns list of plain dicts compatible with ToolRegistry.register_mcp_server().
Returns:
(server_config_dicts, max_tools) for :meth:`ToolRegistry.load_registry_servers`.
``max_tools`` is ``None`` when omitted or invalid in JSON.
"""
registry_json_path = agent_path / "mcp_registry.json"
if not registry_json_path.exists():
return []
return [], None
selection = json.loads(registry_json_path.read_text(encoding="utf-8"))
@@ -437,15 +439,16 @@ class MCPRegistry:
continue
validated[field] = value
max_tools = validated.get("max_tools")
configs = self.resolve_for_agent(
include=validated.get("include"),
tags=validated.get("tags"),
exclude=validated.get("exclude"),
profile=validated.get("profile"),
max_tools=validated.get("max_tools"),
max_tools=max_tools,
versions=validated.get("versions"),
)
return [self._server_config_to_dict(c) for c in configs]
return [self._server_config_to_dict(c) for c in configs], max_tools
# ── resolve_for_agent ───────────────────────────────────────────
@@ -552,12 +555,14 @@ class MCPRegistry:
)
continue
# Check tool count cap before adding (FR-56)
# Check tool count cap before adding (FR-56), using manifest tool list when present.
# When ``tools`` is empty (e.g. ``add_local``), counts are unknown here—callers should
# pass the same ``max_tools`` to ToolRegistry.load_registry_servers to cap registration.
manifest_tools = manifest.get("tools", [])
server_tool_count = len(manifest_tools)
if max_tools is not None and server_tool_count == 0:
logger.debug(
"Server '%s' has no declared tools in manifest, skipping max_tools check",
"Server '%s' has no tools list in manifest; max_tools enforced at registration",
name,
)
elif max_tools is not None and total_tools + server_tool_count > max_tools:
@@ -0,0 +1,252 @@
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
_CACHE_INDEX_PATH = Path.home() / ".hive" / "mcp_registry" / "cache" / "registry_index.json"
_FIXTURE_INDEX_PATH = Path(__file__).resolve().parent / "fixtures" / "registry_index.json"
def resolve_registry_servers(
*,
include: list[str] | None = None,
tags: list[str] | None = None,
exclude: list[str] | None = None,
profile: str | None = None,
max_tools: int | None = None,
versions: dict[str, str] | None = None,
) -> list[dict[str, Any]]:
"""
Resolve registry-sourced MCP servers for `mcp_registry.json` selection.
This function is written to be mock-friendly during early development:
- If the real `MCPRegistry` core module is present, delegate to it.
- Otherwise, fall back to a cached local index (`~/.hive/.../registry_index.json`)
and then to the repo fixture index.
"""
# `max_tools` is enforced by ToolRegistry. We keep it in the resolver
# signature to match the PRD and future MCPRegistry interfaces.
_ = max_tools
try:
from framework.runner.mcp_registry import MCPRegistry # type: ignore
registry = MCPRegistry()
resolved = registry.resolve_for_agent(
include=include or [],
tags=tags or [],
exclude=exclude or [],
profile=profile,
max_tools=max_tools,
versions=versions or {},
)
# Future-proof: normalize both dicts and typed objects to dicts.
return [_normalize_server_config(x) for x in resolved]
except ImportError:
# Expected while #6349/#6574 is not merged locally.
pass
except Exception as e:
logger.warning("MCPRegistry resolution failed; falling back to cache/fixtures: %s", e)
return _resolve_from_local_index(
include=include,
tags=tags,
exclude=exclude,
profile=profile,
versions=versions or {},
)
def _resolve_from_local_index(
*,
include: list[str] | None,
tags: list[str] | None,
exclude: list[str] | None,
profile: str | None,
versions: dict[str, str],
) -> list[dict[str, Any]]:
index = _load_index_json()
servers = _coerce_index_servers(index)
servers_by_name: dict[str, dict[str, Any]] = {
s["name"]: s for s in servers if isinstance(s, dict) and "name" in s
}
include_list = include or []
tags_list = tags or []
exclude_set = set(exclude or [])
def _profiles_of(entry: dict[str, Any]) -> set[str]:
if isinstance(entry.get("profiles"), list):
return set(entry["profiles"])
hive = entry.get("hive")
if isinstance(hive, dict) and isinstance(hive.get("profiles"), list):
return set(hive["profiles"])
return set()
def _tags_of(entry: dict[str, Any]) -> set[str]:
if isinstance(entry.get("tags"), list):
return set(entry["tags"])
return set()
def _entry_version(entry: dict[str, Any]) -> str | None:
# Prefer flat `version`, but support a few common shapes.
v = entry.get("version")
if isinstance(v, str):
return v
v2 = entry.get("manifest_version")
if isinstance(v2, str):
return v2
hive = entry.get("manifest")
if isinstance(hive, dict) and isinstance(hive.get("version"), str):
return hive["version"]
return None
def _version_allows(server_name: str) -> bool:
if server_name not in versions:
return True
pinned = versions[server_name]
entry = servers_by_name.get(server_name)
if not entry:
return False
return _entry_version(entry) == pinned
resolved_names: list[str] = []
resolved_set: set[str] = set()
# 1) Include-order first
for name in include_list:
if name in exclude_set:
continue
if name in servers_by_name and _version_allows(name) and name not in resolved_set:
resolved_names.append(name)
resolved_set.add(name)
# 2) Then tag/profile matches, alphabetical
profile_candidates = set()
if profile:
for name, entry in servers_by_name.items():
if name in exclude_set or not _version_allows(name):
continue
if profile in _profiles_of(entry):
profile_candidates.add(name)
tag_candidates = set()
if tags_list:
tags_set = set(tags_list)
for name, entry in servers_by_name.items():
if name in exclude_set or not _version_allows(name):
continue
if _tags_of(entry).intersection(tags_set):
tag_candidates.add(name)
tag_profile_names = sorted((profile_candidates | tag_candidates) - resolved_set)
resolved_names.extend(tag_profile_names)
# Missing requested servers should warn (FR-54).
for name in include_list:
if name in exclude_set:
continue
if name not in resolved_set:
if name not in servers_by_name:
logger.warning(
"Server '%s' requested by mcp_registry.json but not found in index. "
"Run: hive mcp install %s",
name,
name,
)
elif name in versions:
logger.warning(
"Server '%s' was requested but pinned version '%s' was not found in index. "
"Run: hive mcp update %s or change the pin in mcp_registry.json",
name,
versions[name],
name,
)
else:
logger.warning(
"Server '%s' requested by mcp_registry.json was not selected. "
"Check selection filters/exclude lists.",
name,
)
resolved_configs: list[dict[str, Any]] = []
repo_root = Path(__file__).resolve().parents[3]
for name in resolved_names:
entry = servers_by_name.get(name)
if not entry:
continue
config = entry.get("mcp_config")
if not isinstance(config, dict):
# Best-effort: allow a direct MCP config shape at top-level.
config = {
k: v
for k, v in entry.items()
if k
in {
"name",
"transport",
"command",
"args",
"env",
"cwd",
"url",
"headers",
"description",
}
}
mcp_config = dict(config)
mcp_config["name"] = name
if mcp_config.get("transport") == "stdio":
_absolutize_stdio_config_in_place(repo_root, mcp_config)
resolved_configs.append(mcp_config)
return resolved_configs
def _load_index_json() -> Any:
if _CACHE_INDEX_PATH.exists():
return json.loads(_CACHE_INDEX_PATH.read_text(encoding="utf-8"))
if _FIXTURE_INDEX_PATH.exists():
logger.info("Using local fixture index because registry cache is missing")
return json.loads(_FIXTURE_INDEX_PATH.read_text(encoding="utf-8"))
logger.warning("No local MCP registry index found (cache and fixture missing)")
return {"servers": []}
def _coerce_index_servers(index: Any) -> list[dict[str, Any]]:
if isinstance(index, list):
return [x for x in index if isinstance(x, dict)]
if isinstance(index, dict):
servers = index.get("servers", [])
if isinstance(servers, list):
return [x for x in servers if isinstance(x, dict)]
return []
def _normalize_server_config(raw: Any) -> dict[str, Any]:
if isinstance(raw, dict):
return dict(raw)
# Future-proof object-to-dict normalization.
for attr in ("to_dict", "model_dump"):
maybe = getattr(raw, attr, None)
if callable(maybe):
return dict(maybe())
return dict(getattr(raw, "__dict__", {}))
def _absolutize_stdio_config_in_place(repo_root: Path, config: dict[str, Any]) -> None:
cwd = config.get("cwd")
if isinstance(cwd, str) and not Path(cwd).is_absolute():
config["cwd"] = str((repo_root / cwd).resolve())
# We intentionally do not absolutize `args` here.
# For stdio servers, arguments may include the script name relative to
# `cwd` (e.g. "coder_tools_server.py" with cwd="tools"). ToolRegistry's
# stdio resolution logic handles script path checks and platform quirks.
+13 -2
View File
@@ -1429,12 +1429,18 @@ class AgentRunner:
def _load_registry_mcp_servers(self, agent_path: Path) -> None:
"""Load and register MCP servers selected via ``mcp_registry.json``."""
registry_json = agent_path / "mcp_registry.json"
if registry_json.is_file():
self._tool_registry.set_mcp_registry_agent_path(agent_path)
else:
self._tool_registry.set_mcp_registry_agent_path(None)
from framework.runner.mcp_registry import MCPRegistry
try:
registry = MCPRegistry()
registry.initialize()
server_configs = registry.load_agent_selection(agent_path)
server_configs, selection_max_tools = registry.load_agent_selection(agent_path)
except Exception as exc:
logger.warning(
"Failed to load MCP registry servers for '%s': %s",
@@ -1446,7 +1452,12 @@ class AgentRunner:
if not server_configs:
return
results = self._tool_registry.load_registry_servers(server_configs)
results = self._tool_registry.load_registry_servers(
server_configs,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
loaded = [result for result in results if result["status"] == "loaded"]
skipped = [result for result in results if result["status"] != "loaded"]
+110 -4
View File
@@ -66,6 +66,8 @@ class ToolRegistry:
self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time
self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time
self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names
# Agent dir for re-loading registry MCP after credential resync.
self._mcp_registry_agent_path: Path | None = None
def register(
self,
@@ -490,7 +492,13 @@ class ToolRegistry:
self._resolve_mcp_server_config(server_config, base_dir)
for server_config in server_list
]
self.load_registry_servers(resolved_server_list, log_summary=False)
# Ordered first-wins for duplicate tool names across servers; keep tools.py tools.
self.load_registry_servers(
resolved_server_list,
log_summary=False,
preserve_existing_tools=True,
log_collisions=False,
)
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
self._mcp_cred_snapshot = self._snapshot_credentials()
@@ -499,6 +507,10 @@ class ToolRegistry:
def _register_mcp_server_with_retry(
self,
server_config: dict[str, Any],
*,
preserve_existing_tools: bool = True,
tool_cap: int | None = None,
log_collisions: bool = False,
) -> tuple[bool, int, str | None]:
"""Register a single MCP server with one retry for transient failures."""
name = server_config.get("name", "unknown")
@@ -506,7 +518,12 @@ class ToolRegistry:
for attempt in range(2):
try:
count = self.register_mcp_server(server_config)
count = self.register_mcp_server(
server_config,
preserve_existing_tools=preserve_existing_tools,
tool_cap=tool_cap,
log_collisions=log_collisions,
)
if count > 0:
return True, count, None
last_error = "registered 0 tools"
@@ -532,13 +549,38 @@ class ToolRegistry:
server_list: list[dict[str, Any]],
*,
log_summary: bool = True,
preserve_existing_tools: bool = True,
max_tools: int | None = None,
log_collisions: bool = False,
) -> list[dict[str, Any]]:
"""Register resolved registry-selected MCP servers with retry and status tracking."""
"""Register MCP servers from a resolved config list (registry and/or static).
``preserve_existing_tools`` enforces first-wins tool names (FR-100): later
servers skip names already taken including tools from ``mcp_servers.json``
or ``tools.py`` when those were loaded first.
``max_tools`` caps how many *new* tool names are registered across this batch
(collisions do not consume the cap). When ``log_collisions`` is True, skipped
duplicate names emit a warning (FR-101).
"""
results: list[dict[str, Any]] = []
tools_added_batch = 0
for server_config in server_list:
remaining: int | None = None
if max_tools is not None:
remaining = max_tools - tools_added_batch
if remaining <= 0:
break
name = server_config.get("name", "unknown")
success, tools_loaded, error = self._register_mcp_server_with_retry(server_config)
success, tools_loaded, error = self._register_mcp_server_with_retry(
server_config,
preserve_existing_tools=preserve_existing_tools,
tool_cap=remaining,
log_collisions=log_collisions,
)
tools_added_batch += tools_loaded
result = {
"server": name,
"status": "loaded" if success else "skipped",
@@ -565,6 +607,10 @@ class ToolRegistry:
self,
server_config: dict[str, Any],
use_connection_manager: bool = True,
*,
preserve_existing_tools: bool = True,
tool_cap: int | None = None,
log_collisions: bool = False,
) -> int:
"""
Register an MCP server and discover its tools.
@@ -581,6 +627,9 @@ class ToolRegistry:
- headers: HTTP headers (for http)
- description: Server description (optional)
use_connection_manager: When True, reuse a shared client keyed by server name
preserve_existing_tools: If True, do not replace tools already in the registry.
tool_cap: Max tools to newly register from this server (None = unlimited).
log_collisions: If True, log when this server skips a tool name already taken.
Returns:
Number of tools registered from this server
@@ -623,6 +672,23 @@ class ToolRegistry:
self._mcp_server_tools[server_name] = set()
count = 0
for mcp_tool in client.list_tools():
if tool_cap is not None and count >= tool_cap:
break
if preserve_existing_tools and mcp_tool.name in self._tools:
if log_collisions:
origin_server = (
self._find_mcp_origin_server_for_tool(mcp_tool.name) or "<existing>"
)
logger.warning(
"MCP tool '%s' from '%s' shadowed by '%s' (loaded first)",
mcp_tool.name,
server_name,
origin_server,
)
# Skip registration; do not update MCP tool bookkeeping for this server.
continue
# Convert MCP tool to framework Tool (strips context params from LLM schema)
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
@@ -700,6 +766,12 @@ class ToolRegistry:
)
return 0
def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None:
for server_name, tool_names in self._mcp_server_tools.items():
if tool_name in tool_names:
return server_name
return None
def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool:
"""
Convert an MCP tool to a framework Tool.
@@ -787,6 +859,37 @@ class ToolRegistry:
# MCP credential resync
# ------------------------------------------------------------------
def set_mcp_registry_agent_path(self, agent_path: Path | None) -> None:
"""Remember agent dir so registry MCP servers reload after credential resync."""
self._mcp_registry_agent_path = None if agent_path is None else Path(agent_path)
def reload_registry_mcp_servers_after_resync(self) -> None:
"""Re-run ``mcp_registry.json`` resolution and register servers (post-resync)."""
if self._mcp_registry_agent_path is None:
return
from framework.runner.mcp_registry import MCPRegistry
try:
reg = MCPRegistry()
reg.initialize()
configs, selection_max_tools = reg.load_agent_selection(self._mcp_registry_agent_path)
except Exception as exc:
logger.warning(
"Failed to reload MCP registry servers after resync for '%s': %s",
self._mcp_registry_agent_path.name,
exc,
)
return
if not configs:
return
self.load_registry_servers(
configs,
log_summary=True,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
def _snapshot_credentials(self) -> set[str]:
"""Return the set of credential filenames currently on disk."""
try:
@@ -832,9 +935,12 @@ class ToolRegistry:
for name in self._mcp_tool_names:
self._tools.pop(name, None)
self._mcp_tool_names.clear()
self._mcp_server_tools.clear()
# 3. Re-load MCP servers (spawns fresh subprocesses with new credentials)
self.load_mcp_config(self._mcp_config_path)
if self._mcp_registry_agent_path is not None:
self.reload_registry_mcp_servers_after_resync()
logger.info("MCP server resync complete")
return True
+9 -2
View File
@@ -90,9 +90,16 @@ async def create_queen(
try:
registry = MCPRegistry()
registry.initialize()
registry_configs = registry.load_agent_selection(queen_pkg_dir)
if (queen_pkg_dir / "mcp_registry.json").is_file():
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
registry_configs, selection_max_tools = registry.load_agent_selection(queen_pkg_dir)
if registry_configs:
results = queen_registry.load_registry_servers(registry_configs)
results = queen_registry.load_registry_servers(
registry_configs,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
logger.info("Queen: loaded MCP registry servers: %s", results)
except Exception:
logger.warning("Queen: MCP registry config failed to load", exc_info=True)
+5 -5
View File
@@ -32,7 +32,7 @@ class _FakeRegistry:
def load_agent_selection(self, agent_path: Path):
self.loaded_paths.append(agent_path)
return list(self._returned_configs)
return list(self._returned_configs), None
def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
@@ -61,7 +61,7 @@ def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
lambda self, server_config, use_connection_manager=True, **kwargs: (
registered.append(server_config) or 1
),
)
@@ -95,7 +95,7 @@ def test_agent_runner_skips_registry_when_no_servers_selected(tmp_path, monkeypa
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
lambda self, server_config, use_connection_manager=True, **kwargs: (
registered.append(server_config) or 1
),
)
@@ -135,7 +135,7 @@ def test_agent_runner_logs_actual_registry_load_results(tmp_path, monkeypatch):
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.load_registry_servers",
lambda self, server_configs: [
lambda self, server_configs, **kwargs: [
{"server": "jira", "status": "loaded", "tools_loaded": 2, "skipped_reason": None},
{
"server": "slack",
@@ -223,7 +223,7 @@ def test_integration_real_registry_to_agent_runner(tmp_path, monkeypatch):
registered: list[dict] = []
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
lambda self, server_config, use_connection_manager=True, **kwargs: (
registered.append(server_config) or 1
),
)
+6 -5
View File
@@ -619,8 +619,9 @@ def test_load_agent_selection(tmp_path: Path):
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "mcp_registry.json").write_text(json.dumps({"include": ["jira", "slack"]}))
dicts = registry.load_agent_selection(agent_dir)
assert len(dicts) == 2 and all("transport" in d for d in dicts)
dicts, max_tools = registry.load_agent_selection(agent_dir)
assert len(dicts) == 2 and max_tools is None
assert all("transport" in d for d in dicts)
def test_load_agent_selection_no_file(tmp_path: Path):
@@ -628,7 +629,7 @@ def test_load_agent_selection_no_file(tmp_path: Path):
registry.initialize()
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
assert registry.load_agent_selection(agent_dir) == []
assert registry.load_agent_selection(agent_dir) == ([], None)
@pytest.mark.parametrize(
@@ -648,9 +649,9 @@ def test_load_agent_selection_rejects_wrong_types(tmp_path: Path, field, bad_val
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "mcp_registry.json").write_text(json.dumps({field: bad_value}))
configs = registry.load_agent_selection(agent_dir)
configs, max_tools = registry.load_agent_selection(agent_dir)
# All bad fields are dropped, so resolve_for_agent gets no criteria and returns []
assert configs == []
assert configs == [] and max_tools is None
# ── run_health_check ────────────────────────────────────────────────
+140
View File
@@ -0,0 +1,140 @@
from __future__ import annotations
from typing import Any
from framework.runner.mcp_client import MCPTool
from framework.runner.tool_registry import ToolRegistry
def _patch_connection_manager_for_fake_stdio(monkeypatch, tool_map: dict[str, list[str]]) -> None:
"""Avoid spawning real stdio MCP processes; return in-memory clients per server name."""
class FakeMCPClient:
def __init__(self, config: Any):
self.config = config
def connect(self) -> None:
return
def disconnect(self) -> None:
return
def list_tools(self) -> list[MCPTool]:
names = tool_map.get(self.config.name, [])
return [_make_tool(n, self.config.name) for n in names]
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
raise NotImplementedError
class FakeManager:
def acquire(self, config: Any) -> FakeMCPClient:
return FakeMCPClient(config)
def release(self, _server_name: str) -> None:
return
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
def _make_tool(name: str, server_name: str) -> MCPTool:
return MCPTool(
name=name,
description=f"{name} from {server_name}",
input_schema={"type": "object", "properties": {}, "required": []},
server_name=server_name,
)
def test_registry_first_wins_collisions(monkeypatch):
"""
When multiple registry servers expose the same tool name, the first server
in load order should win and later servers should not overwrite it.
"""
tool_map: dict[str, list[str]] = {
"s1": ["tool_common", "tool_hive"],
"s2": ["tool_common", "tool_coder"],
}
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
resolved_servers = [
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
]
registry = ToolRegistry()
registry.load_registry_servers(
resolved_servers,
log_summary=False,
preserve_existing_tools=True,
log_collisions=True,
)
assert registry.has_tool("tool_common") is True
assert registry.has_tool("tool_hive") is True
assert registry.has_tool("tool_coder") is True
assert registry.get_server_tool_names("s1") == {"tool_common", "tool_hive"}
assert registry.get_server_tool_names("s2") == {"tool_coder"}
def test_registry_precedence_over_existing_mcp_servers(monkeypatch):
"""Registry-loaded tools should not overwrite already registered MCP tools."""
tool_map: dict[str, list[str]] = {
"pre": ["tool_common", "tool_pre"],
"s1": ["tool_common", "tool_hive"],
"s2": ["tool_common", "tool_coder"],
}
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
resolved_servers = [
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
]
registry = ToolRegistry()
registry.register_mcp_server(
{"name": "pre", "transport": "stdio", "command": "fake", "args": [], "cwd": None}
)
registry.load_registry_servers(
resolved_servers,
log_summary=False,
preserve_existing_tools=True,
log_collisions=True,
)
assert registry.get_server_tool_names("pre") == {"tool_common", "tool_pre"}
assert registry.get_server_tool_names("s1") == {"tool_hive"}
assert registry.get_server_tool_names("s2") == {"tool_coder"}
def test_registry_max_tools_cap(monkeypatch):
"""max_tools caps the total number of newly added tools from registry servers."""
tool_map: dict[str, list[str]] = {
"s1": ["tool_a", "tool_b"],
"s2": ["tool_c"],
}
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
resolved_servers = [
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
]
registry = ToolRegistry()
registry.load_registry_servers(
resolved_servers,
log_summary=False,
preserve_existing_tools=True,
max_tools=2,
)
assert registry.has_tool("tool_a") is True
assert registry.has_tool("tool_b") is True
assert registry.has_tool("tool_c") is False
+1 -1
View File
@@ -214,7 +214,7 @@ def test_load_registry_servers_retries_when_registration_returns_zero(monkeypatc
registry = ToolRegistry()
attempts = {"count": 0}
def fake_register(server_config, use_connection_manager=True):
def fake_register(server_config, use_connection_manager=True, **kwargs):
attempts["count"] += 1
return 0 if attempts["count"] == 1 else 2
@@ -0,0 +1 @@
{ "include": ["hive-tools"] }