"""Tool discovery and registration for agent runner.""" import asyncio import contextvars import importlib.util import inspect import json import logging import os import re from collections.abc import Callable from dataclasses import dataclass from pathlib import Path from typing import Any from framework.llm.provider import Tool, ToolResult, ToolUse logger = logging.getLogger(__name__) _INPUT_LOG_MAX_LEN = 500 # Tools whose names match this pattern are assumed to return ImageContent. # Matched against the bare tool name (case-insensitive). Used to mark MCP # tools with produces_image=True so they can be filtered out for text-only # models before the schema is ever shown to the LLM (avoids wasted calls # and "screenshot failed" entries polluting memory). _IMAGE_TOOL_NAME_RE = re.compile( r"(screenshot|screen_capture|capture_image|render_image|get_image|snapshot_image)", re.IGNORECASE, ) # Per-execution context overrides. Each asyncio task (and thus each # concurrent graph execution) gets its own copy, so there are no races # when multiple ExecutionStreams run in parallel. _execution_context: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar( "_execution_context", default=None ) @dataclass class RegisteredTool: """A tool with its executor function.""" tool: Tool executor: Callable[[dict], Any] class ToolRegistry: """ Manages tool discovery and registration. Tool Discovery Order: 1. Built-in tools (if any) 2. tools.py in agent folder 3. MCP servers 4. Manually registered tools """ # Framework-internal context keys injected into tool calls. # Stripped from LLM-facing schemas (the LLM doesn't know these values) # and auto-injected at call time for tools that accept them. CONTEXT_PARAMS = frozenset({"agent_id", "data_dir", "profile"}) # Tools that perform no filesystem/process/network writes and are safe # to run concurrently with other safe tools in the same assistant turn. # Unknown tools default to unsafe (serialized) - adding a name here is # an explicit promise about that tool's side effects. Keep this list # conservative: anything that mutates state, writes to disk, issues # POST/PUT/DELETE requests, or drives a browser MUST NOT be listed. CONCURRENCY_SAFE_TOOLS = frozenset( { # File system reads "read_file", "list_directory", "grep", "glob", # Web reads "web_search", "web_fetch", # Browser read-only snapshots (mutate-free observations) "browser_screenshot", "browser_snapshot", "browser_console", "browser_get_text", # Background bash polling - reads output buffers only, does # not touch the subprocess itself. "bash_output", } ) # Credential directory used for change detection _CREDENTIAL_DIR = Path("~/.hive/credentials/credentials").expanduser() def __init__(self): self._tools: dict[str, RegisteredTool] = {} self._mcp_clients: list[Any] = [] # List of MCPClient instances self._mcp_client_servers: dict[int, str] = {} # client id -> server name self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager self._session_context: dict[str, Any] = {} # Auto-injected context for tools self._provider_index: dict[str, set[str]] = {} # provider -> tool names # MCP resync tracking self._mcp_config_path: Path | None = None # Path used for initial load self._mcp_tool_names: set[str] = set() # Tool names registered from MCP self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names # tool name -> owning MCPClient (for force-kill on timeout) self._mcp_tool_clients: dict[str, Any] = {} # Per-agent env injected into every MCP server config.env. Kept # here (not on the process-wide os.environ) so parallel workers # in the same interpreter don't clobber each other's identity. self._mcp_extra_env: dict[str, str] = {} # Agent dir for re-loading registry MCP after credential resync. self._mcp_registry_agent_path: Path | None = None def set_mcp_extra_env(self, env: dict[str, str]) -> None: """Attach per-agent env vars to every MCPServerConfig this registry builds. Use this instead of mutating ``os.environ`` — the global env dict is shared across all workers in a single interpreter, so writes from one worker race with MCP spawns from another. """ self._mcp_extra_env = dict(env) def register( self, name: str, tool: Tool, executor: Callable[[dict], Any], ) -> None: """ Register a single tool with its executor. Args: name: Tool name (must match tool.name) tool: Tool definition executor: Function that takes tool input dict and returns result """ self._tools[name] = RegisteredTool(tool=tool, executor=executor) def register_function( self, func: Callable, name: str | None = None, description: str | None = None, ) -> None: """ Register a function as a tool, auto-generating the Tool definition. Args: func: Function to register name: Tool name (defaults to function name) description: Tool description (defaults to docstring) """ tool_name = name or func.__name__ tool_desc = description or func.__doc__ or f"Execute {tool_name}" # Generate parameters from function signature sig = inspect.signature(func) properties = {} required = [] for param_name, param in sig.parameters.items(): if param_name in ("self", "cls"): continue param_type = "string" # Default if param.annotation != inspect.Parameter.empty: if param.annotation is int: param_type = "integer" elif param.annotation is float: param_type = "number" elif param.annotation is bool: param_type = "boolean" elif param.annotation is dict: param_type = "object" elif param.annotation is list: param_type = "array" properties[param_name] = {"type": param_type} if param.default == inspect.Parameter.empty: required.append(param_name) tool = Tool( name=tool_name, description=tool_desc, parameters={ "type": "object", "properties": properties, "required": required, }, concurrency_safe=tool_name in self.CONCURRENCY_SAFE_TOOLS, ) def executor(inputs: dict) -> Any: return func(**inputs) self.register(tool_name, tool, executor) def discover_from_module(self, module_path: Path) -> int: """ Load tools from a Python module file. Looks for: - TOOLS: dict[str, Tool] - tool definitions - tool_executor(tool_use: ToolUse) -> ToolResult - unified executor - Functions decorated with @tool Args: module_path: Path to tools.py file Returns: Number of tools discovered """ if not module_path.exists(): return 0 # Load the module dynamically spec = importlib.util.spec_from_file_location("agent_tools", module_path) if spec is None or spec.loader is None: return 0 module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) count = 0 # Check for TOOLS dict if hasattr(module, "TOOLS"): tools_dict = module.TOOLS executor_func = getattr(module, "tool_executor", None) for name, tool in tools_dict.items(): if executor_func: # Use unified executor def make_executor(tool_name: str): def executor(inputs: dict) -> Any: tool_use = ToolUse( id=f"call_{tool_name}", name=tool_name, input=inputs, ) result = executor_func(tool_use) if isinstance(result, ToolResult): # ToolResult.content is expected to be JSON, but tools may # sometimes return invalid JSON. Guard against crashes here # and surface a structured error instead. if not result.content: return {} try: return json.loads(result.content) except json.JSONDecodeError as e: logger.warning( "Tool '%s' returned invalid JSON: %s", tool_name, str(e), ) return { "error": (f"Invalid JSON response from tool '{tool_name}': {str(e)}"), "raw_content": result.content, } return result return executor self.register(name, tool, make_executor(name)) else: # Register tool without executor (will use mock) self.register(name, tool, lambda inputs: {"mock": True, "inputs": inputs}) count += 1 # Check for @tool decorated functions for name in dir(module): obj = getattr(module, name) if callable(obj) and hasattr(obj, "_tool_metadata"): metadata = obj._tool_metadata self.register_function( obj, name=metadata.get("name", name), description=metadata.get("description"), ) count += 1 return count def get_tools(self) -> dict[str, Tool]: """Get all registered Tool objects.""" return {name: rt.tool for name, rt in self._tools.items()} def get_executor(self) -> Callable[[ToolUse], ToolResult]: """ Get unified tool executor function. Returns a function that dispatches to the appropriate tool executor. Handles both sync and async tool implementations — async results are wrapped so that ``EventLoopNode._execute_tool`` can await them. """ def _wrap_result(tool_use_id: str, result: Any) -> ToolResult: if isinstance(result, ToolResult): return result # MCP client returns dict with _images when image content is present if isinstance(result, dict) and "_images" in result: return ToolResult( tool_use_id=tool_use_id, content=result.get("_text", ""), image_content=result["_images"], ) return ToolResult( tool_use_id=tool_use_id, content=json.dumps(result) if not isinstance(result, str) else result, is_error=False, ) registry_ref = self def executor(tool_use: ToolUse) -> ToolResult: # Check if credential files changed (lightweight dir listing). # If new OAuth tokens appeared, restarts MCP servers to pick them up. registry_ref.resync_mcp_servers_if_needed() if tool_use.name not in registry_ref._tools: return ToolResult( tool_use_id=tool_use.id, content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}), is_error=True, ) registered = registry_ref._tools[tool_use.name] try: result = registered.executor(tool_use.input) # Async tool: wrap the awaitable so the caller can await it if asyncio.iscoroutine(result) or asyncio.isfuture(result): async def _await_and_wrap(): try: r = await result return _wrap_result(tool_use.id, r) except Exception as exc: inputs_str = json.dumps(tool_use.input, default=str) if len(inputs_str) > _INPUT_LOG_MAX_LEN: inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)" logger.error( "Async tool '%s' failed (tool_use_id=%s): %s\nInputs: %s", tool_use.name, tool_use.id, exc, inputs_str, exc_info=True, ) return ToolResult( tool_use_id=tool_use.id, content=json.dumps({"error": str(exc)}), is_error=True, ) return _await_and_wrap() return _wrap_result(tool_use.id, result) except Exception as e: inputs_str = json.dumps(tool_use.input, default=str) if len(inputs_str) > _INPUT_LOG_MAX_LEN: inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)" logger.error( "Tool '%s' execution failed for tool_use_id=%s: %s\nInputs: %s", tool_use.name, tool_use.id, e, inputs_str, exc_info=True, ) return ToolResult( tool_use_id=tool_use.id, content=json.dumps({"error": str(e)}), is_error=True, ) # Expose force-kill hook so the timeout handler can tear down a # hung MCP subprocess (asyncio.wait_for alone cannot). executor.kill_for_tool = registry_ref.kill_mcp_for_tool # type: ignore[attr-defined] return executor def get_registered_names(self) -> list[str]: """Get list of registered tool names.""" return list(self._tools.keys()) def has_tool(self, name: str) -> bool: """Check if a tool is registered.""" return name in self._tools def get_server_tool_names(self, server_name: str) -> set[str]: """Return tool names registered from a specific MCP server.""" return set(self._mcp_server_tools.get(server_name, set())) def set_session_context(self, **context) -> None: """ Set session context to auto-inject into tool calls. Args: **context: Key-value pairs to inject (e.g., workspace_id, agent_id, session_id) """ self._session_context.update(context) @staticmethod def set_execution_context(**context) -> contextvars.Token: """Set per-execution context overrides (concurrency-safe via contextvars). Values set here take precedence over session context. Each asyncio task gets its own copy, so concurrent executions don't interfere. Returns a token that must be passed to :meth:`reset_execution_context` to restore the previous state. """ current = _execution_context.get() or {} return _execution_context.set({**current, **context}) @staticmethod def reset_execution_context(token: contextvars.Token) -> None: """Restore execution context to its previous state.""" _execution_context.reset(token) @staticmethod def resolve_mcp_stdio_config(server_config: dict[str, Any], base_dir: Path) -> dict[str, Any]: """Resolve cwd and script paths for MCP stdio config (Windows compatibility). Use this when building MCPServerConfig from a config file (e.g. in list_agent_tools, discover_mcp_tools) so hive_tools and other servers work on Windows. Call with base_dir = directory containing the config. """ registry = ToolRegistry() return registry._resolve_mcp_server_config(server_config, base_dir) def _resolve_mcp_server_config(self, server_config: dict[str, Any], base_dir: Path) -> dict[str, Any]: """Resolve cwd and script paths for MCP stdio servers (Windows compatibility). On Windows, passing cwd to subprocess can cause WinError 267. We use cwd=None and absolute script paths when the server runs a .py script from the tools dir. If the resolved cwd doesn't exist (e.g. config from ~/.hive/agents/), fall back to Path.cwd() / "tools". """ config = dict(server_config) if config.get("transport") != "stdio": return config cwd = config.get("cwd") args = list(config.get("args", [])) if not cwd and not args: return config # Resolve cwd relative to base_dir resolved_cwd: Path | None = None if cwd: if Path(cwd).is_absolute(): resolved_cwd = Path(cwd) else: resolved_cwd = (base_dir / cwd).resolve() # Find .py script in args (e.g. coder_tools_server.py, files_server.py) script_name = None for i, arg in enumerate(args): if isinstance(arg, str) and arg.endswith(".py"): script_name = arg script_idx = i break if resolved_cwd is None: return config # If resolved cwd doesn't exist or (when we have a script) doesn't contain it, # try fallback tools_fallback = Path.cwd() / "tools" need_fallback = not resolved_cwd.is_dir() if script_name and not need_fallback: need_fallback = not (resolved_cwd / script_name).exists() if need_fallback: fallback_ok = tools_fallback.is_dir() if script_name: fallback_ok = fallback_ok and (tools_fallback / script_name).exists() else: # No script (e.g. GCU); just need tools dir to exist pass if fallback_ok: resolved_cwd = tools_fallback logger.debug( "MCP server '%s': using fallback tools dir %s", config.get("name", "?"), resolved_cwd, ) else: config["cwd"] = str(resolved_cwd) return config if not script_name: # No .py script (e.g. GCU uses -m gcu.server); just set cwd config["cwd"] = str(resolved_cwd) return config # For coder_tools_server, inject --project-root so writes go to the expected workspace if script_name and "coder_tools" in script_name: project_root = str(resolved_cwd.parent.resolve()) args = list(args) if "--project-root" not in args: args.extend(["--project-root", project_root]) config["args"] = args if os.name == "nt": # Windows: cwd=None avoids WinError 267; use absolute script path config["cwd"] = None abs_script = str((resolved_cwd / script_name).resolve()) args = list(config["args"]) args[script_idx] = abs_script config["args"] = args else: config["cwd"] = str(resolved_cwd) return config def load_mcp_config(self, config_path: Path) -> None: """ Load and register MCP servers from a config file. Resolves relative ``cwd`` paths against the config file's parent directory so callers never need to handle path resolution themselves. Args: config_path: Path to an ``mcp_servers.json`` file. """ # Remember config path for potential resync later self._mcp_config_path = Path(config_path) try: with open(config_path, encoding="utf-8") as f: config = json.load(f) except Exception as e: logger.warning(f"Failed to load MCP config from {config_path}: {e}") return base_dir = config_path.parent # Support both formats: # {"servers": [{"name": "x", ...}]} (list format) # {"server-name": {"transport": ...}, ...} (dict format) server_list = config.get("servers", []) if not server_list and "servers" not in config: # Treat top-level keys as server names server_list = [{"name": name, **cfg} for name, cfg in config.items()] resolved_server_list = [ self._resolve_mcp_server_config(server_config, base_dir) for server_config in server_list ] # Ordered first-wins for duplicate tool names across servers; keep tools.py tools. self.load_registry_servers( resolved_server_list, log_summary=False, preserve_existing_tools=True, log_collisions=False, ) # Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes self._mcp_cred_snapshot = self._snapshot_credentials() self._mcp_aden_key_snapshot = os.environ.get("ADEN_API_KEY") self._log_registry_snapshot("after load_mcp_config") def _register_mcp_server_with_retry( self, server_config: dict[str, Any], *, preserve_existing_tools: bool = True, tool_cap: int | None = None, log_collisions: bool = False, ) -> tuple[bool, int, str | None]: """Register a single MCP server with one retry for transient failures.""" name = server_config.get("name", "unknown") last_error: str | None = None for attempt in range(2): try: count = self.register_mcp_server( server_config, preserve_existing_tools=preserve_existing_tools, tool_cap=tool_cap, log_collisions=log_collisions, ) if count > 0: return True, count, None last_error = "registered 0 tools" except Exception as exc: last_error = str(exc) if attempt == 0: logger.warning( "MCP server '%s' failed to register, retrying in 2s: %s", name, last_error, ) import time time.sleep(2) else: logger.warning("MCP server '%s' failed after retry: %s", name, last_error) return False, 0, last_error def load_registry_servers( self, server_list: list[dict[str, Any]], *, log_summary: bool = True, preserve_existing_tools: bool = True, max_tools: int | None = None, log_collisions: bool = False, ) -> list[dict[str, Any]]: """Register MCP servers from a resolved config list (registry and/or static). ``preserve_existing_tools`` enforces first-wins tool names (FR-100): later servers skip names already taken— including tools from ``mcp_servers.json`` or ``tools.py`` when those were loaded first. ``max_tools`` caps how many *new* tool names are registered across this batch (collisions do not consume the cap). When ``log_collisions`` is True, skipped duplicate names emit a warning (FR-101). """ results: list[dict[str, Any]] = [] tools_added_batch = 0 for server_config in server_list: remaining: int | None = None if max_tools is not None: remaining = max_tools - tools_added_batch if remaining <= 0: break name = server_config.get("name", "unknown") success, tools_loaded, error = self._register_mcp_server_with_retry( server_config, preserve_existing_tools=preserve_existing_tools, tool_cap=remaining, log_collisions=log_collisions, ) tools_added_batch += tools_loaded result = { "server": name, "status": "loaded" if success else "skipped", "tools_loaded": tools_loaded, "skipped_reason": None if success else (error or "unknown error"), } results.append(result) if log_summary: logger.info( "MCP registry server resolution", extra={ "event": "mcp_registry_server_resolution", "server": result["server"], "status": result["status"], "tools_loaded": result["tools_loaded"], "skipped_reason": result["skipped_reason"], }, ) return results def register_mcp_server( self, server_config: dict[str, Any], use_connection_manager: bool = True, *, preserve_existing_tools: bool = True, tool_cap: int | None = None, log_collisions: bool = False, ) -> int: """ Register an MCP server and discover its tools. Args: server_config: MCP server configuration dict with keys: - name: Server name (required) - transport: "stdio" or "http" (required) - command: Command to run (for stdio) - args: Command arguments (for stdio) - env: Environment variables (for stdio) - cwd: Working directory (for stdio) - url: Server URL (for http) - headers: HTTP headers (for http) - description: Server description (optional) use_connection_manager: When True, reuse a shared client keyed by server name preserve_existing_tools: If True, do not replace tools already in the registry. tool_cap: Max tools to newly register from this server (None = unlimited). log_collisions: If True, log when this server skips a tool name already taken. Returns: Number of tools registered from this server """ try: from framework.loader.mcp_client import MCPClient, MCPServerConfig from framework.loader.mcp_connection_manager import MCPConnectionManager # Build config object. Merge per-agent env on top of the # server's own env so MCP subprocesses receive the identity # of the worker that spawned them (instead of whichever # worker most recently wrote to os.environ). merged_env = {**self._mcp_extra_env, **(server_config.get("env") or {})} config = MCPServerConfig( name=server_config["name"], transport=server_config["transport"], command=server_config.get("command"), args=server_config.get("args", []), env=merged_env, cwd=server_config.get("cwd"), url=server_config.get("url"), headers=server_config.get("headers", {}), socket_path=server_config.get("socket_path"), description=server_config.get("description", ""), ) # Create and connect client if use_connection_manager: client = MCPConnectionManager.get_instance().acquire(config) else: client = MCPClient(config) client.connect() # Store client for cleanup self._mcp_clients.append(client) client_id = id(client) self._mcp_client_servers[client_id] = config.name if use_connection_manager: self._mcp_managed_clients.add(client_id) # Register each tool server_name = server_config["name"] if server_name not in self._mcp_server_tools: self._mcp_server_tools[server_name] = set() # Build admission gate: only admit MCP tools that are either # (a) credential-backed *and* have a configured account, or # (b) credential-less *and* listed in the verified manifest. # Servers that don't expose `__aden_verified_manifest` (third-party # MCP servers) bypass the gate entirely — preserves prior behavior. admit = self._build_mcp_admission_gate(client) count = 0 admitted_names: list[str] = [] for mcp_tool in client.list_tools(): if not admit(mcp_tool.name): continue if tool_cap is not None and count >= tool_cap: break if preserve_existing_tools and mcp_tool.name in self._tools: if log_collisions: origin_server = self._find_mcp_origin_server_for_tool(mcp_tool.name) or "" logger.warning( "MCP tool '%s' from '%s' shadowed by '%s' (loaded first)", mcp_tool.name, server_name, origin_server, ) # Skip registration; do not update MCP tool bookkeeping for this server. continue # Convert MCP tool to framework Tool (strips context params from LLM schema) tool = self._convert_mcp_tool_to_framework_tool(mcp_tool) # Create executor that calls the MCP server def make_mcp_executor( client_ref: MCPClient, tool_name: str, registry_ref, tool_params: set[str], ): def executor(inputs: dict) -> Any: try: # Build base context: session < execution (execution wins) base_context = dict(registry_ref._session_context) exec_ctx = _execution_context.get() if exec_ctx: base_context.update(exec_ctx) # Only inject context params the tool accepts filtered_context = {k: v for k, v in base_context.items() if k in tool_params} # Strip context params from LLM inputs — the framework # values are authoritative (prevents the LLM from passing # e.g. data_dir="/data" and overriding the real path). clean_inputs = {k: v for k, v in inputs.items() if k not in registry_ref.CONTEXT_PARAMS} merged_inputs = {**clean_inputs, **filtered_context} result = client_ref.call_tool(tool_name, merged_inputs) # MCP client already extracts content (returns str # or {"_text": ..., "_images": ...} for image results). # Handle legacy list format from HTTP transport. if isinstance(result, list) and len(result) > 0: if isinstance(result[0], dict) and "text" in result[0]: return result[0]["text"] return result[0] return result except Exception as e: inputs_str = json.dumps(inputs, default=str) if len(inputs_str) > _INPUT_LOG_MAX_LEN: inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)" logger.error( "MCP tool '%s' execution failed: %s\nInputs: %s", tool_name, e, inputs_str, exc_info=True, ) return {"error": str(e)} return executor tool_params = set(mcp_tool.input_schema.get("properties", {}).keys()) self.register( mcp_tool.name, tool, make_mcp_executor(client, mcp_tool.name, self, tool_params), ) self._mcp_tool_names.add(mcp_tool.name) self._mcp_tool_clients[mcp_tool.name] = client self._mcp_server_tools[server_name].add(mcp_tool.name) admitted_names.append(mcp_tool.name) count += 1 logger.info( "MCP Registry Load", extra={ "server": config.name, "status": "success", "tools_loaded": count, "skipped_reason": None, }, ) logger.info( "MCP server '%s' admitted %d tool(s): %s", config.name, len(admitted_names), sorted(admitted_names), ) return count except Exception as e: logger.error( "MCP Registry Load", extra={ "server": server_config.get("name", "unknown"), "status": "failed", "tools_loaded": 0, "skipped_reason": str(e), }, ) if "Connection closed" in str(e) and os.name == "nt": logger.debug( "On Windows, check that the MCP subprocess starts (e.g. uv in PATH, " "script path correct). Worker config uses base_dir = mcp_servers.json parent." ) return 0 def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None: for server_name, tool_names in self._mcp_server_tools.items(): if tool_name in tool_names: return server_name return None def _log_registry_snapshot(self, context: str) -> None: """Emit a one-line summary of the current tool registry. Called after every tool-list mutation (initial load + resync) so that operators can correlate "what tools does the queen have right now" with credential changes and MCP server lifecycle events. Per-server contents are already logged by `register_mcp_server`; this is just the rollup so the resync path also gets a single anchor line. """ per_server_counts = {server: len(names) for server, names in self._mcp_server_tools.items()} non_mcp_count = len(self._tools) - len(self._mcp_tool_names) logger.info( "ToolRegistry snapshot (%s): total=%d, mcp=%d, non_mcp=%d, per_server=%s", context, len(self._tools), len(self._mcp_tool_names), non_mcp_count, per_server_counts, ) _MCP_VERIFIED_MANIFEST_TOOL = "__aden_verified_manifest" def _build_mcp_admission_gate(self, client: Any) -> Callable[[str], bool]: """Build a per-server predicate that filters MCP tools at registration. Rules: * The sentinel manifest tool itself is never admitted. * Credential-backed tools (provider in `tool_provider_map`) are admitted only when at least one account exists for that provider. * Credential-less tools are admitted only when they appear in the server's verified manifest. * Servers that don't expose a manifest bypass the verified gate entirely (third-party MCP servers behave as before). """ verified_names: set[str] = set() manifest_present = False # Only probe the sentinel when the server actually advertises it. # Calling ``__aden_verified_manifest`` unconditionally on every # MCP server at registration time (a) causes a bogus tool call # round-trip to every third-party server, (b) pollutes any # call-capturing fakes in tests, and (c) risks side effects on # servers that eagerly execute unknown tool names. Listing is # cheap and cached by the client; this keeps the manifest gate # active for aden-flavoured servers without penalising others. sentinel_advertised = False try: for t in client.list_tools(): if getattr(t, "name", None) == self._MCP_VERIFIED_MANIFEST_TOOL: sentinel_advertised = True break except Exception: sentinel_advertised = False if sentinel_advertised: try: raw = client.call_tool(self._MCP_VERIFIED_MANIFEST_TOOL, {}) parsed: Any = raw if isinstance(raw, str): try: parsed = json.loads(raw) except json.JSONDecodeError: parsed = None # Only treat the response as a manifest when it's a list # of strings. A malformed response shouldn't flip the gate # on and silently hide every real tool from the server. if isinstance(parsed, list) and all(isinstance(n, str) for n in parsed): verified_names = set(parsed) manifest_present = True except Exception: # Server advertised the sentinel but errored when called # — treat as no manifest; fall back to third-party bypass. pass tool_provider_map: dict[str, str] = {} live_providers: set[str] = set() try: from aden_tools.credentials.store_adapter import CredentialStoreAdapter adapter = CredentialStoreAdapter.default() tool_provider_map = adapter.get_tool_provider_map() live_providers = {a.get("provider", "") for a in adapter.get_all_account_info() if a.get("provider")} except Exception: logger.debug("Credential snapshot unavailable for MCP gate", exc_info=True) def admit(tool_name: str) -> bool: if tool_name == self._MCP_VERIFIED_MANIFEST_TOOL: return False provider = tool_provider_map.get(tool_name) if provider: # Credentialed tool — needs an account. return provider in live_providers if not manifest_present: # Third-party MCP server: preserve legacy "admit everything". return True return tool_name in verified_names return admit def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool: """ Convert an MCP tool to a framework Tool. Args: mcp_tool: MCPTool object Returns: Framework Tool object """ # Extract parameters from MCP input schema input_schema = mcp_tool.input_schema properties = input_schema.get("properties", {}) required = input_schema.get("required", []) # Strip framework-internal context params from LLM-facing schema. # The LLM can't know these values; they're auto-injected at call time. properties = {k: v for k, v in properties.items() if k not in self.CONTEXT_PARAMS} required = [r for r in required if r not in self.CONTEXT_PARAMS] # Convert to framework Tool format tool = Tool( name=mcp_tool.name, description=mcp_tool.description, parameters={ "type": "object", "properties": properties, "required": required, }, produces_image=bool(_IMAGE_TOOL_NAME_RE.search(mcp_tool.name or "")), concurrency_safe=mcp_tool.name in self.CONCURRENCY_SAFE_TOOLS, ) return tool # ------------------------------------------------------------------ # Provider-based tool filtering # ------------------------------------------------------------------ def build_provider_index(self) -> None: """Build provider -> tool-name mapping from CREDENTIAL_SPECS. Populates ``_provider_index`` so :meth:`get_by_provider` works. Safe to call even if ``aden_tools`` is not installed (silently no-ops). """ try: from aden_tools.credentials import CREDENTIAL_SPECS except ImportError: logger.debug("aden_tools not available, skipping provider index") return self._provider_index.clear() for spec in CREDENTIAL_SPECS.values(): provider = spec.aden_provider_name if provider: if provider not in self._provider_index: self._provider_index[provider] = set() self._provider_index[provider].update(spec.tools) def get_by_provider(self, provider: str) -> dict[str, Tool]: """Return registered tools that belong to *provider*. Lazily builds the provider index on first call. """ if not self._provider_index: self.build_provider_index() tool_names = self._provider_index.get(provider, set()) return {name: rt.tool for name, rt in self._tools.items() if name in tool_names} def get_tool_names_by_provider(self, provider: str) -> list[str]: """Return sorted registered tool names for *provider*.""" if not self._provider_index: self.build_provider_index() tool_names = self._provider_index.get(provider, set()) return sorted(name for name in self._tools if name in tool_names) def get_all_provider_tool_names(self) -> list[str]: """Return sorted names of all registered tools that belong to any provider.""" if not self._provider_index: self.build_provider_index() all_names: set[str] = set() for names in self._provider_index.values(): all_names.update(names) return sorted(name for name in self._tools if name in all_names) # ------------------------------------------------------------------ # MCP credential resync # ------------------------------------------------------------------ def set_mcp_registry_agent_path(self, agent_path: Path | None) -> None: """Remember agent dir so registry MCP servers reload after credential resync.""" self._mcp_registry_agent_path = None if agent_path is None else Path(agent_path) def reload_registry_mcp_servers_after_resync(self) -> None: """Re-run ``mcp_registry.json`` resolution and register servers (post-resync).""" if self._mcp_registry_agent_path is None: return from framework.loader.mcp_registry import MCPRegistry try: reg = MCPRegistry() reg.initialize() configs, selection_max_tools = reg.load_agent_selection(self._mcp_registry_agent_path) except Exception as exc: logger.warning( "Failed to reload MCP registry servers after resync for '%s': %s", self._mcp_registry_agent_path.name, exc, ) return if not configs: return self.load_registry_servers( configs, log_summary=True, preserve_existing_tools=True, log_collisions=True, max_tools=selection_max_tools, ) def _snapshot_credentials(self) -> set[str]: """Return the set of credential filenames currently on disk.""" try: return set(self._CREDENTIAL_DIR.iterdir()) if self._CREDENTIAL_DIR.is_dir() else set() except OSError: return set() def resync_mcp_servers_if_needed(self) -> bool: """Restart MCP servers if credential files changed since last load. Compares the current credential directory listing against the snapshot taken when MCP servers were first loaded. If new files appeared (e.g. user connected an OAuth account mid-session), disconnects all MCP clients and re-loads them so the new subprocess picks up the fresh credentials. Note: Individual credential TTL/refresh is handled by the MCP server process internally -- it resolves tokens from the credential store on every tool call, not at startup. This method only handles the case where entirely new credential files appear. Returns True if a resync was performed, False otherwise. """ if not self._mcp_clients or self._mcp_config_path is None: return False current = self._snapshot_credentials() current_aden_key = os.environ.get("ADEN_API_KEY") files_changed = current != self._mcp_cred_snapshot aden_key_changed = current_aden_key != self._mcp_aden_key_snapshot if not files_changed and not aden_key_changed: return False reason = ( "Credential files and ADEN_API_KEY changed" if files_changed and aden_key_changed else "ADEN_API_KEY changed" if aden_key_changed else "Credential files changed" ) logger.info("%s — resyncing MCP servers", reason) # 1. Disconnect existing MCP clients self._cleanup_mcp_clients("during resync") # 2. Remove MCP-registered tools for name in self._mcp_tool_names: self._tools.pop(name, None) self._mcp_tool_names.clear() self._mcp_server_tools.clear() # 3. Re-load MCP servers (spawns fresh subprocesses with new credentials) self.load_mcp_config(self._mcp_config_path) if self._mcp_registry_agent_path is not None: self.reload_registry_mcp_servers_after_resync() logger.info("MCP server resync complete") self._log_registry_snapshot("after resync_mcp_servers_if_needed") return True def cleanup(self) -> None: """Clean up all MCP client connections.""" self._cleanup_mcp_clients() def _cleanup_mcp_clients(self, context: str = "") -> None: """Disconnect or release all tracked MCP clients for this registry.""" if context: context = f" {context}" for client in self._mcp_clients: client_id = id(client) server_name = self._mcp_client_servers.get(client_id, client.config.name) try: if client_id in self._mcp_managed_clients: from framework.loader.mcp_connection_manager import MCPConnectionManager MCPConnectionManager.get_instance().release(server_name) else: client.disconnect() except Exception as e: logger.warning(f"Error disconnecting MCP client{context}: {e}") self._mcp_clients.clear() self._mcp_client_servers.clear() self._mcp_managed_clients.clear() self._mcp_tool_clients.clear() def kill_mcp_for_tool(self, tool_name: str) -> bool: """Force-disconnect the MCP client that owns *tool_name*. Called from the timeout handler in ``execute_tool`` when a tool call hangs. Plain ``asyncio.wait_for`` cancellation cannot stop a sync executor running inside a thread pool (and therefore cannot stop the MCP subprocess), so we reach through to the client here and tear it down. The next ``call_tool`` triggers an automatic reconnect. Returns True if a client was found and disconnect was attempted. """ client = self._mcp_tool_clients.get(tool_name) if client is None: return False try: logger.warning( "Force-disconnecting MCP client for hung tool '%s' on server '%s'", tool_name, getattr(client.config, "name", "?"), ) client.disconnect() except Exception as exc: logger.warning("Error force-disconnecting MCP client for '%s': %s", tool_name, exc) return True def __del__(self): """Destructor to ensure cleanup.""" self.cleanup() def tool( description: str | None = None, name: str | None = None, ) -> Callable: """ Decorator to mark a function as a tool. Usage: @tool(description="Fetch lead from GTM table") def gtm_fetch_lead(lead_id: str) -> dict: return {"lead_data": {...}} """ def decorator(func: Callable) -> Callable: func._tool_metadata = { "name": name or func.__name__, "description": description or func.__doc__, } return func return decorator