Files
hive/core/tests/test_tool_registry.py
T
2026-03-18 20:20:32 -07:00

209 lines
6.4 KiB
Python

"""Tests for ToolRegistry JSON handling when tools return invalid JSON.
These tests exercise the discover_from_module() path, where tools are
registered via a TOOLS dict and a unified tool_executor that returns
ToolResult instances. Historically, invalid JSON in ToolResult.content
could cause a json.JSONDecodeError and crash execution.
"""
import textwrap
from pathlib import Path
from types import SimpleNamespace
from framework.runner.tool_registry import ToolRegistry
def _write_tool_module(tmp_path: Path, content: str) -> Path:
"""Helper to write a temporary tools module."""
module_path = tmp_path / "agent_tools.py"
module_path.write_text(textwrap.dedent(content))
return module_path
def test_discover_from_module_handles_invalid_json(tmp_path):
"""ToolRegistry should not crash when tool_executor returns invalid JSON."""
module_src = """
from framework.llm.provider import Tool, ToolUse, ToolResult
TOOLS = {
"bad_tool": Tool(
name="bad_tool",
description="Returns malformed JSON",
parameters={"type": "object", "properties": {}},
),
}
def tool_executor(tool_use: ToolUse) -> ToolResult:
# Intentionally malformed JSON
return ToolResult(
tool_use_id=tool_use.id,
content="not {valid json",
is_error=False,
)
"""
module_path = _write_tool_module(tmp_path, module_src)
registry = ToolRegistry()
count = registry.discover_from_module(module_path)
assert count == 1
# Access the registered executor for "bad_tool"
assert "bad_tool" in registry._tools # noqa: SLF001 - testing internal registry
registered = registry._tools["bad_tool"]
# Should not raise, and should return a structured error dict
result = registered.executor({})
assert isinstance(result, dict)
assert "error" in result
assert "raw_content" in result
assert result["raw_content"] == "not {valid json"
def test_discover_from_module_handles_empty_content(tmp_path):
"""ToolRegistry should handle empty ToolResult.content gracefully."""
module_src = """
from framework.llm.provider import Tool, ToolUse, ToolResult
TOOLS = {
"empty_tool": Tool(
name="empty_tool",
description="Returns empty content",
parameters={"type": "object", "properties": {}},
),
}
def tool_executor(tool_use: ToolUse) -> ToolResult:
return ToolResult(
tool_use_id=tool_use.id,
content="",
is_error=False,
)
"""
module_path = _write_tool_module(tmp_path, module_src)
registry = ToolRegistry()
count = registry.discover_from_module(module_path)
assert count == 1
assert "empty_tool" in registry._tools # noqa: SLF001 - testing internal registry
registered = registry._tools["empty_tool"]
# Empty content should return an empty dict rather than crashing
result = registered.executor({})
assert isinstance(result, dict)
assert result == {}
class _RegistryFakeClient:
def __init__(self, config):
self.config = config
self.connect_calls = 0
self.disconnect_calls = 0
def connect(self) -> None:
self.connect_calls += 1
def disconnect(self) -> None:
self.disconnect_calls += 1
def list_tools(self):
return [
SimpleNamespace(
name="pooled_tool",
description="Tool from MCP",
input_schema={"type": "object", "properties": {}, "required": []},
)
]
def call_tool(self, tool_name, arguments):
return [{"text": f"{tool_name}:{arguments}"}]
def test_register_mcp_server_uses_connection_manager_when_enabled(monkeypatch):
registry = ToolRegistry()
client = _RegistryFakeClient(SimpleNamespace(name="shared"))
manager_calls: list[tuple[str, str]] = []
class FakeManager:
def acquire(self, config):
manager_calls.append(("acquire", config.name))
client.config = config
return client
def release(self, server_name: str) -> None:
manager_calls.append(("release", server_name))
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
count = registry.register_mcp_server(
{"name": "shared", "transport": "stdio", "command": "echo"},
use_connection_manager=True,
)
assert count == 1
assert manager_calls == [("acquire", "shared")]
registry.cleanup()
assert manager_calls == [("acquire", "shared"), ("release", "shared")]
assert client.disconnect_calls == 0
def test_register_mcp_server_defaults_to_connection_manager(monkeypatch):
"""Default behavior uses the connection manager (reuse enabled by default)."""
registry = ToolRegistry()
created_clients: list[_RegistryFakeClient] = []
def fake_client_factory(config):
client = _RegistryFakeClient(config)
created_clients.append(client)
return client
class FakeManager:
def acquire(self, config):
return fake_client_factory(config)
def release(self, server_name):
pass
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
count = registry.register_mcp_server(
{"name": "direct", "transport": "stdio", "command": "echo"},
)
assert count == 1
assert len(created_clients) == 1
def test_register_mcp_server_direct_client_when_manager_disabled(monkeypatch):
"""When use_connection_manager=False, a direct MCPClient is created."""
registry = ToolRegistry()
created_clients: list[_RegistryFakeClient] = []
def fake_client_factory(config):
client = _RegistryFakeClient(config)
created_clients.append(client)
return client
monkeypatch.setattr("framework.runner.mcp_client.MCPClient", fake_client_factory)
count = registry.register_mcp_server(
{"name": "direct", "transport": "stdio", "command": "echo"},
use_connection_manager=False,
)
assert count == 1
assert len(created_clients) == 1
assert created_clients[0].connect_calls == 1
registry.cleanup()
assert created_clients[0].disconnect_calls == 1