Merge pull request #6534 from VasuBansal7576/codex/mcp-connection-manager-6348-draft

feat: add shared MCP connection manager
This commit is contained in:
Timothy @aden
2026-03-18 16:52:44 -07:00
committed by GitHub
4 changed files with 541 additions and 10 deletions
@@ -0,0 +1,246 @@
"""Shared MCP client connection management."""
import threading
from typing import Any
import httpx
from framework.runner.mcp_client import MCPClient, MCPServerConfig
class MCPConnectionManager:
"""Process-wide MCP client pool keyed by server name."""
_instance = None
_lock = threading.Lock()
def __init__(self) -> None:
self._pool: dict[str, MCPClient] = {}
self._refcounts: dict[str, int] = {}
self._configs: dict[str, MCPServerConfig] = {}
self._pool_lock = threading.Lock()
# Transition events keep callers from racing a connect/reconnect/disconnect.
self._transitions: dict[str, threading.Event] = {}
@classmethod
def get_instance(cls) -> "MCPConnectionManager":
"""Return the process-level singleton instance."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@staticmethod
def _is_connected(client: MCPClient | None) -> bool:
return bool(client and getattr(client, "_connected", False))
def acquire(self, config: MCPServerConfig) -> MCPClient:
"""Get or create a shared connection and increment its refcount."""
server_name = config.name
while True:
should_connect = False
transition_event: threading.Event | None = None
with self._pool_lock:
client = self._pool.get(server_name)
if self._is_connected(client) and server_name not in self._transitions:
self._refcounts[server_name] = self._refcounts.get(server_name, 0) + 1
self._configs[server_name] = config
return client
transition_event = self._transitions.get(server_name)
if transition_event is None:
transition_event = threading.Event()
self._transitions[server_name] = transition_event
self._configs[server_name] = config
should_connect = True
if not should_connect:
transition_event.wait()
continue
client = MCPClient(config)
try:
client.connect()
except Exception:
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._transitions.pop(server_name, None)
if (
server_name not in self._pool
and self._refcounts.get(server_name, 0) <= 0
):
self._configs.pop(server_name, None)
transition_event.set()
raise
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._pool[server_name] = client
self._refcounts[server_name] = self._refcounts.get(server_name, 0) + 1
self._configs[server_name] = config
self._transitions.pop(server_name, None)
transition_event.set()
return client
client.disconnect()
def release(self, server_name: str) -> None:
"""Decrement refcount and disconnect when the last user releases."""
while True:
disconnect_client: MCPClient | None = None
transition_event: threading.Event | None = None
should_disconnect = False
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
refcount = self._refcounts.get(server_name, 0)
if refcount <= 0:
return
if refcount > 1:
self._refcounts[server_name] = refcount - 1
return
disconnect_client = self._pool.pop(server_name, None)
self._refcounts.pop(server_name, None)
transition_event = threading.Event()
self._transitions[server_name] = transition_event
should_disconnect = True
if not should_disconnect:
transition_event.wait()
continue
try:
if disconnect_client is not None:
disconnect_client.disconnect()
finally:
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._transitions.pop(server_name, None)
transition_event.set()
return
def health_check(self, server_name: str) -> bool:
"""Return True when the pooled connection appears healthy."""
while True:
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
client = self._pool.get(server_name)
config = self._configs.get(server_name)
break
transition_event.wait()
if client is None or config is None:
return False
try:
if config.transport == "stdio":
client.list_tools()
return True
if not config.url:
return False
client_kwargs: dict[str, Any] = {
"base_url": config.url,
"headers": config.headers,
"timeout": 5.0,
}
if config.transport == "unix":
if not config.socket_path:
return False
client_kwargs["transport"] = httpx.HTTPTransport(uds=config.socket_path)
with httpx.Client(**client_kwargs) as http_client:
response = http_client.get("/health")
response.raise_for_status()
return True
except Exception:
return False
def reconnect(self, server_name: str) -> MCPClient:
"""Force a disconnect and replace the pooled client with a fresh one."""
while True:
transition_event: threading.Event | None = None
old_client: MCPClient | None = None
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
config = self._configs.get(server_name)
if config is None:
raise KeyError(f"Unknown MCP server: {server_name}")
old_client = self._pool.get(server_name)
refcount = self._refcounts.get(server_name, 0)
transition_event = threading.Event()
self._transitions[server_name] = transition_event
break
transition_event.wait()
if old_client is not None:
old_client.disconnect()
new_client = MCPClient(config)
try:
new_client.connect()
except Exception:
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._pool.pop(server_name, None)
self._transitions.pop(server_name, None)
transition_event.set()
raise
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._pool[server_name] = new_client
self._refcounts[server_name] = max(refcount, 1)
self._transitions.pop(server_name, None)
transition_event.set()
return new_client
new_client.disconnect()
return self.acquire(config)
def cleanup_all(self) -> None:
"""Disconnect all pooled clients and clear manager state."""
while True:
with self._pool_lock:
if self._transitions:
pending = list(self._transitions.values())
else:
cleanup_events = {name: threading.Event() for name in self._pool}
clients = list(self._pool.items())
self._transitions.update(cleanup_events)
self._pool.clear()
self._refcounts.clear()
self._configs.clear()
break
for event in pending:
event.wait()
for _server_name, client in clients:
try:
client.disconnect()
except Exception:
pass
with self._pool_lock:
for server_name, event in cleanup_events.items():
current = self._transitions.get(server_name)
if current is event:
self._transitions.pop(server_name, None)
event.set()
+33 -10
View File
@@ -54,6 +54,8 @@ class ToolRegistry:
def __init__(self):
self._tools: dict[str, RegisteredTool] = {}
self._mcp_clients: list[Any] = [] # List of MCPClient instances
self._mcp_client_servers: dict[int, str] = {} # client id -> server name
self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
# MCP resync tracking
@@ -480,6 +482,7 @@ class ToolRegistry:
def register_mcp_server(
self,
server_config: dict[str, Any],
use_connection_manager: bool = False,
) -> int:
"""
Register an MCP server and discover its tools.
@@ -495,12 +498,14 @@ class ToolRegistry:
- url: Server URL (for http)
- headers: HTTP headers (for http)
- description: Server description (optional)
use_connection_manager: When True, reuse a shared client keyed by server name
Returns:
Number of tools registered from this server
"""
try:
from framework.runner.mcp_client import MCPClient, MCPServerConfig
from framework.runner.mcp_connection_manager import MCPConnectionManager
# Build config object
config = MCPServerConfig(
@@ -516,11 +521,18 @@ class ToolRegistry:
)
# Create and connect client
client = MCPClient(config)
client.connect()
if use_connection_manager:
client = MCPConnectionManager.get_instance().acquire(config)
else:
client = MCPClient(config)
client.connect()
# Store client for cleanup
self._mcp_clients.append(client)
client_id = id(client)
self._mcp_client_servers[client_id] = config.name
if use_connection_manager:
self._mcp_managed_clients.add(client_id)
# Register each tool
server_name = server_config["name"]
@@ -720,12 +732,7 @@ class ToolRegistry:
logger.info("%s — resyncing MCP servers", reason)
# 1. Disconnect existing MCP clients
for client in self._mcp_clients:
try:
client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting MCP client during resync: {e}")
self._mcp_clients.clear()
self._cleanup_mcp_clients("during resync")
# 2. Remove MCP-registered tools
for name in self._mcp_tool_names:
@@ -740,12 +747,28 @@ class ToolRegistry:
def cleanup(self) -> None:
"""Clean up all MCP client connections."""
self._cleanup_mcp_clients()
def _cleanup_mcp_clients(self, context: str = "") -> None:
"""Disconnect or release all tracked MCP clients for this registry."""
if context:
context = f" {context}"
for client in self._mcp_clients:
client_id = id(client)
server_name = self._mcp_client_servers.get(client_id, client.config.name)
try:
client.disconnect()
if client_id in self._mcp_managed_clients:
from framework.runner.mcp_connection_manager import MCPConnectionManager
MCPConnectionManager.get_instance().release(server_name)
else:
client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting MCP client: {e}")
logger.warning(f"Error disconnecting MCP client{context}: {e}")
self._mcp_clients.clear()
self._mcp_client_servers.clear()
self._mcp_managed_clients.clear()
def __del__(self):
"""Destructor to ensure cleanup."""
+172
View File
@@ -0,0 +1,172 @@
"""Tests for the shared MCP connection manager."""
import threading
import httpx
import pytest
from framework.runner.mcp_client import MCPServerConfig, MCPTool
from framework.runner.mcp_connection_manager import MCPConnectionManager
class FakeMCPClient:
"""Minimal fake MCP client for connection manager tests."""
instances: list["FakeMCPClient"] = []
def __init__(self, config: MCPServerConfig):
self.config = config
self._connected = False
self.connect_calls = 0
self.disconnect_calls = 0
self.list_tools_calls = 0
self.list_tools_error: Exception | None = None
FakeMCPClient.instances.append(self)
def connect(self) -> None:
self.connect_calls += 1
self._connected = True
def disconnect(self) -> None:
self.disconnect_calls += 1
self._connected = False
def list_tools(self) -> list[MCPTool]:
self.list_tools_calls += 1
if self.list_tools_error is not None:
raise self.list_tools_error
return [MCPTool("ping", "Ping", {"type": "object"}, self.config.name)]
@pytest.fixture
def manager(monkeypatch):
monkeypatch.setattr("framework.runner.mcp_connection_manager.MCPClient", FakeMCPClient)
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FakeMCPClient.instances.clear()
manager = MCPConnectionManager.get_instance()
yield manager
manager.cleanup_all()
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FakeMCPClient.instances.clear()
def test_acquire_returns_same_client_for_same_server_name(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client_one = manager.acquire(config)
client_two = manager.acquire(config)
assert client_one is client_two
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
assert len(FakeMCPClient.instances) == 1
def test_release_with_refcount_above_one_keeps_connection_open(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
manager.acquire(config)
manager.release("shared")
assert client.disconnect_calls == 0
assert manager._pool["shared"] is client # noqa: SLF001 - state assertion for unit test
assert manager._refcounts["shared"] == 1 # noqa: SLF001 - state assertion for unit test
def test_release_last_reference_disconnects_and_removes_from_pool(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
manager.release("shared")
assert client.disconnect_calls == 1
assert "shared" not in manager._pool # noqa: SLF001 - state assertion for unit test
assert "shared" not in manager._refcounts # noqa: SLF001 - state assertion for unit test
def test_concurrent_acquire_and_release_keeps_state_consistent(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
worker_count = 8
acquire_barrier = threading.Barrier(worker_count + 1)
release_barrier = threading.Barrier(worker_count)
acquired_clients: list[FakeMCPClient] = []
acquired_lock = threading.Lock()
def worker() -> None:
acquire_barrier.wait()
client = manager.acquire(config)
with acquired_lock:
acquired_clients.append(client)
release_barrier.wait()
manager.release("shared")
threads = [threading.Thread(target=worker) for _ in range(worker_count)]
for thread in threads:
thread.start()
acquire_barrier.wait()
for thread in threads:
thread.join()
assert len({id(client) for client in acquired_clients}) == 1
assert len(FakeMCPClient.instances) == 1
assert FakeMCPClient.instances[0].disconnect_calls == 1
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
def test_cleanup_all_disconnects_every_pooled_client(manager):
manager.acquire(MCPServerConfig(name="one", transport="stdio", command="echo"))
manager.acquire(MCPServerConfig(name="two", transport="stdio", command="echo"))
manager.cleanup_all()
assert len(FakeMCPClient.instances) == 2
assert all(client.disconnect_calls == 1 for client in FakeMCPClient.instances)
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
assert manager._configs == {} # noqa: SLF001 - state assertion for unit test
def test_reconnect_replaces_client_even_with_existing_refcount(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
original_client = manager.acquire(config)
manager.acquire(config)
replacement = manager.reconnect("shared")
assert replacement is not original_client
assert original_client.disconnect_calls == 1
assert manager._pool["shared"] is replacement # noqa: SLF001 - state assertion for unit test
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
def test_health_check_returns_false_when_server_is_unreachable(manager, monkeypatch):
config = MCPServerConfig(name="shared", transport="http", url="http://localhost:9")
manager.acquire(config)
class FailingHttpClient:
def __init__(self, **_kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, _path: str):
raise httpx.ConnectError("unreachable")
monkeypatch.setattr("framework.runner.mcp_connection_manager.httpx.Client", FailingHttpClient)
assert manager.health_check("shared") is False
def test_health_check_for_stdio_returns_false_on_tools_list_error(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
client.list_tools_error = RuntimeError("broken")
assert manager.health_check("shared") is False
+90
View File
@@ -8,6 +8,7 @@ could cause a json.JSONDecodeError and crash execution.
import textwrap
from pathlib import Path
from types import SimpleNamespace
from framework.runner.tool_registry import ToolRegistry
@@ -91,3 +92,92 @@ def test_discover_from_module_handles_empty_content(tmp_path):
result = registered.executor({})
assert isinstance(result, dict)
assert result == {}
class _RegistryFakeClient:
def __init__(self, config):
self.config = config
self.connect_calls = 0
self.disconnect_calls = 0
def connect(self) -> None:
self.connect_calls += 1
def disconnect(self) -> None:
self.disconnect_calls += 1
def list_tools(self):
return [
SimpleNamespace(
name="pooled_tool",
description="Tool from MCP",
input_schema={"type": "object", "properties": {}, "required": []},
)
]
def call_tool(self, tool_name, arguments):
return [{"text": f"{tool_name}:{arguments}"}]
def test_register_mcp_server_uses_connection_manager_when_enabled(monkeypatch):
registry = ToolRegistry()
client = _RegistryFakeClient(SimpleNamespace(name="shared"))
manager_calls: list[tuple[str, str]] = []
class FakeManager:
def acquire(self, config):
manager_calls.append(("acquire", config.name))
client.config = config
return client
def release(self, server_name: str) -> None:
manager_calls.append(("release", server_name))
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
count = registry.register_mcp_server(
{"name": "shared", "transport": "stdio", "command": "echo"},
use_connection_manager=True,
)
assert count == 1
assert manager_calls == [("acquire", "shared")]
registry.cleanup()
assert manager_calls == [("acquire", "shared"), ("release", "shared")]
assert client.disconnect_calls == 0
def test_register_mcp_server_defaults_to_direct_client_behavior(monkeypatch):
registry = ToolRegistry()
created_clients: list[_RegistryFakeClient] = []
def fake_client_factory(config):
client = _RegistryFakeClient(config)
created_clients.append(client)
return client
def fail_if_manager_used():
raise AssertionError("connection manager should not be used by default")
monkeypatch.setattr("framework.runner.mcp_client.MCPClient", fake_client_factory)
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
fail_if_manager_used,
)
count = registry.register_mcp_server(
{"name": "direct", "transport": "stdio", "command": "echo"},
)
assert count == 1
assert len(created_clients) == 1
assert created_clients[0].connect_calls == 1
registry.cleanup()
assert created_clients[0].disconnect_calls == 1