354 lines
11 KiB
Python
354 lines
11 KiB
Python
"""MCP Client for connecting to Model Context Protocol servers.
|
|
|
|
This module provides a client for connecting to MCP servers and invoking their tools.
|
|
Supports both STDIO and HTTP transports using the official MCP Python SDK.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class MCPServerConfig:
|
|
"""Configuration for an MCP server connection."""
|
|
|
|
name: str
|
|
transport: Literal["stdio", "http"]
|
|
|
|
# For STDIO transport
|
|
command: str | None = None
|
|
args: list[str] = field(default_factory=list)
|
|
env: dict[str, str] = field(default_factory=dict)
|
|
cwd: str | None = None
|
|
|
|
# For HTTP transport
|
|
url: str | None = None
|
|
headers: dict[str, str] = field(default_factory=dict)
|
|
|
|
# Optional metadata
|
|
description: str = ""
|
|
|
|
|
|
@dataclass
|
|
class MCPTool:
|
|
"""A tool available from an MCP server."""
|
|
|
|
name: str
|
|
description: str
|
|
input_schema: dict[str, Any]
|
|
server_name: str
|
|
|
|
|
|
class MCPClient:
|
|
"""
|
|
Client for communicating with MCP servers.
|
|
|
|
Supports both STDIO and HTTP transports using the official MCP SDK.
|
|
Manages the connection lifecycle and provides methods to list and invoke tools.
|
|
"""
|
|
|
|
def __init__(self, config: MCPServerConfig):
|
|
"""
|
|
Initialize the MCP client.
|
|
|
|
Args:
|
|
config: Server configuration
|
|
"""
|
|
self.config = config
|
|
self._session = None
|
|
self._read_stream = None
|
|
self._write_stream = None
|
|
self._http_client: httpx.Client | None = None
|
|
self._tools: dict[str, MCPTool] = {}
|
|
self._connected = False
|
|
|
|
def _run_async(self, coro):
|
|
"""
|
|
Run an async coroutine, handling both sync and async contexts.
|
|
|
|
Args:
|
|
coro: Coroutine to run
|
|
|
|
Returns:
|
|
Result of the coroutine
|
|
"""
|
|
try:
|
|
# Try to get the current event loop
|
|
asyncio.get_running_loop()
|
|
# If we're here, we're in an async context
|
|
# Create a new thread to run the coroutine
|
|
import threading
|
|
|
|
result = None
|
|
exception = None
|
|
|
|
def run_in_thread():
|
|
nonlocal result, exception
|
|
try:
|
|
new_loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(new_loop)
|
|
try:
|
|
result = new_loop.run_until_complete(coro)
|
|
finally:
|
|
new_loop.close()
|
|
except Exception as e:
|
|
exception = e
|
|
|
|
thread = threading.Thread(target=run_in_thread)
|
|
thread.start()
|
|
thread.join()
|
|
|
|
if exception:
|
|
raise exception
|
|
return result
|
|
except RuntimeError:
|
|
# No event loop running, we can use asyncio.run
|
|
return asyncio.run(coro)
|
|
|
|
def connect(self) -> None:
|
|
"""Connect to the MCP server."""
|
|
if self._connected:
|
|
return
|
|
|
|
if self.config.transport == "stdio":
|
|
self._connect_stdio()
|
|
elif self.config.transport == "http":
|
|
self._connect_http()
|
|
else:
|
|
raise ValueError(f"Unsupported transport: {self.config.transport}")
|
|
|
|
# Discover tools
|
|
self._discover_tools()
|
|
self._connected = True
|
|
|
|
def _connect_stdio(self) -> None:
|
|
"""Connect to MCP server via STDIO transport using MCP SDK."""
|
|
if not self.config.command:
|
|
raise ValueError("command is required for STDIO transport")
|
|
|
|
try:
|
|
# Import MCP SDK
|
|
from mcp import StdioServerParameters
|
|
|
|
# Create server parameters
|
|
server_params = StdioServerParameters(
|
|
command=self.config.command,
|
|
args=self.config.args,
|
|
env=self.config.env or None,
|
|
cwd=self.config.cwd,
|
|
)
|
|
|
|
# Store for later use in async context
|
|
self._server_params = server_params
|
|
|
|
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to connect to MCP server: {e}")
|
|
|
|
def _connect_http(self) -> None:
|
|
"""Connect to MCP server via HTTP transport."""
|
|
if not self.config.url:
|
|
raise ValueError("url is required for HTTP transport")
|
|
|
|
self._http_client = httpx.Client(
|
|
base_url=self.config.url,
|
|
headers=self.config.headers,
|
|
timeout=30.0,
|
|
)
|
|
|
|
# Test connection
|
|
try:
|
|
response = self._http_client.get("/health")
|
|
response.raise_for_status()
|
|
logger.info(f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}")
|
|
except Exception as e:
|
|
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
|
# Continue anyway, server might not have health endpoint
|
|
|
|
def _discover_tools(self) -> None:
|
|
"""Discover available tools from the MCP server."""
|
|
try:
|
|
if self.config.transport == "stdio":
|
|
tools_list = self._run_async(self._list_tools_stdio_async())
|
|
else:
|
|
tools_list = self._list_tools_http()
|
|
|
|
self._tools = {}
|
|
for tool_data in tools_list:
|
|
tool = MCPTool(
|
|
name=tool_data["name"],
|
|
description=tool_data.get("description", ""),
|
|
input_schema=tool_data.get("inputSchema", {}),
|
|
server_name=self.config.name,
|
|
)
|
|
self._tools[tool.name] = tool
|
|
|
|
logger.info(f"Discovered {len(self._tools)} tools from '{self.config.name}': {list(self._tools.keys())}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to discover tools from '{self.config.name}': {e}")
|
|
raise
|
|
|
|
async def _list_tools_stdio_async(self) -> list[dict]:
|
|
"""List tools via STDIO protocol using MCP SDK."""
|
|
from mcp import ClientSession
|
|
from mcp.client.stdio import stdio_client
|
|
|
|
async with stdio_client(self._server_params) as (read, write):
|
|
async with ClientSession(read, write) as session:
|
|
# Initialize the session
|
|
await session.initialize()
|
|
|
|
# List tools
|
|
response = await session.list_tools()
|
|
|
|
# Convert tools to dict format
|
|
tools_list = []
|
|
for tool in response.tools:
|
|
tools_list.append({
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"inputSchema": tool.inputSchema,
|
|
})
|
|
|
|
return tools_list
|
|
|
|
def _list_tools_http(self) -> list[dict]:
|
|
"""List tools via HTTP protocol."""
|
|
if not self._http_client:
|
|
raise RuntimeError("HTTP client not initialized")
|
|
|
|
try:
|
|
# Use MCP over HTTP protocol
|
|
response = self._http_client.post(
|
|
"/mcp/v1",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"method": "tools/list",
|
|
"params": {},
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
if "error" in data:
|
|
raise RuntimeError(f"MCP error: {data['error']}")
|
|
|
|
return data.get("result", {}).get("tools", [])
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to list tools via HTTP: {e}")
|
|
|
|
def list_tools(self) -> list[MCPTool]:
|
|
"""
|
|
Get list of available tools.
|
|
|
|
Returns:
|
|
List of MCPTool objects
|
|
"""
|
|
if not self._connected:
|
|
self.connect()
|
|
|
|
return list(self._tools.values())
|
|
|
|
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
|
"""
|
|
Invoke a tool on the MCP server.
|
|
|
|
Args:
|
|
tool_name: Name of the tool to invoke
|
|
arguments: Tool arguments
|
|
|
|
Returns:
|
|
Tool result
|
|
"""
|
|
if not self._connected:
|
|
self.connect()
|
|
|
|
if tool_name not in self._tools:
|
|
raise ValueError(f"Unknown tool: {tool_name}")
|
|
|
|
if self.config.transport == "stdio":
|
|
return self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
|
else:
|
|
return self._call_tool_http(tool_name, arguments)
|
|
|
|
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
|
"""Call tool via STDIO protocol using MCP SDK."""
|
|
from mcp import ClientSession
|
|
from mcp.client.stdio import stdio_client
|
|
|
|
async with stdio_client(self._server_params) as (read, write):
|
|
async with ClientSession(read, write) as session:
|
|
# Initialize the session
|
|
await session.initialize()
|
|
|
|
# Call tool
|
|
result = await session.call_tool(tool_name, arguments=arguments)
|
|
|
|
# Extract content
|
|
if result.content:
|
|
# MCP returns content as a list of content items
|
|
if len(result.content) > 0:
|
|
content_item = result.content[0]
|
|
# Check if it's a text content item
|
|
if hasattr(content_item, 'text'):
|
|
return content_item.text
|
|
elif hasattr(content_item, 'data'):
|
|
return content_item.data
|
|
return result.content
|
|
|
|
return None
|
|
|
|
def _call_tool_http(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
|
"""Call tool via HTTP protocol."""
|
|
if not self._http_client:
|
|
raise RuntimeError("HTTP client not initialized")
|
|
|
|
try:
|
|
response = self._http_client.post(
|
|
"/mcp/v1",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"id": 2,
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": tool_name,
|
|
"arguments": arguments,
|
|
},
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
if "error" in data:
|
|
raise RuntimeError(f"Tool execution error: {data['error']}")
|
|
|
|
return data.get("result", {}).get("content", [])
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to call tool via HTTP: {e}")
|
|
|
|
def disconnect(self) -> None:
|
|
"""Disconnect from the MCP server."""
|
|
if self._http_client:
|
|
self._http_client.close()
|
|
self._http_client = None
|
|
|
|
self._connected = False
|
|
logger.info(f"Disconnected from MCP server '{self.config.name}'")
|
|
|
|
def __enter__(self):
|
|
"""Context manager entry."""
|
|
self.connect()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Context manager exit."""
|
|
self.disconnect()
|