Files
hive/core/framework/runner/mcp_client.py
T
2026-01-20 13:28:42 -08:00

354 lines
11 KiB
Python

"""MCP Client for connecting to Model Context Protocol servers.
This module provides a client for connecting to MCP servers and invoking their tools.
Supports both STDIO and HTTP transports using the official MCP Python SDK.
"""
import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
import httpx
logger = logging.getLogger(__name__)
@dataclass
class MCPServerConfig:
"""Configuration for an MCP server connection."""
name: str
transport: Literal["stdio", "http"]
# For STDIO transport
command: str | None = None
args: list[str] = field(default_factory=list)
env: dict[str, str] = field(default_factory=dict)
cwd: str | None = None
# For HTTP transport
url: str | None = None
headers: dict[str, str] = field(default_factory=dict)
# Optional metadata
description: str = ""
@dataclass
class MCPTool:
"""A tool available from an MCP server."""
name: str
description: str
input_schema: dict[str, Any]
server_name: str
class MCPClient:
"""
Client for communicating with MCP servers.
Supports both STDIO and HTTP transports using the official MCP SDK.
Manages the connection lifecycle and provides methods to list and invoke tools.
"""
def __init__(self, config: MCPServerConfig):
"""
Initialize the MCP client.
Args:
config: Server configuration
"""
self.config = config
self._session = None
self._read_stream = None
self._write_stream = None
self._http_client: httpx.Client | None = None
self._tools: dict[str, MCPTool] = {}
self._connected = False
def _run_async(self, coro):
"""
Run an async coroutine, handling both sync and async contexts.
Args:
coro: Coroutine to run
Returns:
Result of the coroutine
"""
try:
# Try to get the current event loop
asyncio.get_running_loop()
# If we're here, we're in an async context
# Create a new thread to run the coroutine
import threading
result = None
exception = None
def run_in_thread():
nonlocal result, exception
try:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
result = new_loop.run_until_complete(coro)
finally:
new_loop.close()
except Exception as e:
exception = e
thread = threading.Thread(target=run_in_thread)
thread.start()
thread.join()
if exception:
raise exception
return result
except RuntimeError:
# No event loop running, we can use asyncio.run
return asyncio.run(coro)
def connect(self) -> None:
"""Connect to the MCP server."""
if self._connected:
return
if self.config.transport == "stdio":
self._connect_stdio()
elif self.config.transport == "http":
self._connect_http()
else:
raise ValueError(f"Unsupported transport: {self.config.transport}")
# Discover tools
self._discover_tools()
self._connected = True
def _connect_stdio(self) -> None:
"""Connect to MCP server via STDIO transport using MCP SDK."""
if not self.config.command:
raise ValueError("command is required for STDIO transport")
try:
# Import MCP SDK
from mcp import StdioServerParameters
# Create server parameters
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args,
env=self.config.env or None,
cwd=self.config.cwd,
)
# Store for later use in async context
self._server_params = server_params
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO")
except Exception as e:
raise RuntimeError(f"Failed to connect to MCP server: {e}")
def _connect_http(self) -> None:
"""Connect to MCP server via HTTP transport."""
if not self.config.url:
raise ValueError("url is required for HTTP transport")
self._http_client = httpx.Client(
base_url=self.config.url,
headers=self.config.headers,
timeout=30.0,
)
# Test connection
try:
response = self._http_client.get("/health")
response.raise_for_status()
logger.info(f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}")
except Exception as e:
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
# Continue anyway, server might not have health endpoint
def _discover_tools(self) -> None:
"""Discover available tools from the MCP server."""
try:
if self.config.transport == "stdio":
tools_list = self._run_async(self._list_tools_stdio_async())
else:
tools_list = self._list_tools_http()
self._tools = {}
for tool_data in tools_list:
tool = MCPTool(
name=tool_data["name"],
description=tool_data.get("description", ""),
input_schema=tool_data.get("inputSchema", {}),
server_name=self.config.name,
)
self._tools[tool.name] = tool
logger.info(f"Discovered {len(self._tools)} tools from '{self.config.name}': {list(self._tools.keys())}")
except Exception as e:
logger.error(f"Failed to discover tools from '{self.config.name}': {e}")
raise
async def _list_tools_stdio_async(self) -> list[dict]:
"""List tools via STDIO protocol using MCP SDK."""
from mcp import ClientSession
from mcp.client.stdio import stdio_client
async with stdio_client(self._server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the session
await session.initialize()
# List tools
response = await session.list_tools()
# Convert tools to dict format
tools_list = []
for tool in response.tools:
tools_list.append({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema,
})
return tools_list
def _list_tools_http(self) -> list[dict]:
"""List tools via HTTP protocol."""
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
try:
# Use MCP over HTTP protocol
response = self._http_client.post(
"/mcp/v1",
json={
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {},
},
)
response.raise_for_status()
data = response.json()
if "error" in data:
raise RuntimeError(f"MCP error: {data['error']}")
return data.get("result", {}).get("tools", [])
except Exception as e:
raise RuntimeError(f"Failed to list tools via HTTP: {e}")
def list_tools(self) -> list[MCPTool]:
"""
Get list of available tools.
Returns:
List of MCPTool objects
"""
if not self._connected:
self.connect()
return list(self._tools.values())
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""
Invoke a tool on the MCP server.
Args:
tool_name: Name of the tool to invoke
arguments: Tool arguments
Returns:
Tool result
"""
if not self._connected:
self.connect()
if tool_name not in self._tools:
raise ValueError(f"Unknown tool: {tool_name}")
if self.config.transport == "stdio":
return self._run_async(self._call_tool_stdio_async(tool_name, arguments))
else:
return self._call_tool_http(tool_name, arguments)
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call tool via STDIO protocol using MCP SDK."""
from mcp import ClientSession
from mcp.client.stdio import stdio_client
async with stdio_client(self._server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the session
await session.initialize()
# Call tool
result = await session.call_tool(tool_name, arguments=arguments)
# Extract content
if result.content:
# MCP returns content as a list of content items
if len(result.content) > 0:
content_item = result.content[0]
# Check if it's a text content item
if hasattr(content_item, 'text'):
return content_item.text
elif hasattr(content_item, 'data'):
return content_item.data
return result.content
return None
def _call_tool_http(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call tool via HTTP protocol."""
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
try:
response = self._http_client.post(
"/mcp/v1",
json={
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments,
},
},
)
response.raise_for_status()
data = response.json()
if "error" in data:
raise RuntimeError(f"Tool execution error: {data['error']}")
return data.get("result", {}).get("content", [])
except Exception as e:
raise RuntimeError(f"Failed to call tool via HTTP: {e}")
def disconnect(self) -> None:
"""Disconnect from the MCP server."""
if self._http_client:
self._http_client.close()
self._http_client = None
self._connected = False
logger.info(f"Disconnected from MCP server '{self.config.name}'")
def __enter__(self):
"""Context manager entry."""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.disconnect()