diff --git a/.gitignore b/.gitignore index 3a38fd8d..5dc3aadd 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,7 @@ exports/* .claude/settings.local.json .venv +/venv docs/github-issues/* core/tests/*dumps/* diff --git a/core/framework/agents/queen/mcp_registry.json b/core/framework/agents/queen/mcp_registry.json new file mode 100644 index 00000000..591e3a02 --- /dev/null +++ b/core/framework/agents/queen/mcp_registry.json @@ -0,0 +1,4 @@ +{ + "profile": "all" +} + diff --git a/core/framework/runner/fixtures/registry_index.json b/core/framework/runner/fixtures/registry_index.json new file mode 100644 index 00000000..7a891e3e --- /dev/null +++ b/core/framework/runner/fixtures/registry_index.json @@ -0,0 +1,44 @@ +{ + "servers": [ + { + "name": "hive-tools", + "version": "1.0.0", + "tags": ["core", "productivity"], + "profiles": ["all", "core", "productivity"], + "mcp_config": { + "transport": "stdio", + "command": "uv", + "args": ["run", "python", "mcp_server.py", "--stdio"], + "cwd": "tools", + "description": "Hive tools MCP server providing core utilities" + } + }, + { + "name": "coder-tools", + "version": "1.0.0", + "tags": ["coding", "productivity"], + "profiles": ["all", "coding", "productivity"], + "mcp_config": { + "transport": "stdio", + "command": "uv", + "args": ["run", "python", "coder_tools_server.py", "--stdio"], + "cwd": "tools", + "description": "Unsandboxed file/code tools for code generation" + } + }, + { + "name": "tools", + "version": "1.0.0", + "tags": ["web", "productivity"], + "profiles": ["all", "general"], + "mcp_config": { + "transport": "stdio", + "command": "uv", + "args": ["run", "python", "mcp_server.py", "--stdio"], + "cwd": "tools", + "description": "Aden tools MCP server providing web/file utilities" + } + } + ] +} + diff --git a/core/framework/runner/mcp_registry_resolver.py b/core/framework/runner/mcp_registry_resolver.py new file mode 100644 index 00000000..e8e6f7bb --- /dev/null +++ b/core/framework/runner/mcp_registry_resolver.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +_CACHE_INDEX_PATH = Path.home() / ".hive" / "mcp_registry" / "cache" / "registry_index.json" +_FIXTURE_INDEX_PATH = Path(__file__).resolve().parent / "fixtures" / "registry_index.json" + + +def resolve_registry_servers( + *, + include: list[str] | None = None, + tags: list[str] | None = None, + exclude: list[str] | None = None, + profile: str | None = None, + max_tools: int | None = None, + versions: dict[str, str] | None = None, +) -> list[dict[str, Any]]: + """ + Resolve registry-sourced MCP servers for `mcp_registry.json` selection. + + This function is written to be mock-friendly during early development: + - If the real `MCPRegistry` core module is present, delegate to it. + - Otherwise, fall back to a cached local index (`~/.hive/.../registry_index.json`) + and then to the repo fixture index. + """ + + # `max_tools` is enforced by ToolRegistry. We keep it in the resolver + # signature to match the PRD and future MCPRegistry interfaces. + _ = max_tools + + try: + from framework.runner.mcp_registry import MCPRegistry # type: ignore + + registry = MCPRegistry() + resolved = registry.resolve_for_agent( + include=include or [], + tags=tags or [], + exclude=exclude or [], + profile=profile, + max_tools=max_tools, + versions=versions or {}, + ) + # Future-proof: normalize both dicts and typed objects to dicts. + return [_normalize_server_config(x) for x in resolved] + except ImportError: + # Expected while #6349/#6574 is not merged locally. + pass + except Exception as e: + logger.warning("MCPRegistry resolution failed; falling back to cache/fixtures: %s", e) + + return _resolve_from_local_index( + include=include, + tags=tags, + exclude=exclude, + profile=profile, + versions=versions or {}, + ) + + +def _resolve_from_local_index( + *, + include: list[str] | None, + tags: list[str] | None, + exclude: list[str] | None, + profile: str | None, + versions: dict[str, str], +) -> list[dict[str, Any]]: + index = _load_index_json() + servers = _coerce_index_servers(index) + servers_by_name: dict[str, dict[str, Any]] = { + s["name"]: s for s in servers if isinstance(s, dict) and "name" in s + } + + include_list = include or [] + tags_list = tags or [] + exclude_set = set(exclude or []) + + def _profiles_of(entry: dict[str, Any]) -> set[str]: + if isinstance(entry.get("profiles"), list): + return set(entry["profiles"]) + hive = entry.get("hive") + if isinstance(hive, dict) and isinstance(hive.get("profiles"), list): + return set(hive["profiles"]) + return set() + + def _tags_of(entry: dict[str, Any]) -> set[str]: + if isinstance(entry.get("tags"), list): + return set(entry["tags"]) + return set() + + def _entry_version(entry: dict[str, Any]) -> str | None: + # Prefer flat `version`, but support a few common shapes. + v = entry.get("version") + if isinstance(v, str): + return v + v2 = entry.get("manifest_version") + if isinstance(v2, str): + return v2 + hive = entry.get("manifest") + if isinstance(hive, dict) and isinstance(hive.get("version"), str): + return hive["version"] + return None + + def _version_allows(server_name: str) -> bool: + if server_name not in versions: + return True + pinned = versions[server_name] + entry = servers_by_name.get(server_name) + if not entry: + return False + return _entry_version(entry) == pinned + + resolved_names: list[str] = [] + resolved_set: set[str] = set() + + # 1) Include-order first + for name in include_list: + if name in exclude_set: + continue + if name in servers_by_name and _version_allows(name) and name not in resolved_set: + resolved_names.append(name) + resolved_set.add(name) + + # 2) Then tag/profile matches, alphabetical + profile_candidates = set() + if profile: + for name, entry in servers_by_name.items(): + if name in exclude_set or not _version_allows(name): + continue + if profile in _profiles_of(entry): + profile_candidates.add(name) + + tag_candidates = set() + if tags_list: + tags_set = set(tags_list) + for name, entry in servers_by_name.items(): + if name in exclude_set or not _version_allows(name): + continue + if _tags_of(entry).intersection(tags_set): + tag_candidates.add(name) + + tag_profile_names = sorted((profile_candidates | tag_candidates) - resolved_set) + resolved_names.extend(tag_profile_names) + + # Missing requested servers should warn (FR-54). + for name in include_list: + if name in exclude_set: + continue + if name not in resolved_set: + if name not in servers_by_name: + logger.warning( + "Server '%s' requested by mcp_registry.json but not found in index. " + "Run: hive mcp install %s", + name, + name, + ) + elif name in versions: + logger.warning( + "Server '%s' was requested but pinned version '%s' was not found in index. " + "Run: hive mcp update %s or change the pin in mcp_registry.json", + name, + versions[name], + name, + ) + else: + logger.warning( + "Server '%s' requested by mcp_registry.json was not selected. " + "Check selection filters/exclude lists.", + name, + ) + + resolved_configs: list[dict[str, Any]] = [] + repo_root = Path(__file__).resolve().parents[3] + for name in resolved_names: + entry = servers_by_name.get(name) + if not entry: + continue + config = entry.get("mcp_config") + if not isinstance(config, dict): + # Best-effort: allow a direct MCP config shape at top-level. + config = { + k: v + for k, v in entry.items() + if k + in { + "name", + "transport", + "command", + "args", + "env", + "cwd", + "url", + "headers", + "description", + } + } + mcp_config = dict(config) + mcp_config["name"] = name + if mcp_config.get("transport") == "stdio": + _absolutize_stdio_config_in_place(repo_root, mcp_config) + resolved_configs.append(mcp_config) + + return resolved_configs + + +def _load_index_json() -> Any: + if _CACHE_INDEX_PATH.exists(): + return json.loads(_CACHE_INDEX_PATH.read_text(encoding="utf-8")) + if _FIXTURE_INDEX_PATH.exists(): + logger.info("Using local fixture index because registry cache is missing") + return json.loads(_FIXTURE_INDEX_PATH.read_text(encoding="utf-8")) + logger.warning("No local MCP registry index found (cache and fixture missing)") + return {"servers": []} + + +def _coerce_index_servers(index: Any) -> list[dict[str, Any]]: + if isinstance(index, list): + return [x for x in index if isinstance(x, dict)] + if isinstance(index, dict): + servers = index.get("servers", []) + if isinstance(servers, list): + return [x for x in servers if isinstance(x, dict)] + return [] + + +def _normalize_server_config(raw: Any) -> dict[str, Any]: + if isinstance(raw, dict): + return dict(raw) + + # Future-proof object-to-dict normalization. + for attr in ("to_dict", "model_dump"): + maybe = getattr(raw, attr, None) + if callable(maybe): + return dict(maybe()) + + return dict(getattr(raw, "__dict__", {})) + + +def _absolutize_stdio_config_in_place(repo_root: Path, config: dict[str, Any]) -> None: + cwd = config.get("cwd") + if isinstance(cwd, str) and not Path(cwd).is_absolute(): + config["cwd"] = str((repo_root / cwd).resolve()) + + # We intentionally do not absolutize `args` here. + # For stdio servers, arguments may include the script name relative to + # `cwd` (e.g. "coder_tools_server.py" with cwd="tools"). ToolRegistry's + # stdio resolution logic handles script path checks and platform quirks. diff --git a/core/framework/runner/runner.py b/core/framework/runner/runner.py index 83e3777e..708cf782 100644 --- a/core/framework/runner/runner.py +++ b/core/framework/runner/runner.py @@ -808,6 +808,29 @@ class AgentRunner: if mcp_config_path.exists(): self._load_mcp_servers_from_config(mcp_config_path) + # Optional: load additional MCP servers selected via mcp_registry.json. + # This is backward-compatible: if the file doesn't exist, nothing changes. + mcp_registry_path = agent_path / "mcp_registry.json" + if mcp_registry_path.exists(): + try: + raw_selection = json.loads(mcp_registry_path.read_text(encoding="utf-8")) + if not isinstance(raw_selection, dict): + raise TypeError("mcp_registry.json must be a JSON object") + + allowed_keys = { + "include", + "tags", + "exclude", + "profile", + "max_tools", + "versions", + } + selection = {k: v for k, v in raw_selection.items() if k in allowed_keys} + + self._tool_registry.load_registry_servers(**selection) + except Exception as e: + logger.warning("Failed to load mcp_registry.json: %s", e) + @staticmethod def _import_agent_module(agent_path: Path): """Import an agent package from its directory path. diff --git a/core/framework/runner/tool_registry.py b/core/framework/runner/tool_registry.py index 6ceca562..84dad979 100644 --- a/core/framework/runner/tool_registry.py +++ b/core/framework/runner/tool_registry.py @@ -64,6 +64,7 @@ class ToolRegistry: self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names + self._mcp_registry_selection: dict[str, Any] | None = None # for resync def register( self, @@ -479,10 +480,85 @@ class ToolRegistry: self._mcp_cred_snapshot = self._snapshot_credentials() self._mcp_aden_key_snapshot = os.environ.get("ADEN_API_KEY") + def load_registry_servers( + self, + *, + include: list[str] | None = None, + tags: list[str] | None = None, + exclude: list[str] | None = None, + profile: str | None = None, + max_tools: int | None = None, + versions: dict[str, str] | None = None, + ) -> int: + """ + Resolve and load MCP servers based on `mcp_registry.json` selection. + + Implements: + - deterministic server order (resolved by resolver) + - first-wins tool collisions (existing tools are preserved) + - `max_tools` cap on *newly registered* tools from registry servers + """ + + from framework.runner.mcp_registry_resolver import resolve_registry_servers + + self._mcp_registry_selection = { + "include": include, + "tags": tags, + "exclude": exclude, + "profile": profile, + "max_tools": max_tools, + "versions": versions, + } + + resolved_servers = resolve_registry_servers( + include=include, + tags=tags, + exclude=exclude, + profile=profile, + max_tools=max_tools, + versions=versions or {}, + ) + if not resolved_servers: + logger.warning("MCP registry selection resolved to 0 servers; nothing to load") + return 0 + + tools_added = 0 + repo_root = Path(__file__).resolve().parents[3] + + for server_cfg in resolved_servers: + if not isinstance(server_cfg, dict): + continue + + if max_tools is not None and tools_added >= max_tools: + break + + # Normalize stdio config so scripts/cwd behave like mcp_servers.json loading. + server_cfg = self._resolve_mcp_server_config(server_cfg, repo_root) + + remaining = None + if max_tools is not None: + remaining = max_tools - tools_added + if remaining <= 0: + break + + added = self.register_mcp_server( + server_cfg, + preserve_existing_tools=True, + tool_cap=remaining, + log_collisions=True, + ) + tools_added += added + + return tools_added + def register_mcp_server( self, server_config: dict[str, Any], use_connection_manager: bool = False, + *, + preserve_existing_tools: bool = False, + tool_cap: int | None = None, + log_collisions: bool = False, ) -> int: """ Register an MCP server and discover its tools. @@ -540,6 +616,23 @@ class ToolRegistry: self._mcp_server_tools[server_name] = set() count = 0 for mcp_tool in client.list_tools(): + if tool_cap is not None and count >= tool_cap: + break + + if preserve_existing_tools and mcp_tool.name in self._tools: + if log_collisions: + origin_server = ( + self._find_mcp_origin_server_for_tool(mcp_tool.name) or "" + ) + logger.warning( + "MCP tool '%s' from '%s' shadowed by '%s' (loaded first)", + mcp_tool.name, + server_name, + origin_server, + ) + # Skip registration; do not update MCP tool bookkeeping for this server. + continue + # Convert MCP tool to framework Tool (strips context params from LLM schema) tool = self._convert_mcp_tool_to_framework_tool(mcp_tool) @@ -606,6 +699,12 @@ class ToolRegistry: ) return 0 + def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None: + for server_name, tool_names in self._mcp_server_tools.items(): + if tool_name in tool_names: + return server_name + return None + def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool: """ Convert an MCP tool to a framework Tool. @@ -738,9 +837,13 @@ class ToolRegistry: for name in self._mcp_tool_names: self._tools.pop(name, None) self._mcp_tool_names.clear() + self._mcp_server_tools.clear() # 3. Re-load MCP servers (spawns fresh subprocesses with new credentials) self.load_mcp_config(self._mcp_config_path) + if self._mcp_registry_selection is not None: + # Re-apply the same registry selection (preserves tool collision semantics). + self.load_registry_servers(**self._mcp_registry_selection) logger.info("MCP server resync complete") return True diff --git a/core/tests/test_mcp_registry_loader.py b/core/tests/test_mcp_registry_loader.py new file mode 100644 index 00000000..e9ff66df --- /dev/null +++ b/core/tests/test_mcp_registry_loader.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import Any + +from framework.runner.mcp_client import MCPTool +from framework.runner.tool_registry import ToolRegistry + + +def _make_tool(name: str, server_name: str) -> MCPTool: + return MCPTool( + name=name, + description=f"{name} from {server_name}", + input_schema={"type": "object", "properties": {}, "required": []}, + server_name=server_name, + ) + + +def test_registry_first_wins_collisions(monkeypatch): + """ + When multiple registry servers expose the same tool name, the first server + in load order should win and later servers should not overwrite it. + """ + + tool_map: dict[str, list[str]] = { + "s1": ["tool_common", "tool_hive"], + "s2": ["tool_common", "tool_coder"], + } + + from framework.runner import mcp_client as mcp_client_mod + + class FakeMCPClient: + def __init__(self, config: Any): + self.config = config + self._tools: list[MCPTool] = [] + + def connect(self) -> None: + return + + def disconnect(self) -> None: + return + + def list_tools(self) -> list[MCPTool]: + names = tool_map.get(self.config.name, []) + return [_make_tool(n, self.config.name) for n in names] + + def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + raise NotImplementedError + + monkeypatch.setattr(mcp_client_mod, "MCPClient", FakeMCPClient) + + from framework.runner import mcp_registry_resolver as resolver_mod + + # Return server configs in the desired deterministic order. + resolved_servers = [ + {"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None}, + {"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None}, + ] + monkeypatch.setattr( + resolver_mod, + "resolve_registry_servers", + lambda **kwargs: resolved_servers, + ) + + registry = ToolRegistry() + added = registry.load_registry_servers(include=["s1"], tags=None, exclude=None, profile=None) + + assert added == 3 # tool_common + tool_hive + tool_coder + assert registry.has_tool("tool_common") is True + assert registry.has_tool("tool_hive") is True + assert registry.has_tool("tool_coder") is True + + assert registry.get_server_tool_names("s1") == {"tool_common", "tool_hive"} + assert registry.get_server_tool_names("s2") == {"tool_coder"} + + +def test_registry_precedence_over_existing_mcp_servers(monkeypatch): + """Registry-loaded tools should not overwrite already registered MCP tools.""" + + tool_map: dict[str, list[str]] = { + "pre": ["tool_common", "tool_pre"], + "s1": ["tool_common", "tool_hive"], + "s2": ["tool_common", "tool_coder"], + } + + from framework.runner import mcp_client as mcp_client_mod + + class FakeMCPClient: + def __init__(self, config: Any): + self.config = config + + def connect(self) -> None: + return + + def disconnect(self) -> None: + return + + def list_tools(self) -> list[MCPTool]: + names = tool_map.get(self.config.name, []) + return [_make_tool(n, self.config.name) for n in names] + + def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + raise NotImplementedError + + monkeypatch.setattr(mcp_client_mod, "MCPClient", FakeMCPClient) + + from framework.runner import mcp_registry_resolver as resolver_mod + + resolved_servers = [ + {"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None}, + {"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None}, + ] + monkeypatch.setattr( + resolver_mod, + "resolve_registry_servers", + lambda **kwargs: resolved_servers, + ) + + registry = ToolRegistry() + registry.register_mcp_server( + {"name": "pre", "transport": "stdio", "command": "fake", "args": [], "cwd": None} + ) + + added = registry.load_registry_servers(include=None, tags=None, exclude=None, profile=None) + + assert added == 2 # only tool_hive + tool_coder; tool_common is preserved from "pre" + assert registry.get_server_tool_names("pre") == {"tool_common", "tool_pre"} + assert registry.get_server_tool_names("s1") == {"tool_hive"} + assert registry.get_server_tool_names("s2") == {"tool_coder"} + + +def test_registry_max_tools_cap(monkeypatch): + """max_tools caps the total number of newly added tools from registry servers.""" + + tool_map: dict[str, list[str]] = { + "s1": ["tool_a", "tool_b"], + "s2": ["tool_c"], + } + + from framework.runner import mcp_client as mcp_client_mod + + class FakeMCPClient: + def __init__(self, config: Any): + self.config = config + + def connect(self) -> None: + return + + def disconnect(self) -> None: + return + + def list_tools(self) -> list[MCPTool]: + names = tool_map.get(self.config.name, []) + return [_make_tool(n, self.config.name) for n in names] + + def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + raise NotImplementedError + + monkeypatch.setattr(mcp_client_mod, "MCPClient", FakeMCPClient) + + from framework.runner import mcp_registry_resolver as resolver_mod + + resolved_servers = [ + {"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None}, + {"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None}, + ] + monkeypatch.setattr( + resolver_mod, + "resolve_registry_servers", + lambda **kwargs: resolved_servers, + ) + + registry = ToolRegistry() + added = registry.load_registry_servers( + include=None, + tags=None, + exclude=None, + profile=None, + max_tools=2, + ) + + assert added == 2 + assert registry.has_tool("tool_a") is True + assert registry.has_tool("tool_b") is True + assert registry.has_tool("tool_c") is False diff --git a/examples/templates/vulnerability_assessment/mcp_registry.json b/examples/templates/vulnerability_assessment/mcp_registry.json new file mode 100644 index 00000000..aa558fc0 --- /dev/null +++ b/examples/templates/vulnerability_assessment/mcp_registry.json @@ -0,0 +1,4 @@ +{ + "include": ["jira"], + "max_tools": 10 +} \ No newline at end of file