Integration: add Databricks MCP tool integration

Implements the Databricks MCP tool integration for the Hive agent framework
This commit is contained in:
Navya Bijoy
2026-02-26 21:01:59 +05:30
parent b00203702e
commit ddd30a950d
10 changed files with 2054 additions and 1 deletions
+6
View File
@@ -59,6 +59,10 @@ sql = [
bigquery = [
"google-cloud-bigquery>=3.0.0",
]
databricks = [
"databricks-sdk>=0.30.0",
"databricks-mcp>=0.1.0",
]
all = [
"RestrictedPython>=7.0",
"pytesseract>=0.3.10",
@@ -66,6 +70,8 @@ all = [
"duckdb>=1.0.0",
"openpyxl>=3.1.0",
"google-cloud-bigquery>=3.0.0",
"databricks-sdk>=0.30.0",
"databricks-mcp>=0.1.0",
]
[tool.uv.sources]
@@ -56,6 +56,7 @@ To add a new credential:
from .apollo import APOLLO_CREDENTIALS
from .base import CredentialError, CredentialSpec
from .bigquery import BIGQUERY_CREDENTIALS
from .databricks import DATABRICKS_CREDENTIALS
from .brevo import BREVO_CREDENTIALS
from .browser import get_aden_auth_url, get_aden_setup_url, open_browser
from .calcom import CALCOM_CREDENTIALS
@@ -113,6 +114,7 @@ CREDENTIAL_SPECS = {
**STRIPE_CREDENTIALS,
**BREVO_CREDENTIALS,
**POSTGRES_CREDENTIALS,
**DATABRICKS_CREDENTIALS,
}
__all__ = [
@@ -160,4 +162,5 @@ __all__ = [
"STRIPE_CREDENTIALS",
"BREVO_CREDENTIALS",
"POSTGRES_CREDENTIALS",
"DATABRICKS_CREDENTIALS",
]
@@ -0,0 +1,90 @@
"""
Databricks tool credentials.
Contains credentials for Databricks workspace and SQL Warehouse access,
as well as managed MCP server connectivity.
"""
from .base import CredentialSpec
# All tool names that require Databricks credentials
_DATABRICKS_TOOLS = [
# Custom SQL tools (databricks_tool.py)
"run_databricks_sql",
"describe_databricks_table",
# Managed MCP tools (databricks_mcp_tool.py)
"databricks_mcp_query_sql",
"databricks_mcp_query_uc_function",
"databricks_mcp_vector_search",
"databricks_mcp_query_genie",
"databricks_mcp_list_tools",
]
DATABRICKS_CREDENTIALS = {
"databricks_host": CredentialSpec(
env_var="DATABRICKS_HOST",
tools=_DATABRICKS_TOOLS,
required=True,
startup_required=False,
help_url="https://docs.databricks.com/en/workspace/workspace-details.html",
description="Databricks workspace URL (e.g., https://dbc-a1b2c3d4-e5f6.cloud.databricks.com)",
# Auth method support
aden_supported=False,
direct_api_key_supported=True,
api_key_instructions="""To set up Databricks authentication:
1. Go to your Databricks workspace
2. Copy the workspace URL from your browser (e.g., https://dbc-a1b2c3d4-e5f6.cloud.databricks.com)
3. Set DATABRICKS_HOST=https://your-workspace-hostname
Note: Do not include a trailing slash.""",
# Credential store mapping
credential_id="databricks_host",
credential_key="workspace_url",
),
"databricks_token": CredentialSpec(
env_var="DATABRICKS_TOKEN",
tools=_DATABRICKS_TOOLS,
required=True,
startup_required=False,
help_url="https://docs.databricks.com/en/dev-tools/auth/pat.html",
description="Databricks personal access token for API authentication",
# Auth method support
aden_supported=False,
direct_api_key_supported=True,
api_key_instructions="""To create a Databricks personal access token:
1. Go to your Databricks workspace
2. Click your username in the top bar > Settings
3. Click 'Developer' in the sidebar
4. Under 'Access tokens', click 'Manage'
5. Click 'Generate new token'
6. Give it a description and set the lifetime
7. Copy the token and set DATABRICKS_TOKEN=dapi...
For service principals, use OAuth machine-to-machine tokens instead.""",
# Credential store mapping
credential_id="databricks_token",
credential_key="access_token",
),
"databricks_warehouse": CredentialSpec(
env_var="DATABRICKS_WAREHOUSE_ID",
tools=["run_databricks_sql", "databricks_mcp_query_sql"],
required=False,
startup_required=False,
help_url="https://docs.databricks.com/en/sql/admin/create-sql-warehouse.html",
description="Default Databricks SQL Warehouse ID for query execution",
aden_supported=False,
direct_api_key_supported=True,
api_key_instructions="""To find your SQL Warehouse ID:
1. Go to your Databricks workspace
2. Click 'SQL Warehouses' in the sidebar
3. Click on your warehouse
4. Copy the ID from the URL or the 'Connection details' tab
5. Set DATABRICKS_WAREHOUSE_ID=your_warehouse_id""",
# Credential store mapping
credential_id="databricks_warehouse",
credential_key="warehouse_id",
),
}
+4
View File
@@ -29,6 +29,7 @@ from .brevo_tool import register_tools as register_brevo
from .calcom_tool import register_tools as register_calcom
from .calendar_tool import register_tools as register_calendar
from .csv_tool import register_tools as register_csv
from .databricks_tool import register_tools as register_databricks
from .discord_tool import register_tools as register_discord
# Security scanning tools
@@ -157,6 +158,9 @@ def register_all_tools(
# Postgres tool
register_postgres(mcp, credentials=credentials)
# Databricks tools
register_databricks(mcp, credentials=credentials)
# Return the list of all registered tool names
return list(mcp._tool_manager._tools.keys())
@@ -0,0 +1,128 @@
# Databricks Tool
Query Databricks SQL Warehouses and interact with Databricks managed MCP servers.
## Tools
### Custom SQL Tools (Read-Only)
| Tool | Description |
|------|-------------|
| `run_databricks_sql` | Execute read-only SQL queries against a Databricks SQL Warehouse |
| `describe_databricks_table` | Fetch table schema/metadata from Unity Catalog |
### Managed MCP Server Tools
| Tool | Description |
|------|-------------|
| `databricks_mcp_query_sql` | Execute SQL via the managed SQL MCP server |
| `databricks_mcp_query_uc_function` | Execute a Unity Catalog function |
| `databricks_mcp_vector_search` | Query a Vector Search index |
| `databricks_mcp_query_genie` | Query a Genie space with natural language |
| `databricks_mcp_list_tools` | Discover tools on any managed MCP server endpoint |
## Environment Variables
| Variable | Required | Description |
|----------|----------|-------------|
| `DATABRICKS_HOST` | Yes | Workspace URL (e.g., `https://dbc-xxx.cloud.databricks.com`) |
| `DATABRICKS_TOKEN` | Yes | Personal access token (`dapi...`) |
| `DATABRICKS_WAREHOUSE_ID` | No | Default SQL Warehouse ID |
## Usage Examples
### Execute a Read-Only SQL Query
```python
run_databricks_sql(
sql="SELECT name, COUNT(*) as cnt FROM main.default.users GROUP BY name",
warehouse_id="abc123def456",
max_rows=100
)
```
### Describe a Unity Catalog Table
```python
describe_databricks_table(
catalog="main",
schema="default",
table="users"
)
```
### Query via Managed MCP SQL Server
```python
databricks_mcp_query_sql(
sql="SELECT * FROM main.default.orders LIMIT 10"
)
```
### Execute a Unity Catalog Function
```python
databricks_mcp_query_uc_function(
catalog="main",
schema="analytics",
function_name="get_revenue_summary",
arguments={"start_date": "2024-01-01"}
)
```
### Search a Vector Index
```python
databricks_mcp_vector_search(
catalog="prod",
schema="knowledge_base",
index_name="docs_index",
query="How to configure authentication?",
num_results=5
)
```
### Query a Genie Space
```python
databricks_mcp_query_genie(
genie_space_id="abc123",
question="What was the total revenue last quarter?"
)
```
### Discover Available MCP Tools
```python
databricks_mcp_list_tools(
server_type="functions",
resource_path="system/ai"
)
```
## Safety Features
- **Read-only enforcement** on `run_databricks_sql`: INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, TRUNCATE, MERGE, and REPLACE are blocked
- **Row limits**: Configurable max_rows (110,000) to prevent large result sets
- **Credential isolation**: Uses CredentialStoreAdapter pattern; secrets never logged
## Error Handling
All tools return structured error dicts with `error` and optional `help` fields. Common errors include:
- **Authentication failure**: Invalid or expired token
- **Permission denied**: Insufficient privileges on the target resource
- **Not found**: Invalid catalog, schema, table, or warehouse ID
- **Missing dependency**: `databricks-sdk` or `databricks-mcp` not installed
## Installation
```bash
pip install 'databricks-sdk>=0.30.0' 'databricks-mcp>=0.1.0'
```
Or via the project's optional dependencies:
```bash
pip install '.[databricks]'
```
@@ -0,0 +1,12 @@
"""
Databricks Tool - Query Databricks SQL Warehouses and interact with managed MCP servers.
Provides MCP tools for:
- Executing read-only SQL queries via Databricks SQL Warehouses
- Describing tables via Unity Catalog
- Interacting with Databricks managed MCP servers (SQL, Vector Search, Genie, UC functions)
"""
from .databricks_tool import register_tools
__all__ = ["register_tools"]
@@ -0,0 +1,562 @@
"""
Databricks Managed MCP Server Tools.
Provides tools to interact with Databricks managed MCP server endpoints:
- SQL: Execute queries via the managed SQL MCP server
- Unity Catalog Functions: Execute predefined UC functions
- Vector Search: Query Vector Search indexes
- Genie: Query Genie spaces with natural language
- Discovery: List available tools on any managed MCP server
These tools use the official databricks-mcp library for authentication
and communication with Databricks managed MCP server endpoints.
"""
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING, Any
from fastmcp import FastMCP
if TYPE_CHECKING:
from aden_tools.credentials import CredentialStoreAdapter
logger = logging.getLogger(__name__)
def _get_mcp_client(server_url: str, host: str | None, token: str | None) -> Any:
"""
Create a DatabricksMCPClient for the given server URL.
Args:
server_url: Full URL of the managed MCP server endpoint
host: Databricks workspace URL
token: Personal access token
Returns:
DatabricksMCPClient instance
Raises:
ImportError: If databricks-mcp or databricks-sdk is not installed
"""
try:
from databricks.sdk import WorkspaceClient
from databricks_mcp import DatabricksMCPClient
except ImportError:
raise ImportError(
"databricks-mcp and databricks-sdk are required for Databricks MCP tools. "
"Install them with: pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'"
) from None
kwargs: dict[str, str] = {}
if host:
kwargs["host"] = host
if token:
kwargs["token"] = token
workspace_client = WorkspaceClient(**kwargs)
return DatabricksMCPClient(server_url=server_url, workspace_client=workspace_client)
def register_mcp_tools(
mcp: FastMCP,
credentials: CredentialStoreAdapter | None = None,
) -> None:
"""Register Databricks managed MCP server tools with the MCP server."""
def _get_credentials() -> dict[str, str | None]:
"""Get Databricks credentials from credential store or environment."""
if credentials is not None:
try:
host = credentials.get("databricks_host")
except KeyError:
host = None
try:
token = credentials.get("databricks_token")
except KeyError:
token = None
try:
warehouse = credentials.get("databricks_warehouse")
except KeyError:
warehouse = None
return {
"host": host,
"token": token,
"warehouse_id": warehouse,
}
return {
"host": os.getenv("DATABRICKS_HOST"),
"token": os.getenv("DATABRICKS_TOKEN"),
"warehouse_id": os.getenv("DATABRICKS_WAREHOUSE_ID"),
}
def _get_host() -> str | None:
"""Get the Databricks workspace host URL."""
creds = _get_credentials()
return creds.get("host")
def _build_server_url(path: str) -> str | None:
"""Build a full managed MCP server URL from a path suffix."""
host = _get_host()
if not host:
return None
# Ensure host doesn't have trailing slash
host = host.rstrip("/")
return f"{host}{path}"
@mcp.tool()
def databricks_mcp_query_sql(
sql: str,
warehouse_id: str | None = None,
) -> dict:
"""
Execute a SQL query via the Databricks managed SQL MCP server.
Unlike run_databricks_sql, this tool uses the official Databricks managed
MCP SQL server endpoint and supports both read and write operations as
permitted by the workspace.
Args:
sql: The SQL query to execute.
warehouse_id: SQL Warehouse ID. Falls back to DATABRICKS_WAREHOUSE_ID
env var if not provided. Required for the SQL MCP server.
Returns:
Dict with query results:
- success: True if query executed successfully
- result: The query result text from the MCP server
Or error dict with:
- error: Error message
- help: Optional help text
Example:
>>> databricks_mcp_query_sql("SELECT * FROM main.default.users LIMIT 10")
{
"success": True,
"result": "..."
}
"""
if not sql or not sql.strip():
return {"error": "sql is required"}
try:
creds = _get_credentials()
server_url = _build_server_url("/api/2.0/mcp/sql")
if not server_url:
return {
"error": "Databricks host not configured",
"help": "Set DATABRICKS_HOST environment variable to your workspace URL.",
}
effective_warehouse = warehouse_id or creds.get("warehouse_id")
mcp_client = _get_mcp_client(
server_url=server_url,
host=creds.get("host"),
token=creds.get("token"),
)
# Build arguments for the SQL tool
tool_args: dict[str, Any] = {"statement": sql}
if effective_warehouse:
tool_args["warehouse_id"] = effective_warehouse
response = mcp_client.call_tool("execute_sql", tool_args)
result_text = "".join([c.text for c in response.content])
return {
"success": True,
"result": result_text,
}
except ImportError as e:
return {
"error": str(e),
"help": "Install dependencies: "
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
}
except Exception as e:
return {"error": f"Databricks MCP SQL query failed: {e!s}"}
@mcp.tool()
def databricks_mcp_query_uc_function(
catalog: str,
schema: str,
function_name: str,
arguments: dict | None = None,
) -> dict:
"""
Execute a Unity Catalog function via the Databricks managed MCP server.
Use this to run predefined SQL functions registered in Unity Catalog.
These functions encapsulate business logic and can be invoked as tools.
Args:
catalog: Unity Catalog catalog name (e.g., "main").
schema: Schema name within the catalog (e.g., "default").
function_name: Name of the UC function to execute.
arguments: Optional dict of arguments to pass to the function.
Returns:
Dict with function result:
- success: True if function executed successfully
- result: The function result text from the MCP server
Or error dict with:
- error: Error message
Example:
>>> databricks_mcp_query_uc_function(
... catalog="main",
... schema="analytics",
... function_name="get_revenue_summary",
... arguments={"start_date": "2024-01-01", "end_date": "2024-12-31"}
... )
{
"success": True,
"result": "Revenue summary: ..."
}
"""
if not catalog or not catalog.strip():
return {"error": "catalog is required"}
if not schema or not schema.strip():
return {"error": "schema is required"}
if not function_name or not function_name.strip():
return {"error": "function_name is required"}
try:
creds = _get_credentials()
path = f"/api/2.0/mcp/functions/{catalog}/{schema}/{function_name}"
server_url = _build_server_url(path)
if not server_url:
return {
"error": "Databricks host not configured",
"help": "Set DATABRICKS_HOST environment variable.",
}
mcp_client = _get_mcp_client(
server_url=server_url,
host=creds.get("host"),
token=creds.get("token"),
)
# Construct the tool name using the UC naming convention
tool_name = f"{catalog}__{schema}__{function_name}"
tool_args = arguments or {}
response = mcp_client.call_tool(tool_name, tool_args)
result_text = "".join([c.text for c in response.content])
return {
"success": True,
"result": result_text,
}
except ImportError as e:
return {
"error": str(e),
"help": "Install dependencies: "
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
}
except Exception as e:
return {"error": f"Databricks UC function call failed: {e!s}"}
@mcp.tool()
def databricks_mcp_vector_search(
catalog: str,
schema: str,
index_name: str,
query: str,
num_results: int = 10,
) -> dict:
"""
Query a Databricks Vector Search index via the managed MCP server.
Use this to find semantically relevant documents from a Vector Search
index that uses Databricks managed embeddings.
Args:
catalog: Unity Catalog catalog name containing the index.
schema: Schema name within the catalog.
index_name: Name of the Vector Search index.
query: The search query text.
num_results: Number of results to return (default: 10).
Returns:
Dict with search results:
- success: True if search executed successfully
- result: The search result text from the MCP server
Or error dict with:
- error: Error message
Example:
>>> databricks_mcp_vector_search(
... catalog="prod",
... schema="knowledge_base",
... index_name="docs_index",
... query="How to configure authentication?",
... num_results=5
... )
{
"success": True,
"result": "..."
}
"""
if not catalog or not catalog.strip():
return {"error": "catalog is required"}
if not schema or not schema.strip():
return {"error": "schema is required"}
if not index_name or not index_name.strip():
return {"error": "index_name is required"}
if not query or not query.strip():
return {"error": "query is required"}
try:
creds = _get_credentials()
path = f"/api/2.0/mcp/vector-search/{catalog}/{schema}/{index_name}"
server_url = _build_server_url(path)
if not server_url:
return {
"error": "Databricks host not configured",
"help": "Set DATABRICKS_HOST environment variable.",
}
mcp_client = _get_mcp_client(
server_url=server_url,
host=creds.get("host"),
token=creds.get("token"),
)
tool_args: dict[str, Any] = {
"query": query,
"num_results": num_results,
}
# Discover the actual tool name from the server
tools = mcp_client.list_tools()
if not tools:
return {
"error": "No tools discovered on the Vector Search MCP server",
"help": f"Check that the index '{catalog}.{schema}.{index_name}' exists.",
}
tool_name = tools[0].name
response = mcp_client.call_tool(tool_name, tool_args)
result_text = "".join([c.text for c in response.content])
return {
"success": True,
"result": result_text,
}
except ImportError as e:
return {
"error": str(e),
"help": "Install dependencies: "
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
}
except Exception as e:
return {"error": f"Databricks Vector Search failed: {e!s}"}
@mcp.tool()
def databricks_mcp_query_genie(
genie_space_id: str,
question: str,
) -> dict:
"""
Query a Databricks Genie space via the managed MCP server.
Genie spaces allow natural language queries against structured data.
Use this to analyze data by asking questions in plain English.
Results are read-only.
Note: Genie queries may take longer to execute as they involve
natural language to SQL translation.
Args:
genie_space_id: The ID of the Genie space to query.
question: Natural language question to ask the Genie space.
Returns:
Dict with Genie results:
- success: True if query executed successfully
- result: The Genie response text
Or error dict with:
- error: Error message
Example:
>>> databricks_mcp_query_genie(
... genie_space_id="abc123",
... question="What was the total revenue last quarter?"
... )
{
"success": True,
"result": "The total revenue last quarter was $1.2M..."
}
"""
if not genie_space_id or not genie_space_id.strip():
return {"error": "genie_space_id is required"}
if not question or not question.strip():
return {"error": "question is required"}
try:
creds = _get_credentials()
path = f"/api/2.0/mcp/genie/{genie_space_id}"
server_url = _build_server_url(path)
if not server_url:
return {
"error": "Databricks host not configured",
"help": "Set DATABRICKS_HOST environment variable.",
}
mcp_client = _get_mcp_client(
server_url=server_url,
host=creds.get("host"),
token=creds.get("token"),
)
# Discover the actual tool name from the server
tools = mcp_client.list_tools()
if not tools:
return {
"error": "No tools discovered on the Genie MCP server",
"help": f"Check that the Genie space '{genie_space_id}' exists "
"and you have access to it.",
}
tool_name = tools[0].name
response = mcp_client.call_tool(tool_name, {"question": question})
result_text = "".join([c.text for c in response.content])
return {
"success": True,
"result": result_text,
}
except ImportError as e:
return {
"error": str(e),
"help": "Install dependencies: "
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
}
except Exception as e:
return {"error": f"Databricks Genie query failed: {e!s}"}
@mcp.tool()
def databricks_mcp_list_tools(
server_url: str | None = None,
server_type: str | None = None,
resource_path: str | None = None,
) -> dict:
"""
Discover available tools on a Databricks managed MCP server.
Use this to explore what tools are available on a specific MCP server
endpoint before calling them. Supports both direct URL and parameterized
server type specification.
Args:
server_url: Full URL of the MCP server endpoint. If provided,
server_type and resource_path are ignored.
server_type: Type of managed server: "sql", "vector-search",
"genie", or "functions". Used with resource_path.
resource_path: Resource path for the server type. Examples:
- For vector-search: "catalog/schema/index_name"
- For genie: "genie_space_id"
- For functions: "catalog/schema/function_name"
- For sql: not needed
Returns:
Dict with discovered tools:
- success: True if discovery succeeded
- server_url: The MCP server URL queried
- tools: List of tool definitions (name, description, parameters)
Or error dict with:
- error: Error message
Example:
>>> databricks_mcp_list_tools(server_type="functions", resource_path="system/ai")
{
"success": True,
"server_url": "https://workspace.cloud.databricks.com/api/2.0/mcp/functions/system/ai",
"tools": [
{
"name": "system__ai__python_exec",
"description": "Execute Python code",
"parameters": {...}
}
]
}
"""
try:
creds = _get_credentials()
# Resolve server URL
effective_url = server_url
if not effective_url:
if not server_type:
return {
"error": "Either server_url or server_type is required",
"help": "Provide a full server_url or specify server_type "
"(sql, vector-search, genie, functions) with resource_path.",
}
valid_types = {"sql", "vector-search", "genie", "functions"}
if server_type not in valid_types:
return {
"error": f"Invalid server_type: {server_type}",
"help": f"Must be one of: {', '.join(sorted(valid_types))}",
}
path = f"/api/2.0/mcp/{server_type}"
if resource_path:
path = f"{path}/{resource_path}"
effective_url = _build_server_url(path)
if not effective_url:
return {
"error": "Databricks host not configured",
"help": "Set DATABRICKS_HOST environment variable.",
}
mcp_client = _get_mcp_client(
server_url=effective_url,
host=creds.get("host"),
token=creds.get("token"),
)
tools = mcp_client.list_tools()
tool_list = []
for t in tools:
tool_info: dict[str, Any] = {
"name": t.name,
"description": t.description,
}
if t.inputSchema:
tool_info["parameters"] = t.inputSchema
tool_list.append(tool_info)
return {
"success": True,
"server_url": effective_url,
"tools": tool_list,
}
except ImportError as e:
return {
"error": str(e),
"help": "Install dependencies: "
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
}
except Exception as e:
return {"error": f"Failed to list MCP tools: {e!s}"}
@@ -0,0 +1,429 @@
"""
Databricks Tool - Execute SQL queries and explore tables in Databricks.
Provides two categories of tools:
1. Custom SQL tools with read-only safety guards (run_databricks_sql, describe_databricks_table)
2. Managed MCP server tools (databricks_mcp_*)
Supports:
- Personal access token authentication via DATABRICKS_HOST + DATABRICKS_TOKEN
- Databricks CLI profile authentication (for managed MCP tools)
Safety features:
- Read-only queries only (INSERT, UPDATE, DELETE, etc. are blocked)
- Configurable row limits to prevent large result sets
- SQL write-keyword detection with comment stripping
"""
from __future__ import annotations
import logging
import os
import re
from typing import TYPE_CHECKING, Any
from fastmcp import FastMCP
from .databricks_mcp_tool import register_mcp_tools
if TYPE_CHECKING:
from aden_tools.credentials import CredentialStoreAdapter
logger = logging.getLogger(__name__)
# SQL keywords that indicate write operations (case-insensitive)
WRITE_KEYWORDS = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bCREATE\b",
r"\bALTER\b",
r"\bTRUNCATE\b",
r"\bMERGE\b",
r"\bREPLACE\b",
]
# Compiled regex pattern for detecting write operations
WRITE_PATTERN = re.compile("|".join(WRITE_KEYWORDS), re.IGNORECASE)
def _is_read_only_query(sql: str) -> bool:
"""
Check if a SQL query is read-only.
Args:
sql: The SQL query string to check
Returns:
True if the query appears to be read-only, False otherwise
"""
# Remove comments (both -- and /* */ style)
sql_no_comments = re.sub(r"--.*$", "", sql, flags=re.MULTILINE)
sql_no_comments = re.sub(r"/\*.*?\*/", "", sql_no_comments, flags=re.DOTALL)
# Check for write keywords
return not bool(WRITE_PATTERN.search(sql_no_comments))
def _create_workspace_client(
host: str | None = None,
token: str | None = None,
) -> Any:
"""
Create a Databricks WorkspaceClient with appropriate credentials.
Args:
host: Databricks workspace URL
token: Personal access token
Returns:
WorkspaceClient instance
Raises:
ImportError: If databricks-sdk is not installed
Exception: If authentication fails
"""
try:
from databricks.sdk import WorkspaceClient
except ImportError:
raise ImportError(
"databricks-sdk is required for Databricks tools. "
"Install it with: pip install 'databricks-sdk>=0.30.0'"
) from None
kwargs: dict[str, str] = {}
if host:
kwargs["host"] = host
if token:
kwargs["token"] = token
return WorkspaceClient(**kwargs)
def register_tools(
mcp: FastMCP,
credentials: CredentialStoreAdapter | None = None,
) -> None:
"""Register Databricks tools with the MCP server."""
def _get_credentials() -> dict[str, str | None]:
"""Get Databricks credentials from credential store or environment."""
if credentials is not None:
try:
host = credentials.get("databricks_host")
except KeyError:
host = None
try:
token = credentials.get("databricks_token")
except KeyError:
token = None
try:
warehouse = credentials.get("databricks_warehouse")
except KeyError:
warehouse = None
return {
"host": host,
"token": token,
"warehouse_id": warehouse,
}
return {
"host": os.getenv("DATABRICKS_HOST"),
"token": os.getenv("DATABRICKS_TOKEN"),
"warehouse_id": os.getenv("DATABRICKS_WAREHOUSE_ID"),
}
def _get_client() -> Any:
"""
Get a Databricks WorkspaceClient with credentials resolution.
Returns:
WorkspaceClient instance
"""
creds = _get_credentials()
return _create_workspace_client(
host=creds["host"],
token=creds["token"],
)
@mcp.tool()
def run_databricks_sql(
sql: str,
warehouse_id: str | None = None,
max_rows: int = 1000,
) -> dict:
"""
Execute a read-only SQL query against a Databricks SQL Warehouse.
This tool executes SQL queries and returns the results as structured data.
Only SELECT queries are allowed - write operations (INSERT, UPDATE, DELETE,
DROP, CREATE, ALTER, TRUNCATE, MERGE) are blocked for safety.
Args:
sql: The SQL query to execute. Must be a read-only query.
warehouse_id: SQL Warehouse ID. Falls back to DATABRICKS_WAREHOUSE_ID
env var if not provided.
max_rows: Maximum number of rows to return (default: 1000).
Use this to prevent accidentally fetching large result sets.
Returns:
Dict with query results:
- success: True if query executed successfully
- rows: List of row dictionaries
- total_rows: Total number of rows returned
- rows_returned: Number of rows actually returned (may be limited)
- schema: List of column definitions (name, type)
- query_truncated: True if results were truncated due to max_rows
Or error dict with:
- error: Error message
- help: Optional help text
Example:
>>> run_databricks_sql(
... sql="SELECT name, COUNT(*) as cnt FROM catalog.schema.users GROUP BY name",
... max_rows=100
... )
{
"success": True,
"rows": [{"name": "Alice", "cnt": 42}, ...],
"total_rows": 100,
"rows_returned": 100,
"schema": [{"name": "name", "type": "STRING"}, ...],
"query_truncated": False
}
"""
# Validate SQL is read-only
if not _is_read_only_query(sql):
return {
"error": "Write operations are not allowed",
"help": "Only SELECT queries are permitted. "
"INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, TRUNCATE, and MERGE are blocked.",
}
# Validate max_rows
if max_rows < 1:
return {"error": "max_rows must be at least 1"}
if max_rows > 10000:
return {
"error": "max_rows cannot exceed 10000",
"help": "For larger result sets, consider using pagination or "
"exporting to cloud storage.",
}
try:
client = _get_client()
creds = _get_credentials()
effective_warehouse = warehouse_id or creds.get("warehouse_id")
if not effective_warehouse:
return {
"error": "No SQL Warehouse ID provided",
"help": "Provide warehouse_id parameter or set DATABRICKS_WAREHOUSE_ID "
"environment variable.",
}
# Execute query via Databricks SQL Statement API
response = client.statement_execution.execute_statement(
statement=sql,
warehouse_id=effective_warehouse,
wait_timeout="30s",
row_limit=max_rows,
)
# Check for execution errors
if response.status and response.status.error:
return {
"error": f"Databricks SQL error: {response.status.error.message}",
}
# Parse results
schema = []
if response.manifest and response.manifest.schema and response.manifest.schema.columns:
schema = [
{
"name": col.name,
"type": str(col.type_name) if col.type_name else "UNKNOWN",
}
for col in response.manifest.schema.columns
]
rows = []
total_rows = 0
if response.result and response.result.data_array:
total_rows = len(response.result.data_array)
for row_data in response.result.data_array[:max_rows]:
row = {}
for i, value in enumerate(row_data):
col_name = schema[i]["name"] if i < len(schema) else f"col_{i}"
row[col_name] = value
rows.append(row)
query_truncated = total_rows > max_rows
return {
"success": True,
"rows": rows,
"total_rows": total_rows,
"rows_returned": len(rows),
"schema": schema,
"query_truncated": query_truncated,
}
except ImportError as e:
return {
"error": str(e),
"help": "Install the dependency by running: pip install 'databricks-sdk>=0.30.0'",
}
except Exception as e:
error_msg = str(e)
# Provide helpful messages for common errors
if "TEMPORARILY_UNAVAILABLE" in error_msg or "401" in error_msg:
return {
"error": "Databricks authentication failed",
"help": "Check that DATABRICKS_HOST and DATABRICKS_TOKEN are set correctly. "
"Token may have expired — generate a new one from workspace settings.",
}
if "403" in error_msg or "PERMISSION_DENIED" in error_msg:
return {
"error": f"Databricks permission denied: {error_msg}",
"help": "Ensure your token has permission to access the SQL Warehouse "
"and the requested tables/catalogs.",
}
if "404" in error_msg or "NOT_FOUND" in error_msg:
return {
"error": f"Databricks resource not found: {error_msg}",
"help": "Check that the warehouse ID, catalog, schema, "
"and table names are correct.",
}
if "INVALID_PARAMETER" in error_msg:
return {
"error": f"Invalid parameter: {error_msg}",
"help": "Check that the warehouse_id is a valid SQL Warehouse ID.",
}
return {"error": f"Databricks SQL query failed: {error_msg}"}
@mcp.tool()
def describe_databricks_table(
catalog: str,
schema: str,
table: str,
) -> dict:
"""
Describe a table in Databricks Unity Catalog, returning its schema and metadata.
Use this tool to explore table structure before writing queries.
Returns column names, types, nullability, and table metadata.
Args:
catalog: Unity Catalog catalog name (e.g., "main").
schema: Schema name within the catalog (e.g., "default").
table: Table name to describe.
Returns:
Dict with table information:
- success: True if operation succeeded
- catalog: The catalog name
- schema: The schema name
- table: The table name
- full_name: Fully qualified table name (catalog.schema.table)
- table_type: Type of the table (MANAGED, EXTERNAL, VIEW, etc.)
- columns: List of column definitions (name, type, nullable, comment)
- comment: Table comment/description if available
- storage_location: Physical storage location if applicable
Or error dict with:
- error: Error message
- help: Optional help text
Example:
>>> describe_databricks_table("main", "default", "users")
{
"success": True,
"catalog": "main",
"schema": "default",
"table": "users",
"full_name": "main.default.users",
"table_type": "MANAGED",
"columns": [
{"name": "id", "type": "LONG", "nullable": False, "comment": "User ID"},
{"name": "email", "type": "STRING", "nullable": True, "comment": None}
],
"comment": "User accounts table"
}
"""
if not catalog or not catalog.strip():
return {"error": "catalog is required"}
if not schema or not schema.strip():
return {"error": "schema is required"}
if not table or not table.strip():
return {"error": "table name is required"}
full_name = f"{catalog}.{schema}.{table}"
try:
client = _get_client()
# Retrieve table metadata via Unity Catalog API
table_info = client.tables.get(full_name)
columns = []
if table_info.columns:
for col in table_info.columns:
columns.append(
{
"name": col.name,
"type": str(col.type_name) if col.type_name else "UNKNOWN",
"nullable": col.nullable if col.nullable is not None else True,
"comment": col.comment,
}
)
result = {
"success": True,
"catalog": catalog,
"schema": schema,
"table": table,
"full_name": full_name,
"table_type": str(table_info.table_type) if table_info.table_type else None,
"columns": columns,
"comment": table_info.comment,
}
if table_info.storage_location:
result["storage_location"] = table_info.storage_location
return result
except ImportError as e:
return {
"error": str(e),
"help": "Install the dependency by running: pip install 'databricks-sdk>=0.30.0'",
}
except Exception as e:
error_msg = str(e)
if "TEMPORARILY_UNAVAILABLE" in error_msg or "401" in error_msg:
return {
"error": "Databricks authentication failed",
"help": "Check that DATABRICKS_HOST and DATABRICKS_TOKEN are set correctly.",
}
if "NOT_FOUND" in error_msg or "DOES_NOT_EXIST" in error_msg:
return {
"error": f"Table not found: {full_name}",
"help": "Check that the catalog, schema, and table names are correct. "
f"Full error: {error_msg}",
}
if "403" in error_msg or "PERMISSION_DENIED" in error_msg:
return {
"error": f"Permission denied for table: {full_name}",
"help": "Ensure your token has the 'USE CATALOG', 'USE SCHEMA', "
"and 'SELECT' privileges on the target table.",
}
return {"error": f"Failed to describe table: {error_msg}"}
# Register managed MCP server tools
register_mcp_tools(mcp, credentials)
+783
View File
@@ -0,0 +1,783 @@
"""
Tests for Databricks tools.
Tests cover:
- Custom SQL tool: read-only enforcement, row limiting, query execution, error handling
- Describe table: input validation, successful description, error handling
- Managed MCP tools: SQL, UC functions, Vector Search, Genie, tool discovery
- Credential resolution from CredentialStoreAdapter and environment
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from fastmcp import FastMCP
from aden_tools.credentials import CredentialStoreAdapter
from aden_tools.tools.databricks_tool import register_tools
@pytest.fixture
def mcp():
"""Create a FastMCP instance for testing."""
return FastMCP("test-server")
@pytest.fixture
def mock_credentials():
"""Create mock credentials for testing."""
return CredentialStoreAdapter.for_testing(
{
"databricks_host": "https://test-workspace.cloud.databricks.com",
"databricks_token": "dapi_test_token_12345",
"databricks_warehouse": "test-warehouse-id",
}
)
@pytest.fixture
def registered_mcp(mcp, mock_credentials):
"""Register Databricks tools with mock credentials."""
register_tools(mcp, credentials=mock_credentials)
return mcp
# ===========================================================================
# Custom SQL Tool — Read-Only Enforcement
# ===========================================================================
class TestReadOnlyEnforcement:
"""Tests for SQL write operation blocking in run_databricks_sql."""
def test_blocks_insert(self, registered_mcp):
"""INSERT statements should be blocked."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="INSERT INTO table VALUES (1, 2)")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_blocks_update(self, registered_mcp):
"""UPDATE statements should be blocked."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="UPDATE table SET col = 1")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_blocks_delete(self, registered_mcp):
"""DELETE statements should be blocked."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="DELETE FROM table WHERE id = 1")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_blocks_drop(self, registered_mcp):
"""DROP statements should be blocked."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="DROP TABLE my_table")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_blocks_create(self, registered_mcp):
"""CREATE statements should be blocked."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="CREATE TABLE my_table (id INT)")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_blocks_alter(self, registered_mcp):
"""ALTER statements should be blocked."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="ALTER TABLE my_table ADD COLUMN new_col INT")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_blocks_truncate(self, registered_mcp):
"""TRUNCATE statements should be blocked."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="TRUNCATE TABLE my_table")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_blocks_merge(self, registered_mcp):
"""MERGE statements should be blocked."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="MERGE INTO target USING source ON condition WHEN MATCHED THEN UPDATE")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_blocks_case_insensitive(self, registered_mcp):
"""Write detection should be case-insensitive."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="insert into table values (1)")
assert "error" in result
assert "Write operations are not allowed" in result["error"]
def test_allows_select(self, registered_mcp):
"""SELECT statements should be allowed (will fail on client, not validation)."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = Exception("Mock error")
result = tool.fn(sql="SELECT * FROM table")
assert "Write operations are not allowed" not in result.get("error", "")
def test_allows_select_with_subquery(self, registered_mcp):
"""Complex SELECT with subqueries should be allowed."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = Exception("Mock error")
result = tool.fn(
sql="""
SELECT a.*, b.count
FROM (SELECT id, name FROM users) a
JOIN (SELECT user_id, COUNT(*) as count FROM orders GROUP BY user_id) b
ON a.id = b.user_id
"""
)
assert "Write operations are not allowed" not in result.get("error", "")
def test_ignores_write_keywords_in_comments(self, registered_mcp):
"""Write keywords inside comments should not trigger blocking."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = Exception("Mock error")
result = tool.fn(sql="SELECT * FROM t -- INSERT comment")
assert "Write operations are not allowed" not in result.get("error", "")
# ===========================================================================
# Custom SQL Tool — Row Limits
# ===========================================================================
class TestRowLimits:
"""Tests for row limit validation in run_databricks_sql."""
def test_rejects_zero_max_rows(self, registered_mcp):
"""max_rows of 0 should be rejected."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="SELECT 1", max_rows=0)
assert "error" in result
assert "max_rows must be at least 1" in result["error"]
def test_rejects_negative_max_rows(self, registered_mcp):
"""Negative max_rows should be rejected."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="SELECT 1", max_rows=-1)
assert "error" in result
assert "max_rows must be at least 1" in result["error"]
def test_rejects_excessive_max_rows(self, registered_mcp):
"""max_rows over 10000 should be rejected."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
result = tool.fn(sql="SELECT 1", max_rows=10001)
assert "error" in result
assert "max_rows cannot exceed 10000" in result["error"]
def test_accepts_valid_max_rows(self, registered_mcp):
"""Valid max_rows values should be accepted."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = Exception("Mock error")
for max_rows in [1, 100, 1000, 10000]:
result = tool.fn(sql="SELECT 1", max_rows=max_rows)
assert "max_rows" not in result.get("error", "")
# ===========================================================================
# Custom SQL Tool — Query Execution
# ===========================================================================
class TestQueryExecution:
"""Tests for successful query execution with mocked Databricks client."""
def test_successful_query(self, registered_mcp):
"""Test successful query execution with mocked WorkspaceClient."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_client = MagicMock()
mock_create.return_value = mock_client
# Mock the statement execution response
mock_col1 = MagicMock()
mock_col1.name = "id"
mock_col1.type_name = "INT"
mock_col2 = MagicMock()
mock_col2.name = "name"
mock_col2.type_name = "STRING"
mock_schema = MagicMock()
mock_schema.columns = [mock_col1, mock_col2]
mock_manifest = MagicMock()
mock_manifest.schema = mock_schema
mock_result = MagicMock()
mock_result.data_array = [["1", "Alice"], ["2", "Bob"]]
mock_response = MagicMock()
mock_response.status.error = None
mock_response.manifest = mock_manifest
mock_response.result = mock_result
mock_client.statement_execution.execute_statement.return_value = mock_response
result = tool.fn(sql="SELECT id, name FROM users")
assert result["success"] is True
assert result["rows"] == [
{"id": "1", "name": "Alice"},
{"id": "2", "name": "Bob"},
]
assert result["total_rows"] == 2
assert result["rows_returned"] == 2
assert result["query_truncated"] is False
assert len(result["schema"]) == 2
def test_query_truncation(self, registered_mcp):
"""Test that results are truncated when exceeding max_rows."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_client = MagicMock()
mock_create.return_value = mock_client
mock_col = MagicMock()
mock_col.name = "id"
mock_col.type_name = "INT"
mock_schema = MagicMock()
mock_schema.columns = [mock_col]
mock_manifest = MagicMock()
mock_manifest.schema = mock_schema
# Create 10 rows of data
mock_result = MagicMock()
mock_result.data_array = [[str(i)] for i in range(10)]
mock_response = MagicMock()
mock_response.status.error = None
mock_response.manifest = mock_manifest
mock_response.result = mock_result
mock_client.statement_execution.execute_statement.return_value = mock_response
# Request only 5 rows
result = tool.fn(sql="SELECT id FROM users", max_rows=5)
assert result["success"] is True
assert result["total_rows"] == 10
assert result["rows_returned"] == 5
assert result["query_truncated"] is True
assert len(result["rows"]) == 5
def test_missing_warehouse_id(self, mcp):
"""Test error when no warehouse ID is configured."""
creds = CredentialStoreAdapter.for_testing(
{
"databricks_host": "https://test.cloud.databricks.com",
"databricks_token": "dapi_test",
}
)
register_tools(mcp, credentials=creds)
tool = mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.return_value = MagicMock()
result = tool.fn(sql="SELECT 1")
assert "error" in result
assert "No SQL Warehouse ID provided" in result["error"]
# ===========================================================================
# Describe Table
# ===========================================================================
class TestDescribeTable:
"""Tests for describe_databricks_table tool."""
def test_empty_catalog(self, registered_mcp):
"""Empty catalog should be rejected."""
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
result = tool.fn(catalog="", schema="default", table="users")
assert "error" in result
assert "catalog is required" in result["error"]
def test_empty_schema(self, registered_mcp):
"""Empty schema should be rejected."""
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
result = tool.fn(catalog="main", schema="", table="users")
assert "error" in result
assert "schema is required" in result["error"]
def test_empty_table(self, registered_mcp):
"""Empty table name should be rejected."""
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
result = tool.fn(catalog="main", schema="default", table="")
assert "error" in result
assert "table name is required" in result["error"]
def test_successful_describe(self, registered_mcp):
"""Test successful table description with mocked client."""
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_client = MagicMock()
mock_create.return_value = mock_client
# Mock table info
mock_col1 = MagicMock()
mock_col1.name = "id"
mock_col1.type_name = "LONG"
mock_col1.nullable = False
mock_col1.comment = "User ID"
mock_col2 = MagicMock()
mock_col2.name = "email"
mock_col2.type_name = "STRING"
mock_col2.nullable = True
mock_col2.comment = None
mock_table_info = MagicMock()
mock_table_info.columns = [mock_col1, mock_col2]
mock_table_info.table_type = "MANAGED"
mock_table_info.comment = "User accounts table"
mock_table_info.storage_location = "s3://bucket/path"
mock_client.tables.get.return_value = mock_table_info
result = tool.fn(catalog="main", schema="default", table="users")
assert result["success"] is True
assert result["catalog"] == "main"
assert result["schema"] == "default"
assert result["table"] == "users"
assert result["full_name"] == "main.default.users"
assert result["table_type"] == "MANAGED"
assert result["comment"] == "User accounts table"
assert result["storage_location"] == "s3://bucket/path"
assert len(result["columns"]) == 2
assert result["columns"][0]["name"] == "id"
assert result["columns"][0]["nullable"] is False
assert result["columns"][1]["name"] == "email"
assert result["columns"][1]["nullable"] is True
# ===========================================================================
# Managed MCP Tools
# ===========================================================================
class TestMCPQuerySQL:
"""Tests for databricks_mcp_query_sql tool."""
def test_empty_sql(self, registered_mcp):
"""Empty SQL should be rejected."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"]
result = tool.fn(sql="")
assert "error" in result
assert "sql is required" in result["error"]
def test_successful_mcp_sql(self, registered_mcp):
"""Test successful MCP SQL query with mocked client."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
) as mock_get_client:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_content = MagicMock()
mock_content.text = "Query results: 42 rows returned"
mock_response = MagicMock()
mock_response.content = [mock_content]
mock_client.call_tool.return_value = mock_response
result = tool.fn(sql="SELECT * FROM main.default.users")
assert result["success"] is True
assert "Query results" in result["result"]
class TestMCPUCFunction:
"""Tests for databricks_mcp_query_uc_function tool."""
def test_empty_catalog(self, registered_mcp):
"""Empty catalog should be rejected."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"]
result = tool.fn(catalog="", schema="default", function_name="my_func")
assert "error" in result
assert "catalog is required" in result["error"]
def test_empty_function_name(self, registered_mcp):
"""Empty function name should be rejected."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"]
result = tool.fn(catalog="main", schema="default", function_name="")
assert "error" in result
assert "function_name is required" in result["error"]
def test_successful_uc_function(self, registered_mcp):
"""Test successful UC function call with mocked client."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"]
with patch(
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
) as mock_get_client:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_content = MagicMock()
mock_content.text = "Revenue: $1.2M"
mock_response = MagicMock()
mock_response.content = [mock_content]
mock_client.call_tool.return_value = mock_response
result = tool.fn(
catalog="main",
schema="analytics",
function_name="get_revenue",
arguments={"year": "2024"},
)
assert result["success"] is True
assert "Revenue" in result["result"]
class TestMCPVectorSearch:
"""Tests for databricks_mcp_vector_search tool."""
def test_empty_query(self, registered_mcp):
"""Empty query should be rejected."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"]
result = tool.fn(catalog="prod", schema="kb", index_name="idx", query="")
assert "error" in result
assert "query is required" in result["error"]
def test_empty_index_name(self, registered_mcp):
"""Empty index name should be rejected."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"]
result = tool.fn(catalog="prod", schema="kb", index_name="", query="test")
assert "error" in result
assert "index_name is required" in result["error"]
def test_successful_vector_search(self, registered_mcp):
"""Test successful vector search with mocked client."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"]
with patch(
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
) as mock_get_client:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_tool = MagicMock()
mock_tool.name = "search_index"
mock_client.list_tools.return_value = [mock_tool]
mock_content = MagicMock()
mock_content.text = "Found 3 relevant documents"
mock_response = MagicMock()
mock_response.content = [mock_content]
mock_client.call_tool.return_value = mock_response
result = tool.fn(
catalog="prod",
schema="kb",
index_name="docs_index",
query="How to configure auth?",
num_results=5,
)
assert result["success"] is True
assert "relevant documents" in result["result"]
class TestMCPGenie:
"""Tests for databricks_mcp_query_genie tool."""
def test_empty_genie_space_id(self, registered_mcp):
"""Empty genie space ID should be rejected."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"]
result = tool.fn(genie_space_id="", question="What is revenue?")
assert "error" in result
assert "genie_space_id is required" in result["error"]
def test_empty_question(self, registered_mcp):
"""Empty question should be rejected."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"]
result = tool.fn(genie_space_id="abc123", question="")
assert "error" in result
assert "question is required" in result["error"]
def test_successful_genie_query(self, registered_mcp):
"""Test successful Genie query with mocked client."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"]
with patch(
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
) as mock_get_client:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_tool = MagicMock()
mock_tool.name = "ask_genie"
mock_client.list_tools.return_value = [mock_tool]
mock_content = MagicMock()
mock_content.text = "Total revenue last quarter was $1.2M"
mock_response = MagicMock()
mock_response.content = [mock_content]
mock_client.call_tool.return_value = mock_response
result = tool.fn(
genie_space_id="abc123",
question="What was revenue last quarter?",
)
assert result["success"] is True
assert "$1.2M" in result["result"]
class TestMCPListTools:
"""Tests for databricks_mcp_list_tools tool."""
def test_missing_server_url_and_type(self, registered_mcp):
"""Should return error when neither server_url nor server_type is given."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"]
result = tool.fn()
assert "error" in result
assert "server_url or server_type is required" in result["error"]
def test_invalid_server_type(self, registered_mcp):
"""Should return error for invalid server type."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"]
result = tool.fn(server_type="invalid")
assert "error" in result
assert "Invalid server_type" in result["error"]
def test_successful_list_tools(self, registered_mcp):
"""Test successful tool listing with mocked client."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"]
with patch(
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
) as mock_get_client:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_tool1 = MagicMock()
mock_tool1.name = "execute_sql"
mock_tool1.description = "Execute SQL queries"
mock_tool1.inputSchema = {"type": "object", "properties": {}}
mock_tool2 = MagicMock()
mock_tool2.name = "list_tables"
mock_tool2.description = "List available tables"
mock_tool2.inputSchema = None
mock_client.list_tools.return_value = [mock_tool1, mock_tool2]
result = tool.fn(server_type="sql")
assert result["success"] is True
assert len(result["tools"]) == 2
assert result["tools"][0]["name"] == "execute_sql"
assert result["tools"][1]["name"] == "list_tables"
assert "https://test-workspace.cloud.databricks.com" in result["server_url"]
def test_with_direct_server_url(self, registered_mcp):
"""Test tool listing with a direct server URL."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"]
with patch(
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
) as mock_get_client:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.list_tools.return_value = []
url = "https://my-workspace.cloud.databricks.com/api/2.0/mcp/sql"
result = tool.fn(server_url=url)
assert result["success"] is True
assert result["server_url"] == url
# ===========================================================================
# Error Handling
# ===========================================================================
class TestErrorHandling:
"""Tests for error handling and user-friendly messages."""
def test_authentication_error(self, registered_mcp):
"""Authentication errors should provide helpful messages."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = Exception("401 Unauthorized")
result = tool.fn(sql="SELECT 1")
assert "error" in result
assert "authentication failed" in result["error"].lower()
assert "help" in result
def test_permission_error(self, registered_mcp):
"""Permission errors should provide helpful messages."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = Exception("PERMISSION_DENIED: access denied")
result = tool.fn(sql="SELECT 1")
assert "error" in result
assert "permission denied" in result["error"].lower()
assert "help" in result
def test_not_found_error(self, registered_mcp):
"""Not found errors should provide helpful messages."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = Exception("NOT_FOUND: warehouse xyz not found")
result = tool.fn(sql="SELECT 1")
assert "error" in result
assert "not found" in result["error"].lower()
assert "help" in result
def test_table_not_found_error(self, registered_mcp):
"""Table not found errors should provide helpful messages."""
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = Exception("DOES_NOT_EXIST: table not found")
result = tool.fn(catalog="main", schema="default", table="nonexistent")
assert "error" in result
assert "not found" in result["error"].lower()
def test_mcp_import_error(self, registered_mcp):
"""Import errors for managed MCP tools should be helpful."""
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
) as mock_get:
mock_get.side_effect = ImportError("databricks-mcp and databricks-sdk are required")
result = tool.fn(sql="SELECT 1")
assert "error" in result
assert "databricks" in result["error"].lower()
# ===========================================================================
# Credential Resolution
# ===========================================================================
class TestCredentialResolution:
"""Tests for credential resolution from different sources."""
def test_uses_credential_store(self, mcp):
"""Should use credentials from CredentialStoreAdapter."""
mock_creds = CredentialStoreAdapter.for_testing(
{
"databricks_host": "https://custom.cloud.databricks.com",
"databricks_token": "dapi_custom_token",
"databricks_warehouse": "custom-warehouse",
}
)
register_tools(mcp, credentials=mock_creds)
assert mock_creds.get("databricks_host") == "https://custom.cloud.databricks.com"
assert mock_creds.get("databricks_token") == "dapi_custom_token"
assert mock_creds.get("databricks_warehouse") == "custom-warehouse"
def test_falls_back_to_env_vars(self, mcp):
"""Should fall back to environment variables when no credential store."""
register_tools(mcp, credentials=None)
# Tools are registered and will use os.getenv internally
assert "run_databricks_sql" in mcp._tool_manager._tools
assert "describe_databricks_table" in mcp._tool_manager._tools
assert "databricks_mcp_query_sql" in mcp._tool_manager._tools
assert "databricks_mcp_query_uc_function" in mcp._tool_manager._tools
assert "databricks_mcp_vector_search" in mcp._tool_manager._tools
assert "databricks_mcp_query_genie" in mcp._tool_manager._tools
assert "databricks_mcp_list_tools" in mcp._tool_manager._tools
def test_all_tools_registered(self, registered_mcp):
"""All 7 Databricks tools should be registered."""
expected_tools = [
"run_databricks_sql",
"describe_databricks_table",
"databricks_mcp_query_sql",
"databricks_mcp_query_uc_function",
"databricks_mcp_vector_search",
"databricks_mcp_query_genie",
"databricks_mcp_list_tools",
]
for tool_name in expected_tools:
assert tool_name in registered_mcp._tool_manager._tools, (
f"Tool '{tool_name}' not registered"
)
# ===========================================================================
# Import Error Handling
# ===========================================================================
class TestImportError:
"""Tests for handling missing databricks-sdk package."""
def test_sdk_import_error_message(self, registered_mcp):
"""Should provide helpful message when databricks-sdk not installed."""
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
with patch(
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
) as mock_create:
mock_create.side_effect = ImportError(
"databricks-sdk is required for Databricks tools. "
"Install it with: pip install 'databricks-sdk>=0.30.0'"
)
result = tool.fn(sql="SELECT 1")
assert "error" in result
assert "databricks-sdk" in result["error"]
assert "pip install" in result["error"]
Generated
+37 -1
View File
@@ -536,6 +536,32 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/7c/996760c30f1302704af57c66ff2d723f7d656d0d0b93563b5528a51484bb/cyclopts-4.5.1-py3-none-any.whl", hash = "sha256:0642c93601e554ca6b7b9abd81093847ea4448b2616280f2a0952416574e8c7a", size = 199807, upload-time = "2026-01-25T15:23:55.219Z" },
]
[[package]]
name = "databricks-mcp"
version = "0.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "databricks-sdk" },
{ name = "mcp" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/49/10/9f12e8d3322327a8edac8b906a0f1cece3dfbc58efdd2198538b4ba56957/databricks_mcp-0.1.0-py3-none-any.whl", hash = "sha256:d02ed6cdd24a862a365b4fae2892223d8209d14e5be500ebc14d086f646408cc", size = 2307, upload-time = "2025-06-05T20:39:11.724Z" },
]
[[package]]
name = "databricks-sdk"
version = "0.94.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-auth" },
{ name = "protobuf" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3e/39/18b3a4a69851f583d7970ae71f19975e06ef892ea17393bcafb773088f96/databricks_sdk-0.94.0.tar.gz", hash = "sha256:1b29a9d63e2dc9808ed79253ba3668cd6d130dabb880a576d647fa3b1b1a7b30", size = 861465, upload-time = "2026-02-26T08:17:12.063Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/aa/da/072fe4df10aec583eb72327856b4743c7f5fb50a16772d5f86fcc20e2457/databricks_sdk-0.94.0-py3-none-any.whl", hash = "sha256:ae980186ee2d99f3639d14a94457776dca22980a04f65f89497a79889b2b9149", size = 811079, upload-time = "2026-02-26T08:17:10.242Z" },
]
[[package]]
name = "diff-match-patch"
version = "20241021"
@@ -3490,6 +3516,8 @@ dependencies = [
[package.optional-dependencies]
all = [
{ name = "databricks-mcp" },
{ name = "databricks-sdk" },
{ name = "duckdb" },
{ name = "google-cloud-bigquery" },
{ name = "openpyxl" },
@@ -3500,6 +3528,10 @@ all = [
bigquery = [
{ name = "google-cloud-bigquery" },
]
databricks = [
{ name = "databricks-mcp" },
{ name = "databricks-sdk" },
]
dev = [
{ name = "pytest" },
{ name = "pytest-asyncio" },
@@ -3529,6 +3561,10 @@ dev = [
requires-dist = [
{ name = "arxiv", specifier = ">=2.1.0" },
{ name = "beautifulsoup4", specifier = ">=4.12.0" },
{ name = "databricks-mcp", marker = "extra == 'all'", specifier = ">=0.1.0" },
{ name = "databricks-mcp", marker = "extra == 'databricks'", specifier = ">=0.1.0" },
{ name = "databricks-sdk", marker = "extra == 'all'", specifier = ">=0.30.0" },
{ name = "databricks-sdk", marker = "extra == 'databricks'", specifier = ">=0.30.0" },
{ name = "diff-match-patch", specifier = ">=20230430" },
{ name = "dnspython", specifier = ">=2.4.0" },
{ name = "duckdb", marker = "extra == 'all'", specifier = ">=1.0.0" },
@@ -3561,7 +3597,7 @@ requires-dist = [
{ name = "restrictedpython", marker = "extra == 'sandbox'", specifier = ">=7.0" },
{ name = "stripe", specifier = ">=14.3.0" },
]
provides-extras = ["dev", "sandbox", "ocr", "excel", "sql", "bigquery", "all"]
provides-extras = ["dev", "sandbox", "ocr", "excel", "sql", "bigquery", "databricks", "all"]
[package.metadata.requires-dev]
dev = [