Files
hive/core/framework/runner/tool_registry.py
T
2026-01-26 07:32:03 -08:00

388 lines
13 KiB
Python

"""Tool discovery and registration for agent runner."""
import importlib.util
import inspect
import json
import logging
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from framework.llm.provider import Tool, ToolResult, ToolUse
logger = logging.getLogger(__name__)
@dataclass
class RegisteredTool:
"""A tool with its executor function."""
tool: Tool
executor: Callable[[dict], Any]
class ToolRegistry:
"""
Manages tool discovery and registration.
Tool Discovery Order:
1. Built-in tools (if any)
2. tools.py in agent folder
3. MCP servers
4. Manually registered tools
"""
def __init__(self):
self._tools: dict[str, RegisteredTool] = {}
self._mcp_clients: list[Any] = [] # List of MCPClient instances
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
def register(
self,
name: str,
tool: Tool,
executor: Callable[[dict], Any],
) -> None:
"""
Register a single tool with its executor.
Args:
name: Tool name (must match tool.name)
tool: Tool definition
executor: Function that takes tool input dict and returns result
"""
self._tools[name] = RegisteredTool(tool=tool, executor=executor)
def register_function(
self,
func: Callable,
name: str | None = None,
description: str | None = None,
) -> None:
"""
Register a function as a tool, auto-generating the Tool definition.
Args:
func: Function to register
name: Tool name (defaults to function name)
description: Tool description (defaults to docstring)
"""
tool_name = name or func.__name__
tool_desc = description or func.__doc__ or f"Execute {tool_name}"
# Generate parameters from function signature
sig = inspect.signature(func)
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name in ("self", "cls"):
continue
param_type = "string" # Default
if param.annotation != inspect.Parameter.empty:
if param.annotation is int:
param_type = "integer"
elif param.annotation is float:
param_type = "number"
elif param.annotation is bool:
param_type = "boolean"
elif param.annotation is dict:
param_type = "object"
elif param.annotation is list:
param_type = "array"
properties[param_name] = {"type": param_type}
if param.default == inspect.Parameter.empty:
required.append(param_name)
tool = Tool(
name=tool_name,
description=tool_desc,
parameters={
"type": "object",
"properties": properties,
"required": required,
},
)
def executor(inputs: dict) -> Any:
return func(**inputs)
self.register(tool_name, tool, executor)
def discover_from_module(self, module_path: Path) -> int:
"""
Load tools from a Python module file.
Looks for:
- TOOLS: dict[str, Tool] - tool definitions
- tool_executor(tool_use: ToolUse) -> ToolResult - unified executor
- Functions decorated with @tool
Args:
module_path: Path to tools.py file
Returns:
Number of tools discovered
"""
if not module_path.exists():
return 0
# Load the module dynamically
spec = importlib.util.spec_from_file_location("agent_tools", module_path)
if spec is None or spec.loader is None:
return 0
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
count = 0
# Check for TOOLS dict
if hasattr(module, "TOOLS"):
tools_dict = module.TOOLS
executor_func = getattr(module, "tool_executor", None)
for name, tool in tools_dict.items():
if executor_func:
# Use unified executor
def make_executor(tool_name: str):
def executor(inputs: dict) -> Any:
tool_use = ToolUse(
id=f"call_{tool_name}",
name=tool_name,
input=inputs,
)
result = executor_func(tool_use)
if isinstance(result, ToolResult):
return json.loads(result.content) if result.content else {}
return result
return executor
self.register(name, tool, make_executor(name))
else:
# Register tool without executor (will use mock)
self.register(name, tool, lambda inputs: {"mock": True, "inputs": inputs})
count += 1
# Check for @tool decorated functions
for name in dir(module):
obj = getattr(module, name)
if callable(obj) and hasattr(obj, "_tool_metadata"):
metadata = obj._tool_metadata
self.register_function(
obj,
name=metadata.get("name", name),
description=metadata.get("description"),
)
count += 1
return count
def get_tools(self) -> dict[str, Tool]:
"""Get all registered Tool objects."""
return {name: rt.tool for name, rt in self._tools.items()}
def get_executor(self) -> Callable[[ToolUse], ToolResult]:
"""
Get unified tool executor function.
Returns a function that dispatches to the appropriate tool executor.
"""
def executor(tool_use: ToolUse) -> ToolResult:
if tool_use.name not in self._tools:
return ToolResult(
tool_use_id=tool_use.id,
content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}),
is_error=True,
)
registered = self._tools[tool_use.name]
try:
result = registered.executor(tool_use.input)
if isinstance(result, ToolResult):
return result
return ToolResult(
tool_use_id=tool_use.id,
content=json.dumps(result) if not isinstance(result, str) else result,
is_error=False,
)
except Exception as e:
return ToolResult(
tool_use_id=tool_use.id,
content=json.dumps({"error": str(e)}),
is_error=True,
)
return executor
def get_registered_names(self) -> list[str]:
"""Get list of registered tool names."""
return list(self._tools.keys())
def has_tool(self, name: str) -> bool:
"""Check if a tool is registered."""
return name in self._tools
def set_session_context(self, **context) -> None:
"""
Set session context to auto-inject into tool calls.
Args:
**context: Key-value pairs to inject (e.g., workspace_id, agent_id, session_id)
"""
self._session_context.update(context)
def register_mcp_server(
self,
server_config: dict[str, Any],
) -> int:
"""
Register an MCP server and discover its tools.
Args:
server_config: MCP server configuration dict with keys:
- name: Server name (required)
- transport: "stdio" or "http" (required)
- command: Command to run (for stdio)
- args: Command arguments (for stdio)
- env: Environment variables (for stdio)
- cwd: Working directory (for stdio)
- url: Server URL (for http)
- headers: HTTP headers (for http)
- description: Server description (optional)
Returns:
Number of tools registered from this server
"""
try:
from framework.runner.mcp_client import MCPClient, MCPServerConfig
# Build config object
config = MCPServerConfig(
name=server_config["name"],
transport=server_config["transport"],
command=server_config.get("command"),
args=server_config.get("args", []),
env=server_config.get("env", {}),
cwd=server_config.get("cwd"),
url=server_config.get("url"),
headers=server_config.get("headers", {}),
description=server_config.get("description", ""),
)
# Create and connect client
client = MCPClient(config)
client.connect()
# Store client for cleanup
self._mcp_clients.append(client)
# Register each tool
count = 0
for mcp_tool in client.list_tools():
# Convert MCP tool to framework Tool
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
# Create executor that calls the MCP server
def make_mcp_executor(client_ref: MCPClient, tool_name: str, registry_ref):
def executor(inputs: dict) -> Any:
try:
# Inject session context for tools that need it
merged_inputs = {**registry_ref._session_context, **inputs}
result = client_ref.call_tool(tool_name, merged_inputs)
# MCP tools return content array, extract the result
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and "text" in result[0]:
return result[0]["text"]
return result[0]
return result
except Exception as e:
logger.error(f"MCP tool '{tool_name}' execution failed: {e}")
return {"error": str(e)}
return executor
self.register(
mcp_tool.name,
tool,
make_mcp_executor(client, mcp_tool.name, self),
)
count += 1
logger.info(f"Registered {count} tools from MCP server '{config.name}'")
return count
except Exception as e:
logger.error(f"Failed to register MCP server: {e}")
return 0
def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool:
"""
Convert an MCP tool to a framework Tool.
Args:
mcp_tool: MCPTool object
Returns:
Framework Tool object
"""
# Extract parameters from MCP input schema
input_schema = mcp_tool.input_schema
properties = input_schema.get("properties", {})
required = input_schema.get("required", [])
# Convert to framework Tool format
tool = Tool(
name=mcp_tool.name,
description=mcp_tool.description,
parameters={
"type": "object",
"properties": properties,
"required": required,
},
)
return tool
def cleanup(self) -> None:
"""Clean up all MCP client connections."""
for client in self._mcp_clients:
try:
client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting MCP client: {e}")
self._mcp_clients.clear()
def __del__(self):
"""Destructor to ensure cleanup."""
self.cleanup()
def tool(
description: str | None = None,
name: str | None = None,
) -> Callable:
"""
Decorator to mark a function as a tool.
Usage:
@tool(description="Fetch lead from GTM table")
def gtm_fetch_lead(lead_id: str) -> dict:
return {"lead_data": {...}}
"""
def decorator(func: Callable) -> Callable:
func._tool_metadata = {
"name": name or func.__name__,
"description": description or func.__doc__,
}
return func
return decorator