Merge pull request #6792 from fermano/feat/agent-selection-tool-resolution-n-framework-integration
Feat/agent selection tool resolution n framework integration
This commit is contained in:
+4
-2
@@ -13,6 +13,10 @@ out/
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
.venv
|
||||
/venv
|
||||
tools/src/uv.lock
|
||||
|
||||
|
||||
# User configuration (copied from .example)
|
||||
config.yaml
|
||||
@@ -69,8 +73,6 @@ exports/*
|
||||
|
||||
.claude/settings.local.json
|
||||
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
|
||||
|
||||
@@ -584,11 +584,19 @@ class CredentialTesterAgent:
|
||||
self._tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
try:
|
||||
agent_dir = Path(__file__).parent
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
registry_configs = registry.load_agent_selection(Path(__file__).parent)
|
||||
if (agent_dir / "mcp_registry.json").is_file():
|
||||
self._tool_registry.set_mcp_registry_agent_path(agent_dir)
|
||||
registry_configs, selection_max_tools = registry.load_agent_selection(agent_dir)
|
||||
if registry_configs:
|
||||
self._tool_registry.load_registry_servers(registry_configs)
|
||||
self._tool_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("MCP registry config failed to load", exc_info=True)
|
||||
|
||||
|
||||
@@ -401,14 +401,16 @@ class MCPRegistry:
|
||||
|
||||
# ── load_agent_selection ────────────────────────────────────────
|
||||
|
||||
def load_agent_selection(self, agent_path: Path) -> list[dict[str, Any]]:
|
||||
def load_agent_selection(self, agent_path: Path) -> tuple[list[dict[str, Any]], int | None]:
|
||||
"""Load mcp_registry.json from an agent directory and resolve servers.
|
||||
|
||||
Returns list of plain dicts compatible with ToolRegistry.register_mcp_server().
|
||||
Returns:
|
||||
(server_config_dicts, max_tools) for :meth:`ToolRegistry.load_registry_servers`.
|
||||
``max_tools`` is ``None`` when omitted or invalid in JSON.
|
||||
"""
|
||||
registry_json_path = agent_path / "mcp_registry.json"
|
||||
if not registry_json_path.exists():
|
||||
return []
|
||||
return [], None
|
||||
|
||||
selection = json.loads(registry_json_path.read_text(encoding="utf-8"))
|
||||
|
||||
@@ -437,15 +439,16 @@ class MCPRegistry:
|
||||
continue
|
||||
validated[field] = value
|
||||
|
||||
max_tools = validated.get("max_tools")
|
||||
configs = self.resolve_for_agent(
|
||||
include=validated.get("include"),
|
||||
tags=validated.get("tags"),
|
||||
exclude=validated.get("exclude"),
|
||||
profile=validated.get("profile"),
|
||||
max_tools=validated.get("max_tools"),
|
||||
max_tools=max_tools,
|
||||
versions=validated.get("versions"),
|
||||
)
|
||||
return [self._server_config_to_dict(c) for c in configs]
|
||||
return [self._server_config_to_dict(c) for c in configs], max_tools
|
||||
|
||||
# ── resolve_for_agent ───────────────────────────────────────────
|
||||
|
||||
@@ -552,12 +555,14 @@ class MCPRegistry:
|
||||
)
|
||||
continue
|
||||
|
||||
# Check tool count cap before adding (FR-56)
|
||||
# Check tool count cap before adding (FR-56), using manifest tool list when present.
|
||||
# When ``tools`` is empty (e.g. ``add_local``), counts are unknown here—callers should
|
||||
# pass the same ``max_tools`` to ToolRegistry.load_registry_servers to cap registration.
|
||||
manifest_tools = manifest.get("tools", [])
|
||||
server_tool_count = len(manifest_tools)
|
||||
if max_tools is not None and server_tool_count == 0:
|
||||
logger.debug(
|
||||
"Server '%s' has no declared tools in manifest, skipping max_tools check",
|
||||
"Server '%s' has no tools list in manifest; max_tools enforced at registration",
|
||||
name,
|
||||
)
|
||||
elif max_tools is not None and total_tools + server_tool_count > max_tools:
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CACHE_INDEX_PATH = Path.home() / ".hive" / "mcp_registry" / "cache" / "registry_index.json"
|
||||
_FIXTURE_INDEX_PATH = Path(__file__).resolve().parent / "fixtures" / "registry_index.json"
|
||||
|
||||
|
||||
def resolve_registry_servers(
|
||||
*,
|
||||
include: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
profile: str | None = None,
|
||||
max_tools: int | None = None,
|
||||
versions: dict[str, str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Resolve registry-sourced MCP servers for `mcp_registry.json` selection.
|
||||
|
||||
This function is written to be mock-friendly during early development:
|
||||
- If the real `MCPRegistry` core module is present, delegate to it.
|
||||
- Otherwise, fall back to a cached local index (`~/.hive/.../registry_index.json`)
|
||||
and then to the repo fixture index.
|
||||
"""
|
||||
|
||||
# `max_tools` is enforced by ToolRegistry. We keep it in the resolver
|
||||
# signature to match the PRD and future MCPRegistry interfaces.
|
||||
_ = max_tools
|
||||
|
||||
try:
|
||||
from framework.runner.mcp_registry import MCPRegistry # type: ignore
|
||||
|
||||
registry = MCPRegistry()
|
||||
resolved = registry.resolve_for_agent(
|
||||
include=include or [],
|
||||
tags=tags or [],
|
||||
exclude=exclude or [],
|
||||
profile=profile,
|
||||
max_tools=max_tools,
|
||||
versions=versions or {},
|
||||
)
|
||||
# Future-proof: normalize both dicts and typed objects to dicts.
|
||||
return [_normalize_server_config(x) for x in resolved]
|
||||
except ImportError:
|
||||
# Expected while #6349/#6574 is not merged locally.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("MCPRegistry resolution failed; falling back to cache/fixtures: %s", e)
|
||||
|
||||
return _resolve_from_local_index(
|
||||
include=include,
|
||||
tags=tags,
|
||||
exclude=exclude,
|
||||
profile=profile,
|
||||
versions=versions or {},
|
||||
)
|
||||
|
||||
|
||||
def _resolve_from_local_index(
|
||||
*,
|
||||
include: list[str] | None,
|
||||
tags: list[str] | None,
|
||||
exclude: list[str] | None,
|
||||
profile: str | None,
|
||||
versions: dict[str, str],
|
||||
) -> list[dict[str, Any]]:
|
||||
index = _load_index_json()
|
||||
servers = _coerce_index_servers(index)
|
||||
servers_by_name: dict[str, dict[str, Any]] = {
|
||||
s["name"]: s for s in servers if isinstance(s, dict) and "name" in s
|
||||
}
|
||||
|
||||
include_list = include or []
|
||||
tags_list = tags or []
|
||||
exclude_set = set(exclude or [])
|
||||
|
||||
def _profiles_of(entry: dict[str, Any]) -> set[str]:
|
||||
if isinstance(entry.get("profiles"), list):
|
||||
return set(entry["profiles"])
|
||||
hive = entry.get("hive")
|
||||
if isinstance(hive, dict) and isinstance(hive.get("profiles"), list):
|
||||
return set(hive["profiles"])
|
||||
return set()
|
||||
|
||||
def _tags_of(entry: dict[str, Any]) -> set[str]:
|
||||
if isinstance(entry.get("tags"), list):
|
||||
return set(entry["tags"])
|
||||
return set()
|
||||
|
||||
def _entry_version(entry: dict[str, Any]) -> str | None:
|
||||
# Prefer flat `version`, but support a few common shapes.
|
||||
v = entry.get("version")
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
v2 = entry.get("manifest_version")
|
||||
if isinstance(v2, str):
|
||||
return v2
|
||||
hive = entry.get("manifest")
|
||||
if isinstance(hive, dict) and isinstance(hive.get("version"), str):
|
||||
return hive["version"]
|
||||
return None
|
||||
|
||||
def _version_allows(server_name: str) -> bool:
|
||||
if server_name not in versions:
|
||||
return True
|
||||
pinned = versions[server_name]
|
||||
entry = servers_by_name.get(server_name)
|
||||
if not entry:
|
||||
return False
|
||||
return _entry_version(entry) == pinned
|
||||
|
||||
resolved_names: list[str] = []
|
||||
resolved_set: set[str] = set()
|
||||
|
||||
# 1) Include-order first
|
||||
for name in include_list:
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if name in servers_by_name and _version_allows(name) and name not in resolved_set:
|
||||
resolved_names.append(name)
|
||||
resolved_set.add(name)
|
||||
|
||||
# 2) Then tag/profile matches, alphabetical
|
||||
profile_candidates = set()
|
||||
if profile:
|
||||
for name, entry in servers_by_name.items():
|
||||
if name in exclude_set or not _version_allows(name):
|
||||
continue
|
||||
if profile in _profiles_of(entry):
|
||||
profile_candidates.add(name)
|
||||
|
||||
tag_candidates = set()
|
||||
if tags_list:
|
||||
tags_set = set(tags_list)
|
||||
for name, entry in servers_by_name.items():
|
||||
if name in exclude_set or not _version_allows(name):
|
||||
continue
|
||||
if _tags_of(entry).intersection(tags_set):
|
||||
tag_candidates.add(name)
|
||||
|
||||
tag_profile_names = sorted((profile_candidates | tag_candidates) - resolved_set)
|
||||
resolved_names.extend(tag_profile_names)
|
||||
|
||||
# Missing requested servers should warn (FR-54).
|
||||
for name in include_list:
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if name not in resolved_set:
|
||||
if name not in servers_by_name:
|
||||
logger.warning(
|
||||
"Server '%s' requested by mcp_registry.json but not found in index. "
|
||||
"Run: hive mcp install %s",
|
||||
name,
|
||||
name,
|
||||
)
|
||||
elif name in versions:
|
||||
logger.warning(
|
||||
"Server '%s' was requested but pinned version '%s' was not found in index. "
|
||||
"Run: hive mcp update %s or change the pin in mcp_registry.json",
|
||||
name,
|
||||
versions[name],
|
||||
name,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Server '%s' requested by mcp_registry.json was not selected. "
|
||||
"Check selection filters/exclude lists.",
|
||||
name,
|
||||
)
|
||||
|
||||
resolved_configs: list[dict[str, Any]] = []
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
for name in resolved_names:
|
||||
entry = servers_by_name.get(name)
|
||||
if not entry:
|
||||
continue
|
||||
config = entry.get("mcp_config")
|
||||
if not isinstance(config, dict):
|
||||
# Best-effort: allow a direct MCP config shape at top-level.
|
||||
config = {
|
||||
k: v
|
||||
for k, v in entry.items()
|
||||
if k
|
||||
in {
|
||||
"name",
|
||||
"transport",
|
||||
"command",
|
||||
"args",
|
||||
"env",
|
||||
"cwd",
|
||||
"url",
|
||||
"headers",
|
||||
"description",
|
||||
}
|
||||
}
|
||||
mcp_config = dict(config)
|
||||
mcp_config["name"] = name
|
||||
if mcp_config.get("transport") == "stdio":
|
||||
_absolutize_stdio_config_in_place(repo_root, mcp_config)
|
||||
resolved_configs.append(mcp_config)
|
||||
|
||||
return resolved_configs
|
||||
|
||||
|
||||
def _load_index_json() -> Any:
|
||||
if _CACHE_INDEX_PATH.exists():
|
||||
return json.loads(_CACHE_INDEX_PATH.read_text(encoding="utf-8"))
|
||||
if _FIXTURE_INDEX_PATH.exists():
|
||||
logger.info("Using local fixture index because registry cache is missing")
|
||||
return json.loads(_FIXTURE_INDEX_PATH.read_text(encoding="utf-8"))
|
||||
logger.warning("No local MCP registry index found (cache and fixture missing)")
|
||||
return {"servers": []}
|
||||
|
||||
|
||||
def _coerce_index_servers(index: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(index, list):
|
||||
return [x for x in index if isinstance(x, dict)]
|
||||
if isinstance(index, dict):
|
||||
servers = index.get("servers", [])
|
||||
if isinstance(servers, list):
|
||||
return [x for x in servers if isinstance(x, dict)]
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_server_config(raw: Any) -> dict[str, Any]:
|
||||
if isinstance(raw, dict):
|
||||
return dict(raw)
|
||||
|
||||
# Future-proof object-to-dict normalization.
|
||||
for attr in ("to_dict", "model_dump"):
|
||||
maybe = getattr(raw, attr, None)
|
||||
if callable(maybe):
|
||||
return dict(maybe())
|
||||
|
||||
return dict(getattr(raw, "__dict__", {}))
|
||||
|
||||
|
||||
def _absolutize_stdio_config_in_place(repo_root: Path, config: dict[str, Any]) -> None:
|
||||
cwd = config.get("cwd")
|
||||
if isinstance(cwd, str) and not Path(cwd).is_absolute():
|
||||
config["cwd"] = str((repo_root / cwd).resolve())
|
||||
|
||||
# We intentionally do not absolutize `args` here.
|
||||
# For stdio servers, arguments may include the script name relative to
|
||||
# `cwd` (e.g. "coder_tools_server.py" with cwd="tools"). ToolRegistry's
|
||||
# stdio resolution logic handles script path checks and platform quirks.
|
||||
@@ -1429,12 +1429,18 @@ class AgentRunner:
|
||||
|
||||
def _load_registry_mcp_servers(self, agent_path: Path) -> None:
|
||||
"""Load and register MCP servers selected via ``mcp_registry.json``."""
|
||||
registry_json = agent_path / "mcp_registry.json"
|
||||
if registry_json.is_file():
|
||||
self._tool_registry.set_mcp_registry_agent_path(agent_path)
|
||||
else:
|
||||
self._tool_registry.set_mcp_registry_agent_path(None)
|
||||
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
try:
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
server_configs = registry.load_agent_selection(agent_path)
|
||||
server_configs, selection_max_tools = registry.load_agent_selection(agent_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to load MCP registry servers for '%s': %s",
|
||||
@@ -1446,7 +1452,12 @@ class AgentRunner:
|
||||
if not server_configs:
|
||||
return
|
||||
|
||||
results = self._tool_registry.load_registry_servers(server_configs)
|
||||
results = self._tool_registry.load_registry_servers(
|
||||
server_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
loaded = [result for result in results if result["status"] == "loaded"]
|
||||
skipped = [result for result in results if result["status"] != "loaded"]
|
||||
|
||||
|
||||
@@ -66,6 +66,8 @@ class ToolRegistry:
|
||||
self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time
|
||||
self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time
|
||||
self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names
|
||||
# Agent dir for re-loading registry MCP after credential resync.
|
||||
self._mcp_registry_agent_path: Path | None = None
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -490,7 +492,13 @@ class ToolRegistry:
|
||||
self._resolve_mcp_server_config(server_config, base_dir)
|
||||
for server_config in server_list
|
||||
]
|
||||
self.load_registry_servers(resolved_server_list, log_summary=False)
|
||||
# Ordered first-wins for duplicate tool names across servers; keep tools.py tools.
|
||||
self.load_registry_servers(
|
||||
resolved_server_list,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=False,
|
||||
)
|
||||
|
||||
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
|
||||
self._mcp_cred_snapshot = self._snapshot_credentials()
|
||||
@@ -499,6 +507,10 @@ class ToolRegistry:
|
||||
def _register_mcp_server_with_retry(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
*,
|
||||
preserve_existing_tools: bool = True,
|
||||
tool_cap: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> tuple[bool, int, str | None]:
|
||||
"""Register a single MCP server with one retry for transient failures."""
|
||||
name = server_config.get("name", "unknown")
|
||||
@@ -506,7 +518,12 @@ class ToolRegistry:
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
count = self.register_mcp_server(server_config)
|
||||
count = self.register_mcp_server(
|
||||
server_config,
|
||||
preserve_existing_tools=preserve_existing_tools,
|
||||
tool_cap=tool_cap,
|
||||
log_collisions=log_collisions,
|
||||
)
|
||||
if count > 0:
|
||||
return True, count, None
|
||||
last_error = "registered 0 tools"
|
||||
@@ -532,13 +549,38 @@ class ToolRegistry:
|
||||
server_list: list[dict[str, Any]],
|
||||
*,
|
||||
log_summary: bool = True,
|
||||
preserve_existing_tools: bool = True,
|
||||
max_tools: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Register resolved registry-selected MCP servers with retry and status tracking."""
|
||||
"""Register MCP servers from a resolved config list (registry and/or static).
|
||||
|
||||
``preserve_existing_tools`` enforces first-wins tool names (FR-100): later
|
||||
servers skip names already taken— including tools from ``mcp_servers.json``
|
||||
or ``tools.py`` when those were loaded first.
|
||||
|
||||
``max_tools`` caps how many *new* tool names are registered across this batch
|
||||
(collisions do not consume the cap). When ``log_collisions`` is True, skipped
|
||||
duplicate names emit a warning (FR-101).
|
||||
"""
|
||||
results: list[dict[str, Any]] = []
|
||||
tools_added_batch = 0
|
||||
|
||||
for server_config in server_list:
|
||||
remaining: int | None = None
|
||||
if max_tools is not None:
|
||||
remaining = max_tools - tools_added_batch
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
name = server_config.get("name", "unknown")
|
||||
success, tools_loaded, error = self._register_mcp_server_with_retry(server_config)
|
||||
success, tools_loaded, error = self._register_mcp_server_with_retry(
|
||||
server_config,
|
||||
preserve_existing_tools=preserve_existing_tools,
|
||||
tool_cap=remaining,
|
||||
log_collisions=log_collisions,
|
||||
)
|
||||
tools_added_batch += tools_loaded
|
||||
result = {
|
||||
"server": name,
|
||||
"status": "loaded" if success else "skipped",
|
||||
@@ -565,6 +607,10 @@ class ToolRegistry:
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
use_connection_manager: bool = True,
|
||||
*,
|
||||
preserve_existing_tools: bool = True,
|
||||
tool_cap: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Register an MCP server and discover its tools.
|
||||
@@ -581,6 +627,9 @@ class ToolRegistry:
|
||||
- headers: HTTP headers (for http)
|
||||
- description: Server description (optional)
|
||||
use_connection_manager: When True, reuse a shared client keyed by server name
|
||||
preserve_existing_tools: If True, do not replace tools already in the registry.
|
||||
tool_cap: Max tools to newly register from this server (None = unlimited).
|
||||
log_collisions: If True, log when this server skips a tool name already taken.
|
||||
|
||||
Returns:
|
||||
Number of tools registered from this server
|
||||
@@ -623,6 +672,23 @@ class ToolRegistry:
|
||||
self._mcp_server_tools[server_name] = set()
|
||||
count = 0
|
||||
for mcp_tool in client.list_tools():
|
||||
if tool_cap is not None and count >= tool_cap:
|
||||
break
|
||||
|
||||
if preserve_existing_tools and mcp_tool.name in self._tools:
|
||||
if log_collisions:
|
||||
origin_server = (
|
||||
self._find_mcp_origin_server_for_tool(mcp_tool.name) or "<existing>"
|
||||
)
|
||||
logger.warning(
|
||||
"MCP tool '%s' from '%s' shadowed by '%s' (loaded first)",
|
||||
mcp_tool.name,
|
||||
server_name,
|
||||
origin_server,
|
||||
)
|
||||
# Skip registration; do not update MCP tool bookkeeping for this server.
|
||||
continue
|
||||
|
||||
# Convert MCP tool to framework Tool (strips context params from LLM schema)
|
||||
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
|
||||
|
||||
@@ -700,6 +766,12 @@ class ToolRegistry:
|
||||
)
|
||||
return 0
|
||||
|
||||
def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None:
|
||||
for server_name, tool_names in self._mcp_server_tools.items():
|
||||
if tool_name in tool_names:
|
||||
return server_name
|
||||
return None
|
||||
|
||||
def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool:
|
||||
"""
|
||||
Convert an MCP tool to a framework Tool.
|
||||
@@ -787,6 +859,37 @@ class ToolRegistry:
|
||||
# MCP credential resync
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_mcp_registry_agent_path(self, agent_path: Path | None) -> None:
|
||||
"""Remember agent dir so registry MCP servers reload after credential resync."""
|
||||
self._mcp_registry_agent_path = None if agent_path is None else Path(agent_path)
|
||||
|
||||
def reload_registry_mcp_servers_after_resync(self) -> None:
|
||||
"""Re-run ``mcp_registry.json`` resolution and register servers (post-resync)."""
|
||||
if self._mcp_registry_agent_path is None:
|
||||
return
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
try:
|
||||
reg = MCPRegistry()
|
||||
reg.initialize()
|
||||
configs, selection_max_tools = reg.load_agent_selection(self._mcp_registry_agent_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to reload MCP registry servers after resync for '%s': %s",
|
||||
self._mcp_registry_agent_path.name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
if not configs:
|
||||
return
|
||||
self.load_registry_servers(
|
||||
configs,
|
||||
log_summary=True,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
|
||||
def _snapshot_credentials(self) -> set[str]:
|
||||
"""Return the set of credential filenames currently on disk."""
|
||||
try:
|
||||
@@ -832,9 +935,12 @@ class ToolRegistry:
|
||||
for name in self._mcp_tool_names:
|
||||
self._tools.pop(name, None)
|
||||
self._mcp_tool_names.clear()
|
||||
self._mcp_server_tools.clear()
|
||||
|
||||
# 3. Re-load MCP servers (spawns fresh subprocesses with new credentials)
|
||||
self.load_mcp_config(self._mcp_config_path)
|
||||
if self._mcp_registry_agent_path is not None:
|
||||
self.reload_registry_mcp_servers_after_resync()
|
||||
|
||||
logger.info("MCP server resync complete")
|
||||
return True
|
||||
|
||||
@@ -90,9 +90,16 @@ async def create_queen(
|
||||
try:
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
registry_configs = registry.load_agent_selection(queen_pkg_dir)
|
||||
if (queen_pkg_dir / "mcp_registry.json").is_file():
|
||||
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
|
||||
registry_configs, selection_max_tools = registry.load_agent_selection(queen_pkg_dir)
|
||||
if registry_configs:
|
||||
results = queen_registry.load_registry_servers(registry_configs)
|
||||
results = queen_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
logger.info("Queen: loaded MCP registry servers: %s", results)
|
||||
except Exception:
|
||||
logger.warning("Queen: MCP registry config failed to load", exc_info=True)
|
||||
|
||||
@@ -32,7 +32,7 @@ class _FakeRegistry:
|
||||
|
||||
def load_agent_selection(self, agent_path: Path):
|
||||
self.loaded_paths.append(agent_path)
|
||||
return list(self._returned_configs)
|
||||
return list(self._returned_configs), None
|
||||
|
||||
|
||||
def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
|
||||
@@ -61,7 +61,7 @@ def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
|
||||
lambda self, server_config, use_connection_manager=True: (
|
||||
lambda self, server_config, use_connection_manager=True, **kwargs: (
|
||||
registered.append(server_config) or 1
|
||||
),
|
||||
)
|
||||
@@ -95,7 +95,7 @@ def test_agent_runner_skips_registry_when_no_servers_selected(tmp_path, monkeypa
|
||||
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
|
||||
lambda self, server_config, use_connection_manager=True: (
|
||||
lambda self, server_config, use_connection_manager=True, **kwargs: (
|
||||
registered.append(server_config) or 1
|
||||
),
|
||||
)
|
||||
@@ -135,7 +135,7 @@ def test_agent_runner_logs_actual_registry_load_results(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.tool_registry.ToolRegistry.load_registry_servers",
|
||||
lambda self, server_configs: [
|
||||
lambda self, server_configs, **kwargs: [
|
||||
{"server": "jira", "status": "loaded", "tools_loaded": 2, "skipped_reason": None},
|
||||
{
|
||||
"server": "slack",
|
||||
@@ -223,7 +223,7 @@ def test_integration_real_registry_to_agent_runner(tmp_path, monkeypatch):
|
||||
registered: list[dict] = []
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
|
||||
lambda self, server_config, use_connection_manager=True: (
|
||||
lambda self, server_config, use_connection_manager=True, **kwargs: (
|
||||
registered.append(server_config) or 1
|
||||
),
|
||||
)
|
||||
|
||||
@@ -619,8 +619,9 @@ def test_load_agent_selection(tmp_path: Path):
|
||||
agent_dir = tmp_path / "agent"
|
||||
agent_dir.mkdir()
|
||||
(agent_dir / "mcp_registry.json").write_text(json.dumps({"include": ["jira", "slack"]}))
|
||||
dicts = registry.load_agent_selection(agent_dir)
|
||||
assert len(dicts) == 2 and all("transport" in d for d in dicts)
|
||||
dicts, max_tools = registry.load_agent_selection(agent_dir)
|
||||
assert len(dicts) == 2 and max_tools is None
|
||||
assert all("transport" in d for d in dicts)
|
||||
|
||||
|
||||
def test_load_agent_selection_no_file(tmp_path: Path):
|
||||
@@ -628,7 +629,7 @@ def test_load_agent_selection_no_file(tmp_path: Path):
|
||||
registry.initialize()
|
||||
agent_dir = tmp_path / "agent"
|
||||
agent_dir.mkdir()
|
||||
assert registry.load_agent_selection(agent_dir) == []
|
||||
assert registry.load_agent_selection(agent_dir) == ([], None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -648,9 +649,9 @@ def test_load_agent_selection_rejects_wrong_types(tmp_path: Path, field, bad_val
|
||||
agent_dir = tmp_path / "agent"
|
||||
agent_dir.mkdir()
|
||||
(agent_dir / "mcp_registry.json").write_text(json.dumps({field: bad_value}))
|
||||
configs = registry.load_agent_selection(agent_dir)
|
||||
configs, max_tools = registry.load_agent_selection(agent_dir)
|
||||
# All bad fields are dropped, so resolve_for_agent gets no criteria and returns []
|
||||
assert configs == []
|
||||
assert configs == [] and max_tools is None
|
||||
|
||||
|
||||
# ── run_health_check ────────────────────────────────────────────────
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from framework.runner.mcp_client import MCPTool
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
|
||||
|
||||
def _patch_connection_manager_for_fake_stdio(monkeypatch, tool_map: dict[str, list[str]]) -> None:
|
||||
"""Avoid spawning real stdio MCP processes; return in-memory clients per server name."""
|
||||
|
||||
class FakeMCPClient:
|
||||
def __init__(self, config: Any):
|
||||
self.config = config
|
||||
|
||||
def connect(self) -> None:
|
||||
return
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return
|
||||
|
||||
def list_tools(self) -> list[MCPTool]:
|
||||
names = tool_map.get(self.config.name, [])
|
||||
return [_make_tool(n, self.config.name) for n in names]
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
class FakeManager:
|
||||
def acquire(self, config: Any) -> FakeMCPClient:
|
||||
return FakeMCPClient(config)
|
||||
|
||||
def release(self, _server_name: str) -> None:
|
||||
return
|
||||
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
|
||||
lambda: FakeManager(),
|
||||
)
|
||||
|
||||
|
||||
def _make_tool(name: str, server_name: str) -> MCPTool:
|
||||
return MCPTool(
|
||||
name=name,
|
||||
description=f"{name} from {server_name}",
|
||||
input_schema={"type": "object", "properties": {}, "required": []},
|
||||
server_name=server_name,
|
||||
)
|
||||
|
||||
|
||||
def test_registry_first_wins_collisions(monkeypatch):
|
||||
"""
|
||||
When multiple registry servers expose the same tool name, the first server
|
||||
in load order should win and later servers should not overwrite it.
|
||||
"""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"s1": ["tool_common", "tool_hive"],
|
||||
"s2": ["tool_common", "tool_coder"],
|
||||
}
|
||||
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
|
||||
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.load_registry_servers(
|
||||
resolved_servers,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
)
|
||||
|
||||
assert registry.has_tool("tool_common") is True
|
||||
assert registry.has_tool("tool_hive") is True
|
||||
assert registry.has_tool("tool_coder") is True
|
||||
|
||||
assert registry.get_server_tool_names("s1") == {"tool_common", "tool_hive"}
|
||||
assert registry.get_server_tool_names("s2") == {"tool_coder"}
|
||||
|
||||
|
||||
def test_registry_precedence_over_existing_mcp_servers(monkeypatch):
|
||||
"""Registry-loaded tools should not overwrite already registered MCP tools."""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"pre": ["tool_common", "tool_pre"],
|
||||
"s1": ["tool_common", "tool_hive"],
|
||||
"s2": ["tool_common", "tool_coder"],
|
||||
}
|
||||
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
|
||||
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.register_mcp_server(
|
||||
{"name": "pre", "transport": "stdio", "command": "fake", "args": [], "cwd": None}
|
||||
)
|
||||
|
||||
registry.load_registry_servers(
|
||||
resolved_servers,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
)
|
||||
|
||||
assert registry.get_server_tool_names("pre") == {"tool_common", "tool_pre"}
|
||||
assert registry.get_server_tool_names("s1") == {"tool_hive"}
|
||||
assert registry.get_server_tool_names("s2") == {"tool_coder"}
|
||||
|
||||
|
||||
def test_registry_max_tools_cap(monkeypatch):
|
||||
"""max_tools caps the total number of newly added tools from registry servers."""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"s1": ["tool_a", "tool_b"],
|
||||
"s2": ["tool_c"],
|
||||
}
|
||||
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
|
||||
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.load_registry_servers(
|
||||
resolved_servers,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
max_tools=2,
|
||||
)
|
||||
|
||||
assert registry.has_tool("tool_a") is True
|
||||
assert registry.has_tool("tool_b") is True
|
||||
assert registry.has_tool("tool_c") is False
|
||||
@@ -214,7 +214,7 @@ def test_load_registry_servers_retries_when_registration_returns_zero(monkeypatc
|
||||
registry = ToolRegistry()
|
||||
attempts = {"count": 0}
|
||||
|
||||
def fake_register(server_config, use_connection_manager=True):
|
||||
def fake_register(server_config, use_connection_manager=True, **kwargs):
|
||||
attempts["count"] += 1
|
||||
return 0 if attempts["count"] == 1 else 2
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
{ "include": ["hive-tools"] }
|
||||
Reference in New Issue
Block a user