Files
hive/core/framework/loader/tool_registry.py
T
2026-04-10 22:12:13 -07:00

1179 lines
46 KiB
Python

"""Tool discovery and registration for agent runner."""
import asyncio
import contextvars
import importlib.util
import inspect
import json
import logging
import os
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from framework.llm.provider import Tool, ToolResult, ToolUse
logger = logging.getLogger(__name__)
_INPUT_LOG_MAX_LEN = 500
# Per-execution context overrides. Each asyncio task (and thus each
# concurrent graph execution) gets its own copy, so there are no races
# when multiple ExecutionStreams run in parallel.
_execution_context: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar(
"_execution_context", default=None
)
@dataclass
class RegisteredTool:
"""A tool with its executor function."""
tool: Tool
executor: Callable[[dict], Any]
class ToolRegistry:
"""
Manages tool discovery and registration.
Tool Discovery Order:
1. Built-in tools (if any)
2. tools.py in agent folder
3. MCP servers
4. Manually registered tools
"""
# Framework-internal context keys injected into tool calls.
# Stripped from LLM-facing schemas (the LLM doesn't know these values)
# and auto-injected at call time for tools that accept them.
CONTEXT_PARAMS = frozenset({"agent_id", "data_dir", "profile"})
# Credential directory used for change detection
_CREDENTIAL_DIR = Path("~/.hive/credentials/credentials").expanduser()
def __init__(self):
self._tools: dict[str, RegisteredTool] = {}
self._mcp_clients: list[Any] = [] # List of MCPClient instances
self._mcp_client_servers: dict[int, str] = {} # client id -> server name
self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
# MCP resync tracking
self._mcp_config_path: Path | None = None # Path used for initial load
self._mcp_tool_names: set[str] = set() # Tool names registered from MCP
self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time
self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time
self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names
# tool name -> owning MCPClient (for force-kill on timeout)
self._mcp_tool_clients: dict[str, Any] = {}
# Per-agent env injected into every MCP server config.env. Kept
# here (not on the process-wide os.environ) so parallel workers
# in the same interpreter don't clobber each other's identity.
self._mcp_extra_env: dict[str, str] = {}
# Agent dir for re-loading registry MCP after credential resync.
self._mcp_registry_agent_path: Path | None = None
def set_mcp_extra_env(self, env: dict[str, str]) -> None:
"""Attach per-agent env vars to every MCPServerConfig this registry builds.
Use this instead of mutating ``os.environ`` — the global env dict
is shared across all workers in a single interpreter, so writes
from one worker race with MCP spawns from another.
"""
self._mcp_extra_env = dict(env)
def register(
self,
name: str,
tool: Tool,
executor: Callable[[dict], Any],
) -> None:
"""
Register a single tool with its executor.
Args:
name: Tool name (must match tool.name)
tool: Tool definition
executor: Function that takes tool input dict and returns result
"""
self._tools[name] = RegisteredTool(tool=tool, executor=executor)
def register_function(
self,
func: Callable,
name: str | None = None,
description: str | None = None,
) -> None:
"""
Register a function as a tool, auto-generating the Tool definition.
Args:
func: Function to register
name: Tool name (defaults to function name)
description: Tool description (defaults to docstring)
"""
tool_name = name or func.__name__
tool_desc = description or func.__doc__ or f"Execute {tool_name}"
# Generate parameters from function signature
sig = inspect.signature(func)
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name in ("self", "cls"):
continue
param_type = "string" # Default
if param.annotation != inspect.Parameter.empty:
if param.annotation is int:
param_type = "integer"
elif param.annotation is float:
param_type = "number"
elif param.annotation is bool:
param_type = "boolean"
elif param.annotation is dict:
param_type = "object"
elif param.annotation is list:
param_type = "array"
properties[param_name] = {"type": param_type}
if param.default == inspect.Parameter.empty:
required.append(param_name)
tool = Tool(
name=tool_name,
description=tool_desc,
parameters={
"type": "object",
"properties": properties,
"required": required,
},
)
def executor(inputs: dict) -> Any:
return func(**inputs)
self.register(tool_name, tool, executor)
def discover_from_module(self, module_path: Path) -> int:
"""
Load tools from a Python module file.
Looks for:
- TOOLS: dict[str, Tool] - tool definitions
- tool_executor(tool_use: ToolUse) -> ToolResult - unified executor
- Functions decorated with @tool
Args:
module_path: Path to tools.py file
Returns:
Number of tools discovered
"""
if not module_path.exists():
return 0
# Load the module dynamically
spec = importlib.util.spec_from_file_location("agent_tools", module_path)
if spec is None or spec.loader is None:
return 0
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
count = 0
# Check for TOOLS dict
if hasattr(module, "TOOLS"):
tools_dict = module.TOOLS
executor_func = getattr(module, "tool_executor", None)
for name, tool in tools_dict.items():
if executor_func:
# Use unified executor
def make_executor(tool_name: str):
def executor(inputs: dict) -> Any:
tool_use = ToolUse(
id=f"call_{tool_name}",
name=tool_name,
input=inputs,
)
result = executor_func(tool_use)
if isinstance(result, ToolResult):
# ToolResult.content is expected to be JSON, but tools may
# sometimes return invalid JSON. Guard against crashes here
# and surface a structured error instead.
if not result.content:
return {}
try:
return json.loads(result.content)
except json.JSONDecodeError as e:
logger.warning(
"Tool '%s' returned invalid JSON: %s",
tool_name,
str(e),
)
return {
"error": (
f"Invalid JSON response from tool '{tool_name}': "
f"{str(e)}"
),
"raw_content": result.content,
}
return result
return executor
self.register(name, tool, make_executor(name))
else:
# Register tool without executor (will use mock)
self.register(name, tool, lambda inputs: {"mock": True, "inputs": inputs})
count += 1
# Check for @tool decorated functions
for name in dir(module):
obj = getattr(module, name)
if callable(obj) and hasattr(obj, "_tool_metadata"):
metadata = obj._tool_metadata
self.register_function(
obj,
name=metadata.get("name", name),
description=metadata.get("description"),
)
count += 1
return count
def get_tools(self) -> dict[str, Tool]:
"""Get all registered Tool objects."""
return {name: rt.tool for name, rt in self._tools.items()}
def get_executor(self) -> Callable[[ToolUse], ToolResult]:
"""
Get unified tool executor function.
Returns a function that dispatches to the appropriate tool executor.
Handles both sync and async tool implementations — async results are
wrapped so that ``EventLoopNode._execute_tool`` can await them.
"""
def _wrap_result(tool_use_id: str, result: Any) -> ToolResult:
if isinstance(result, ToolResult):
return result
# MCP client returns dict with _images when image content is present
if isinstance(result, dict) and "_images" in result:
return ToolResult(
tool_use_id=tool_use_id,
content=result.get("_text", ""),
image_content=result["_images"],
)
return ToolResult(
tool_use_id=tool_use_id,
content=json.dumps(result) if not isinstance(result, str) else result,
is_error=False,
)
registry_ref = self
def executor(tool_use: ToolUse) -> ToolResult:
# Check if credential files changed (lightweight dir listing).
# If new OAuth tokens appeared, restarts MCP servers to pick them up.
registry_ref.resync_mcp_servers_if_needed()
if tool_use.name not in registry_ref._tools:
return ToolResult(
tool_use_id=tool_use.id,
content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}),
is_error=True,
)
registered = registry_ref._tools[tool_use.name]
try:
result = registered.executor(tool_use.input)
# Async tool: wrap the awaitable so the caller can await it
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
async def _await_and_wrap():
try:
r = await result
return _wrap_result(tool_use.id, r)
except Exception as exc:
inputs_str = json.dumps(tool_use.input, default=str)
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
logger.error(
"Async tool '%s' failed (tool_use_id=%s): %s\nInputs: %s",
tool_use.name,
tool_use.id,
exc,
inputs_str,
exc_info=True,
)
return ToolResult(
tool_use_id=tool_use.id,
content=json.dumps({"error": str(exc)}),
is_error=True,
)
return _await_and_wrap()
return _wrap_result(tool_use.id, result)
except Exception as e:
inputs_str = json.dumps(tool_use.input, default=str)
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
logger.error(
"Tool '%s' execution failed for tool_use_id=%s: %s\nInputs: %s",
tool_use.name,
tool_use.id,
e,
inputs_str,
exc_info=True,
)
return ToolResult(
tool_use_id=tool_use.id,
content=json.dumps({"error": str(e)}),
is_error=True,
)
# Expose force-kill hook so the timeout handler can tear down a
# hung MCP subprocess (asyncio.wait_for alone cannot).
executor.kill_for_tool = registry_ref.kill_mcp_for_tool # type: ignore[attr-defined]
return executor
def get_registered_names(self) -> list[str]:
"""Get list of registered tool names."""
return list(self._tools.keys())
def has_tool(self, name: str) -> bool:
"""Check if a tool is registered."""
return name in self._tools
def get_server_tool_names(self, server_name: str) -> set[str]:
"""Return tool names registered from a specific MCP server."""
return set(self._mcp_server_tools.get(server_name, set()))
def set_session_context(self, **context) -> None:
"""
Set session context to auto-inject into tool calls.
Args:
**context: Key-value pairs to inject (e.g., workspace_id, agent_id, session_id)
"""
self._session_context.update(context)
@staticmethod
def set_execution_context(**context) -> contextvars.Token:
"""Set per-execution context overrides (concurrency-safe via contextvars).
Values set here take precedence over session context. Each asyncio
task gets its own copy, so concurrent executions don't interfere.
Returns a token that must be passed to :meth:`reset_execution_context`
to restore the previous state.
"""
current = _execution_context.get() or {}
return _execution_context.set({**current, **context})
@staticmethod
def reset_execution_context(token: contextvars.Token) -> None:
"""Restore execution context to its previous state."""
_execution_context.reset(token)
@staticmethod
def resolve_mcp_stdio_config(server_config: dict[str, Any], base_dir: Path) -> dict[str, Any]:
"""Resolve cwd and script paths for MCP stdio config (Windows compatibility).
Use this when building MCPServerConfig from a config file (e.g. in
list_agent_tools, discover_mcp_tools) so hive_tools and other servers
work on Windows. Call with base_dir = directory containing the config.
"""
registry = ToolRegistry()
return registry._resolve_mcp_server_config(server_config, base_dir)
def _resolve_mcp_server_config(
self, server_config: dict[str, Any], base_dir: Path
) -> dict[str, Any]:
"""Resolve cwd and script paths for MCP stdio servers (Windows compatibility).
On Windows, passing cwd to subprocess can cause WinError 267. We use cwd=None
and absolute script paths when the server runs a .py script from the tools dir.
If the resolved cwd doesn't exist (e.g. config from ~/.hive/agents/), fall back
to Path.cwd() / "tools".
"""
config = dict(server_config)
if config.get("transport") != "stdio":
return config
cwd = config.get("cwd")
args = list(config.get("args", []))
if not cwd and not args:
return config
# Resolve cwd relative to base_dir
resolved_cwd: Path | None = None
if cwd:
if Path(cwd).is_absolute():
resolved_cwd = Path(cwd)
else:
resolved_cwd = (base_dir / cwd).resolve()
# Find .py script in args (e.g. coder_tools_server.py, files_server.py)
script_name = None
for i, arg in enumerate(args):
if isinstance(arg, str) and arg.endswith(".py"):
script_name = arg
script_idx = i
break
if resolved_cwd is None:
return config
# If resolved cwd doesn't exist or (when we have a script) doesn't contain it,
# try fallback
tools_fallback = Path.cwd() / "tools"
need_fallback = not resolved_cwd.is_dir()
if script_name and not need_fallback:
need_fallback = not (resolved_cwd / script_name).exists()
if need_fallback:
fallback_ok = tools_fallback.is_dir()
if script_name:
fallback_ok = fallback_ok and (tools_fallback / script_name).exists()
else:
# No script (e.g. GCU); just need tools dir to exist
pass
if fallback_ok:
resolved_cwd = tools_fallback
logger.debug(
"MCP server '%s': using fallback tools dir %s",
config.get("name", "?"),
resolved_cwd,
)
else:
config["cwd"] = str(resolved_cwd)
return config
if not script_name:
# No .py script (e.g. GCU uses -m gcu.server); just set cwd
config["cwd"] = str(resolved_cwd)
return config
# For coder_tools_server, inject --project-root so writes go to the expected workspace
if script_name and "coder_tools" in script_name:
project_root = str(resolved_cwd.parent.resolve())
args = list(args)
if "--project-root" not in args:
args.extend(["--project-root", project_root])
config["args"] = args
if os.name == "nt":
# Windows: cwd=None avoids WinError 267; use absolute script path
config["cwd"] = None
abs_script = str((resolved_cwd / script_name).resolve())
args = list(config["args"])
args[script_idx] = abs_script
config["args"] = args
else:
config["cwd"] = str(resolved_cwd)
return config
def load_mcp_config(self, config_path: Path) -> None:
"""
Load and register MCP servers from a config file.
Resolves relative ``cwd`` paths against the config file's parent
directory so callers never need to handle path resolution themselves.
Args:
config_path: Path to an ``mcp_servers.json`` file.
"""
# Remember config path for potential resync later
self._mcp_config_path = Path(config_path)
try:
with open(config_path, encoding="utf-8") as f:
config = json.load(f)
except Exception as e:
logger.warning(f"Failed to load MCP config from {config_path}: {e}")
return
base_dir = config_path.parent
# Support both formats:
# {"servers": [{"name": "x", ...}]} (list format)
# {"server-name": {"transport": ...}, ...} (dict format)
server_list = config.get("servers", [])
if not server_list and "servers" not in config:
# Treat top-level keys as server names
server_list = [{"name": name, **cfg} for name, cfg in config.items()]
resolved_server_list = [
self._resolve_mcp_server_config(server_config, base_dir)
for server_config in server_list
]
# Ordered first-wins for duplicate tool names across servers; keep tools.py tools.
self.load_registry_servers(
resolved_server_list,
log_summary=False,
preserve_existing_tools=True,
log_collisions=False,
)
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
self._mcp_cred_snapshot = self._snapshot_credentials()
self._mcp_aden_key_snapshot = os.environ.get("ADEN_API_KEY")
self._log_registry_snapshot("after load_mcp_config")
def _register_mcp_server_with_retry(
self,
server_config: dict[str, Any],
*,
preserve_existing_tools: bool = True,
tool_cap: int | None = None,
log_collisions: bool = False,
) -> tuple[bool, int, str | None]:
"""Register a single MCP server with one retry for transient failures."""
name = server_config.get("name", "unknown")
last_error: str | None = None
for attempt in range(2):
try:
count = self.register_mcp_server(
server_config,
preserve_existing_tools=preserve_existing_tools,
tool_cap=tool_cap,
log_collisions=log_collisions,
)
if count > 0:
return True, count, None
last_error = "registered 0 tools"
except Exception as exc:
last_error = str(exc)
if attempt == 0:
logger.warning(
"MCP server '%s' failed to register, retrying in 2s: %s",
name,
last_error,
)
import time
time.sleep(2)
else:
logger.warning("MCP server '%s' failed after retry: %s", name, last_error)
return False, 0, last_error
def load_registry_servers(
self,
server_list: list[dict[str, Any]],
*,
log_summary: bool = True,
preserve_existing_tools: bool = True,
max_tools: int | None = None,
log_collisions: bool = False,
) -> list[dict[str, Any]]:
"""Register MCP servers from a resolved config list (registry and/or static).
``preserve_existing_tools`` enforces first-wins tool names (FR-100): later
servers skip names already taken— including tools from ``mcp_servers.json``
or ``tools.py`` when those were loaded first.
``max_tools`` caps how many *new* tool names are registered across this batch
(collisions do not consume the cap). When ``log_collisions`` is True, skipped
duplicate names emit a warning (FR-101).
"""
results: list[dict[str, Any]] = []
tools_added_batch = 0
for server_config in server_list:
remaining: int | None = None
if max_tools is not None:
remaining = max_tools - tools_added_batch
if remaining <= 0:
break
name = server_config.get("name", "unknown")
success, tools_loaded, error = self._register_mcp_server_with_retry(
server_config,
preserve_existing_tools=preserve_existing_tools,
tool_cap=remaining,
log_collisions=log_collisions,
)
tools_added_batch += tools_loaded
result = {
"server": name,
"status": "loaded" if success else "skipped",
"tools_loaded": tools_loaded,
"skipped_reason": None if success else (error or "unknown error"),
}
results.append(result)
if log_summary:
logger.info(
"MCP registry server resolution",
extra={
"event": "mcp_registry_server_resolution",
"server": result["server"],
"status": result["status"],
"tools_loaded": result["tools_loaded"],
"skipped_reason": result["skipped_reason"],
},
)
return results
def register_mcp_server(
self,
server_config: dict[str, Any],
use_connection_manager: bool = True,
*,
preserve_existing_tools: bool = True,
tool_cap: int | None = None,
log_collisions: bool = False,
) -> int:
"""
Register an MCP server and discover its tools.
Args:
server_config: MCP server configuration dict with keys:
- name: Server name (required)
- transport: "stdio" or "http" (required)
- command: Command to run (for stdio)
- args: Command arguments (for stdio)
- env: Environment variables (for stdio)
- cwd: Working directory (for stdio)
- url: Server URL (for http)
- headers: HTTP headers (for http)
- description: Server description (optional)
use_connection_manager: When True, reuse a shared client keyed by server name
preserve_existing_tools: If True, do not replace tools already in the registry.
tool_cap: Max tools to newly register from this server (None = unlimited).
log_collisions: If True, log when this server skips a tool name already taken.
Returns:
Number of tools registered from this server
"""
try:
from framework.loader.mcp_client import MCPClient, MCPServerConfig
from framework.loader.mcp_connection_manager import MCPConnectionManager
# Build config object. Merge per-agent env on top of the
# server's own env so MCP subprocesses receive the identity
# of the worker that spawned them (instead of whichever
# worker most recently wrote to os.environ).
merged_env = {**self._mcp_extra_env, **(server_config.get("env") or {})}
config = MCPServerConfig(
name=server_config["name"],
transport=server_config["transport"],
command=server_config.get("command"),
args=server_config.get("args", []),
env=merged_env,
cwd=server_config.get("cwd"),
url=server_config.get("url"),
headers=server_config.get("headers", {}),
socket_path=server_config.get("socket_path"),
description=server_config.get("description", ""),
)
# Create and connect client
if use_connection_manager:
client = MCPConnectionManager.get_instance().acquire(config)
else:
client = MCPClient(config)
client.connect()
# Store client for cleanup
self._mcp_clients.append(client)
client_id = id(client)
self._mcp_client_servers[client_id] = config.name
if use_connection_manager:
self._mcp_managed_clients.add(client_id)
# Register each tool
server_name = server_config["name"]
if server_name not in self._mcp_server_tools:
self._mcp_server_tools[server_name] = set()
# Build admission gate: only admit MCP tools that are either
# (a) credential-backed *and* have a configured account, or
# (b) credential-less *and* listed in the verified manifest.
# Servers that don't expose `__aden_verified_manifest` (third-party
# MCP servers) bypass the gate entirely — preserves prior behavior.
admit = self._build_mcp_admission_gate(client)
count = 0
admitted_names: list[str] = []
for mcp_tool in client.list_tools():
if not admit(mcp_tool.name):
continue
if tool_cap is not None and count >= tool_cap:
break
if preserve_existing_tools and mcp_tool.name in self._tools:
if log_collisions:
origin_server = (
self._find_mcp_origin_server_for_tool(mcp_tool.name) or "<existing>"
)
logger.warning(
"MCP tool '%s' from '%s' shadowed by '%s' (loaded first)",
mcp_tool.name,
server_name,
origin_server,
)
# Skip registration; do not update MCP tool bookkeeping for this server.
continue
# Convert MCP tool to framework Tool (strips context params from LLM schema)
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
# Create executor that calls the MCP server
def make_mcp_executor(
client_ref: MCPClient,
tool_name: str,
registry_ref,
tool_params: set[str],
):
def executor(inputs: dict) -> Any:
try:
# Build base context: session < execution (execution wins)
base_context = dict(registry_ref._session_context)
exec_ctx = _execution_context.get()
if exec_ctx:
base_context.update(exec_ctx)
# Only inject context params the tool accepts
filtered_context = {
k: v for k, v in base_context.items() if k in tool_params
}
# Strip context params from LLM inputs — the framework
# values are authoritative (prevents the LLM from passing
# e.g. data_dir="/data" and overriding the real path).
clean_inputs = {
k: v
for k, v in inputs.items()
if k not in registry_ref.CONTEXT_PARAMS
}
merged_inputs = {**clean_inputs, **filtered_context}
result = client_ref.call_tool(tool_name, merged_inputs)
# MCP client already extracts content (returns str
# or {"_text": ..., "_images": ...} for image results).
# Handle legacy list format from HTTP transport.
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and "text" in result[0]:
return result[0]["text"]
return result[0]
return result
except Exception as e:
inputs_str = json.dumps(inputs, default=str)
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
logger.error(
"MCP tool '%s' execution failed: %s\nInputs: %s",
tool_name,
e,
inputs_str,
exc_info=True,
)
return {"error": str(e)}
return executor
tool_params = set(mcp_tool.input_schema.get("properties", {}).keys())
self.register(
mcp_tool.name,
tool,
make_mcp_executor(client, mcp_tool.name, self, tool_params),
)
self._mcp_tool_names.add(mcp_tool.name)
self._mcp_tool_clients[mcp_tool.name] = client
self._mcp_server_tools[server_name].add(mcp_tool.name)
admitted_names.append(mcp_tool.name)
count += 1
logger.info(
"MCP Registry Load",
extra={
"server": config.name,
"status": "success",
"tools_loaded": count,
"skipped_reason": None,
},
)
logger.info(
"MCP server '%s' admitted %d tool(s): %s",
config.name,
len(admitted_names),
sorted(admitted_names),
)
return count
except Exception as e:
logger.error(
"MCP Registry Load",
extra={
"server": server_config.get("name", "unknown"),
"status": "failed",
"tools_loaded": 0,
"skipped_reason": str(e),
},
)
if "Connection closed" in str(e) and os.name == "nt":
logger.debug(
"On Windows, check that the MCP subprocess starts (e.g. uv in PATH, "
"script path correct). Worker config uses base_dir = mcp_servers.json parent."
)
return 0
def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None:
for server_name, tool_names in self._mcp_server_tools.items():
if tool_name in tool_names:
return server_name
return None
def _log_registry_snapshot(self, context: str) -> None:
"""Emit a one-line summary of the current tool registry.
Called after every tool-list mutation (initial load + resync) so that
operators can correlate "what tools does the queen have right now"
with credential changes and MCP server lifecycle events. Per-server
contents are already logged by `register_mcp_server`; this is just the
rollup so the resync path also gets a single anchor line.
"""
per_server_counts = {
server: len(names) for server, names in self._mcp_server_tools.items()
}
non_mcp_count = len(self._tools) - len(self._mcp_tool_names)
logger.info(
"ToolRegistry snapshot (%s): total=%d, mcp=%d, non_mcp=%d, per_server=%s",
context,
len(self._tools),
len(self._mcp_tool_names),
non_mcp_count,
per_server_counts,
)
_MCP_VERIFIED_MANIFEST_TOOL = "__aden_verified_manifest"
def _build_mcp_admission_gate(self, client: Any) -> Callable[[str], bool]:
"""Build a per-server predicate that filters MCP tools at registration.
Rules:
* The sentinel manifest tool itself is never admitted.
* Credential-backed tools (provider in `tool_provider_map`) are
admitted only when at least one account exists for that provider.
* Credential-less tools are admitted only when they appear in the
server's verified manifest.
* Servers that don't expose a manifest bypass the verified gate
entirely (third-party MCP servers behave as before).
"""
verified_names: set[str] = set()
manifest_present = False
try:
raw = client.call_tool(self._MCP_VERIFIED_MANIFEST_TOOL, {})
parsed: Any = raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
parsed = None
if isinstance(parsed, list):
verified_names = {str(n) for n in parsed}
manifest_present = True
except Exception:
# Server doesn't expose the manifest — no verified gate applies.
pass
tool_provider_map: dict[str, str] = {}
live_providers: set[str] = set()
try:
from aden_tools.credentials.store_adapter import CredentialStoreAdapter
adapter = CredentialStoreAdapter.default()
tool_provider_map = adapter.get_tool_provider_map()
live_providers = {
a.get("provider", "")
for a in adapter.get_all_account_info()
if a.get("provider")
}
except Exception:
logger.debug("Credential snapshot unavailable for MCP gate", exc_info=True)
def admit(tool_name: str) -> bool:
if tool_name == self._MCP_VERIFIED_MANIFEST_TOOL:
return False
provider = tool_provider_map.get(tool_name)
if provider:
# Credentialed tool — needs an account.
return provider in live_providers
if not manifest_present:
# Third-party MCP server: preserve legacy "admit everything".
return True
return tool_name in verified_names
return admit
def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool:
"""
Convert an MCP tool to a framework Tool.
Args:
mcp_tool: MCPTool object
Returns:
Framework Tool object
"""
# Extract parameters from MCP input schema
input_schema = mcp_tool.input_schema
properties = input_schema.get("properties", {})
required = input_schema.get("required", [])
# Strip framework-internal context params from LLM-facing schema.
# The LLM can't know these values; they're auto-injected at call time.
properties = {k: v for k, v in properties.items() if k not in self.CONTEXT_PARAMS}
required = [r for r in required if r not in self.CONTEXT_PARAMS]
# Convert to framework Tool format
tool = Tool(
name=mcp_tool.name,
description=mcp_tool.description,
parameters={
"type": "object",
"properties": properties,
"required": required,
},
)
return tool
# ------------------------------------------------------------------
# Provider-based tool filtering
# ------------------------------------------------------------------
def build_provider_index(self) -> None:
"""Build provider -> tool-name mapping from CREDENTIAL_SPECS.
Populates ``_provider_index`` so :meth:`get_by_provider` works.
Safe to call even if ``aden_tools`` is not installed (silently no-ops).
"""
try:
from aden_tools.credentials import CREDENTIAL_SPECS
except ImportError:
logger.debug("aden_tools not available, skipping provider index")
return
self._provider_index.clear()
for spec in CREDENTIAL_SPECS.values():
provider = spec.aden_provider_name
if provider:
if provider not in self._provider_index:
self._provider_index[provider] = set()
self._provider_index[provider].update(spec.tools)
def get_by_provider(self, provider: str) -> dict[str, Tool]:
"""Return registered tools that belong to *provider*.
Lazily builds the provider index on first call.
"""
if not self._provider_index:
self.build_provider_index()
tool_names = self._provider_index.get(provider, set())
return {name: rt.tool for name, rt in self._tools.items() if name in tool_names}
def get_tool_names_by_provider(self, provider: str) -> list[str]:
"""Return sorted registered tool names for *provider*."""
if not self._provider_index:
self.build_provider_index()
tool_names = self._provider_index.get(provider, set())
return sorted(name for name in self._tools if name in tool_names)
def get_all_provider_tool_names(self) -> list[str]:
"""Return sorted names of all registered tools that belong to any provider."""
if not self._provider_index:
self.build_provider_index()
all_names: set[str] = set()
for names in self._provider_index.values():
all_names.update(names)
return sorted(name for name in self._tools if name in all_names)
# ------------------------------------------------------------------
# MCP credential resync
# ------------------------------------------------------------------
def set_mcp_registry_agent_path(self, agent_path: Path | None) -> None:
"""Remember agent dir so registry MCP servers reload after credential resync."""
self._mcp_registry_agent_path = None if agent_path is None else Path(agent_path)
def reload_registry_mcp_servers_after_resync(self) -> None:
"""Re-run ``mcp_registry.json`` resolution and register servers (post-resync)."""
if self._mcp_registry_agent_path is None:
return
from framework.loader.mcp_registry import MCPRegistry
try:
reg = MCPRegistry()
reg.initialize()
configs, selection_max_tools = reg.load_agent_selection(self._mcp_registry_agent_path)
except Exception as exc:
logger.warning(
"Failed to reload MCP registry servers after resync for '%s': %s",
self._mcp_registry_agent_path.name,
exc,
)
return
if not configs:
return
self.load_registry_servers(
configs,
log_summary=True,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
def _snapshot_credentials(self) -> set[str]:
"""Return the set of credential filenames currently on disk."""
try:
return set(self._CREDENTIAL_DIR.iterdir()) if self._CREDENTIAL_DIR.is_dir() else set()
except OSError:
return set()
def resync_mcp_servers_if_needed(self) -> bool:
"""Restart MCP servers if credential files changed since last load.
Compares the current credential directory listing against the snapshot
taken when MCP servers were first loaded. If new files appeared (e.g.
user connected an OAuth account mid-session), disconnects all MCP
clients and re-loads them so the new subprocess picks up the fresh
credentials.
Note: Individual credential TTL/refresh is handled by the MCP server
process internally -- it resolves tokens from the credential store
on every tool call, not at startup. This method only handles the case
where entirely new credential files appear.
Returns True if a resync was performed, False otherwise.
"""
if not self._mcp_clients or self._mcp_config_path is None:
return False
current = self._snapshot_credentials()
current_aden_key = os.environ.get("ADEN_API_KEY")
files_changed = current != self._mcp_cred_snapshot
aden_key_changed = current_aden_key != self._mcp_aden_key_snapshot
if not files_changed and not aden_key_changed:
return False
reason = (
"Credential files and ADEN_API_KEY changed"
if files_changed and aden_key_changed
else "ADEN_API_KEY changed"
if aden_key_changed
else "Credential files changed"
)
logger.info("%s — resyncing MCP servers", reason)
# 1. Disconnect existing MCP clients
self._cleanup_mcp_clients("during resync")
# 2. Remove MCP-registered tools
for name in self._mcp_tool_names:
self._tools.pop(name, None)
self._mcp_tool_names.clear()
self._mcp_server_tools.clear()
# 3. Re-load MCP servers (spawns fresh subprocesses with new credentials)
self.load_mcp_config(self._mcp_config_path)
if self._mcp_registry_agent_path is not None:
self.reload_registry_mcp_servers_after_resync()
logger.info("MCP server resync complete")
self._log_registry_snapshot("after resync_mcp_servers_if_needed")
return True
def cleanup(self) -> None:
"""Clean up all MCP client connections."""
self._cleanup_mcp_clients()
def _cleanup_mcp_clients(self, context: str = "") -> None:
"""Disconnect or release all tracked MCP clients for this registry."""
if context:
context = f" {context}"
for client in self._mcp_clients:
client_id = id(client)
server_name = self._mcp_client_servers.get(client_id, client.config.name)
try:
if client_id in self._mcp_managed_clients:
from framework.loader.mcp_connection_manager import MCPConnectionManager
MCPConnectionManager.get_instance().release(server_name)
else:
client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting MCP client{context}: {e}")
self._mcp_clients.clear()
self._mcp_client_servers.clear()
self._mcp_managed_clients.clear()
self._mcp_tool_clients.clear()
def kill_mcp_for_tool(self, tool_name: str) -> bool:
"""Force-disconnect the MCP client that owns *tool_name*.
Called from the timeout handler in ``execute_tool`` when a tool
call hangs. Plain ``asyncio.wait_for`` cancellation cannot stop
a sync executor running inside a thread pool (and therefore
cannot stop the MCP subprocess), so we reach through to the
client here and tear it down. The next ``call_tool`` triggers
an automatic reconnect.
Returns True if a client was found and disconnect was attempted.
"""
client = self._mcp_tool_clients.get(tool_name)
if client is None:
return False
try:
logger.warning(
"Force-disconnecting MCP client for hung tool '%s' on server '%s'",
tool_name,
getattr(client.config, "name", "?"),
)
client.disconnect()
except Exception as exc:
logger.warning("Error force-disconnecting MCP client for '%s': %s", tool_name, exc)
return True
def __del__(self):
"""Destructor to ensure cleanup."""
self.cleanup()
def tool(
description: str | None = None,
name: str | None = None,
) -> Callable:
"""
Decorator to mark a function as a tool.
Usage:
@tool(description="Fetch lead from GTM table")
def gtm_fetch_lead(lead_id: str) -> dict:
return {"lead_data": {...}}
"""
def decorator(func: Callable) -> Callable:
func._tool_metadata = {
"name": name or func.__name__,
"description": description or func.__doc__,
}
return func
return decorator