"""Tool discovery and registration for agent runner.""" import importlib.util import inspect import json import logging from collections.abc import Callable from dataclasses import dataclass from pathlib import Path from typing import Any from framework.llm.provider import Tool, ToolResult, ToolUse logger = logging.getLogger(__name__) @dataclass class RegisteredTool: """A tool with its executor function.""" tool: Tool executor: Callable[[dict], Any] class ToolRegistry: """ Manages tool discovery and registration. Tool Discovery Order: 1. Built-in tools (if any) 2. tools.py in agent folder 3. MCP servers 4. Manually registered tools """ def __init__(self): self._tools: dict[str, RegisteredTool] = {} self._mcp_clients: list[Any] = [] # List of MCPClient instances self._session_context: dict[str, Any] = {} # Auto-injected context for tools def register( self, name: str, tool: Tool, executor: Callable[[dict], Any], ) -> None: """ Register a single tool with its executor. Args: name: Tool name (must match tool.name) tool: Tool definition executor: Function that takes tool input dict and returns result """ self._tools[name] = RegisteredTool(tool=tool, executor=executor) def register_function( self, func: Callable, name: str | None = None, description: str | None = None, ) -> None: """ Register a function as a tool, auto-generating the Tool definition. Args: func: Function to register name: Tool name (defaults to function name) description: Tool description (defaults to docstring) """ tool_name = name or func.__name__ tool_desc = description or func.__doc__ or f"Execute {tool_name}" # Generate parameters from function signature sig = inspect.signature(func) properties = {} required = [] for param_name, param in sig.parameters.items(): if param_name in ("self", "cls"): continue param_type = "string" # Default if param.annotation != inspect.Parameter.empty: if param.annotation is int: param_type = "integer" elif param.annotation is float: param_type = "number" elif param.annotation is bool: param_type = "boolean" elif param.annotation is dict: param_type = "object" elif param.annotation is list: param_type = "array" properties[param_name] = {"type": param_type} if param.default == inspect.Parameter.empty: required.append(param_name) tool = Tool( name=tool_name, description=tool_desc, parameters={ "type": "object", "properties": properties, "required": required, }, ) def executor(inputs: dict) -> Any: return func(**inputs) self.register(tool_name, tool, executor) def discover_from_module(self, module_path: Path) -> int: """ Load tools from a Python module file. Looks for: - TOOLS: dict[str, Tool] - tool definitions - tool_executor(tool_use: ToolUse) -> ToolResult - unified executor - Functions decorated with @tool Args: module_path: Path to tools.py file Returns: Number of tools discovered """ if not module_path.exists(): return 0 # Load the module dynamically spec = importlib.util.spec_from_file_location("agent_tools", module_path) if spec is None or spec.loader is None: return 0 module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) count = 0 # Check for TOOLS dict if hasattr(module, "TOOLS"): tools_dict = module.TOOLS executor_func = getattr(module, "tool_executor", None) for name, tool in tools_dict.items(): if executor_func: # Use unified executor def make_executor(tool_name: str): def executor(inputs: dict) -> Any: tool_use = ToolUse( id=f"call_{tool_name}", name=tool_name, input=inputs, ) result = executor_func(tool_use) if isinstance(result, ToolResult): return json.loads(result.content) if result.content else {} return result return executor self.register(name, tool, make_executor(name)) else: # Register tool without executor (will use mock) self.register(name, tool, lambda inputs: {"mock": True, "inputs": inputs}) count += 1 # Check for @tool decorated functions for name in dir(module): obj = getattr(module, name) if callable(obj) and hasattr(obj, "_tool_metadata"): metadata = obj._tool_metadata self.register_function( obj, name=metadata.get("name", name), description=metadata.get("description"), ) count += 1 return count def get_tools(self) -> dict[str, Tool]: """Get all registered Tool objects.""" return {name: rt.tool for name, rt in self._tools.items()} def get_executor(self) -> Callable[[ToolUse], ToolResult]: """ Get unified tool executor function. Returns a function that dispatches to the appropriate tool executor. """ def executor(tool_use: ToolUse) -> ToolResult: if tool_use.name not in self._tools: return ToolResult( tool_use_id=tool_use.id, content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}), is_error=True, ) registered = self._tools[tool_use.name] try: result = registered.executor(tool_use.input) if isinstance(result, ToolResult): return result return ToolResult( tool_use_id=tool_use.id, content=json.dumps(result) if not isinstance(result, str) else result, is_error=False, ) except Exception as e: return ToolResult( tool_use_id=tool_use.id, content=json.dumps({"error": str(e)}), is_error=True, ) return executor def get_registered_names(self) -> list[str]: """Get list of registered tool names.""" return list(self._tools.keys()) def has_tool(self, name: str) -> bool: """Check if a tool is registered.""" return name in self._tools def set_session_context(self, **context) -> None: """ Set session context to auto-inject into tool calls. Args: **context: Key-value pairs to inject (e.g., workspace_id, agent_id, session_id) """ self._session_context.update(context) def register_mcp_server( self, server_config: dict[str, Any], ) -> int: """ Register an MCP server and discover its tools. Args: server_config: MCP server configuration dict with keys: - name: Server name (required) - transport: "stdio" or "http" (required) - command: Command to run (for stdio) - args: Command arguments (for stdio) - env: Environment variables (for stdio) - cwd: Working directory (for stdio) - url: Server URL (for http) - headers: HTTP headers (for http) - description: Server description (optional) Returns: Number of tools registered from this server """ try: from framework.runner.mcp_client import MCPClient, MCPServerConfig # Build config object config = MCPServerConfig( name=server_config["name"], transport=server_config["transport"], command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}), cwd=server_config.get("cwd"), url=server_config.get("url"), headers=server_config.get("headers", {}), description=server_config.get("description", ""), ) # Create and connect client client = MCPClient(config) client.connect() # Store client for cleanup self._mcp_clients.append(client) # Register each tool count = 0 for mcp_tool in client.list_tools(): # Convert MCP tool to framework Tool tool = self._convert_mcp_tool_to_framework_tool(mcp_tool) # Create executor that calls the MCP server def make_mcp_executor(client_ref: MCPClient, tool_name: str, registry_ref): def executor(inputs: dict) -> Any: try: # Inject session context for tools that need it merged_inputs = {**registry_ref._session_context, **inputs} result = client_ref.call_tool(tool_name, merged_inputs) # MCP tools return content array, extract the result if isinstance(result, list) and len(result) > 0: if isinstance(result[0], dict) and "text" in result[0]: return result[0]["text"] return result[0] return result except Exception as e: logger.error(f"MCP tool '{tool_name}' execution failed: {e}") return {"error": str(e)} return executor self.register( mcp_tool.name, tool, make_mcp_executor(client, mcp_tool.name, self), ) count += 1 logger.info(f"Registered {count} tools from MCP server '{config.name}'") return count except Exception as e: logger.error(f"Failed to register MCP server: {e}") return 0 def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool: """ Convert an MCP tool to a framework Tool. Args: mcp_tool: MCPTool object Returns: Framework Tool object """ # Extract parameters from MCP input schema input_schema = mcp_tool.input_schema properties = input_schema.get("properties", {}) required = input_schema.get("required", []) # Convert to framework Tool format tool = Tool( name=mcp_tool.name, description=mcp_tool.description, parameters={ "type": "object", "properties": properties, "required": required, }, ) return tool def cleanup(self) -> None: """Clean up all MCP client connections.""" for client in self._mcp_clients: try: client.disconnect() except Exception as e: logger.warning(f"Error disconnecting MCP client: {e}") self._mcp_clients.clear() def __del__(self): """Destructor to ensure cleanup.""" self.cleanup() def tool( description: str | None = None, name: str | None = None, ) -> Callable: """ Decorator to mark a function as a tool. Usage: @tool(description="Fetch lead from GTM table") def gtm_fetch_lead(lead_id: str) -> dict: return {"lead_data": {...}} """ def decorator(func: Callable) -> Callable: func._tool_metadata = { "name": name or func.__name__, "description": description or func.__doc__, } return func return decorator