Integration: add Databricks MCP tool integration
Implements the Databricks MCP tool integration for the Hive agent framework
This commit is contained in:
@@ -59,6 +59,10 @@ sql = [
|
||||
bigquery = [
|
||||
"google-cloud-bigquery>=3.0.0",
|
||||
]
|
||||
databricks = [
|
||||
"databricks-sdk>=0.30.0",
|
||||
"databricks-mcp>=0.1.0",
|
||||
]
|
||||
all = [
|
||||
"RestrictedPython>=7.0",
|
||||
"pytesseract>=0.3.10",
|
||||
@@ -66,6 +70,8 @@ all = [
|
||||
"duckdb>=1.0.0",
|
||||
"openpyxl>=3.1.0",
|
||||
"google-cloud-bigquery>=3.0.0",
|
||||
"databricks-sdk>=0.30.0",
|
||||
"databricks-mcp>=0.1.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
||||
@@ -56,6 +56,7 @@ To add a new credential:
|
||||
from .apollo import APOLLO_CREDENTIALS
|
||||
from .base import CredentialError, CredentialSpec
|
||||
from .bigquery import BIGQUERY_CREDENTIALS
|
||||
from .databricks import DATABRICKS_CREDENTIALS
|
||||
from .brevo import BREVO_CREDENTIALS
|
||||
from .browser import get_aden_auth_url, get_aden_setup_url, open_browser
|
||||
from .calcom import CALCOM_CREDENTIALS
|
||||
@@ -113,6 +114,7 @@ CREDENTIAL_SPECS = {
|
||||
**STRIPE_CREDENTIALS,
|
||||
**BREVO_CREDENTIALS,
|
||||
**POSTGRES_CREDENTIALS,
|
||||
**DATABRICKS_CREDENTIALS,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
@@ -160,4 +162,5 @@ __all__ = [
|
||||
"STRIPE_CREDENTIALS",
|
||||
"BREVO_CREDENTIALS",
|
||||
"POSTGRES_CREDENTIALS",
|
||||
"DATABRICKS_CREDENTIALS",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Databricks tool credentials.
|
||||
|
||||
Contains credentials for Databricks workspace and SQL Warehouse access,
|
||||
as well as managed MCP server connectivity.
|
||||
"""
|
||||
|
||||
from .base import CredentialSpec
|
||||
|
||||
# All tool names that require Databricks credentials
|
||||
_DATABRICKS_TOOLS = [
|
||||
# Custom SQL tools (databricks_tool.py)
|
||||
"run_databricks_sql",
|
||||
"describe_databricks_table",
|
||||
# Managed MCP tools (databricks_mcp_tool.py)
|
||||
"databricks_mcp_query_sql",
|
||||
"databricks_mcp_query_uc_function",
|
||||
"databricks_mcp_vector_search",
|
||||
"databricks_mcp_query_genie",
|
||||
"databricks_mcp_list_tools",
|
||||
]
|
||||
|
||||
DATABRICKS_CREDENTIALS = {
|
||||
"databricks_host": CredentialSpec(
|
||||
env_var="DATABRICKS_HOST",
|
||||
tools=_DATABRICKS_TOOLS,
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://docs.databricks.com/en/workspace/workspace-details.html",
|
||||
description="Databricks workspace URL (e.g., https://dbc-a1b2c3d4-e5f6.cloud.databricks.com)",
|
||||
# Auth method support
|
||||
aden_supported=False,
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To set up Databricks authentication:
|
||||
|
||||
1. Go to your Databricks workspace
|
||||
2. Copy the workspace URL from your browser (e.g., https://dbc-a1b2c3d4-e5f6.cloud.databricks.com)
|
||||
3. Set DATABRICKS_HOST=https://your-workspace-hostname
|
||||
|
||||
Note: Do not include a trailing slash.""",
|
||||
# Credential store mapping
|
||||
credential_id="databricks_host",
|
||||
credential_key="workspace_url",
|
||||
),
|
||||
"databricks_token": CredentialSpec(
|
||||
env_var="DATABRICKS_TOKEN",
|
||||
tools=_DATABRICKS_TOOLS,
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://docs.databricks.com/en/dev-tools/auth/pat.html",
|
||||
description="Databricks personal access token for API authentication",
|
||||
# Auth method support
|
||||
aden_supported=False,
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To create a Databricks personal access token:
|
||||
|
||||
1. Go to your Databricks workspace
|
||||
2. Click your username in the top bar > Settings
|
||||
3. Click 'Developer' in the sidebar
|
||||
4. Under 'Access tokens', click 'Manage'
|
||||
5. Click 'Generate new token'
|
||||
6. Give it a description and set the lifetime
|
||||
7. Copy the token and set DATABRICKS_TOKEN=dapi...
|
||||
|
||||
For service principals, use OAuth machine-to-machine tokens instead.""",
|
||||
# Credential store mapping
|
||||
credential_id="databricks_token",
|
||||
credential_key="access_token",
|
||||
),
|
||||
"databricks_warehouse": CredentialSpec(
|
||||
env_var="DATABRICKS_WAREHOUSE_ID",
|
||||
tools=["run_databricks_sql", "databricks_mcp_query_sql"],
|
||||
required=False,
|
||||
startup_required=False,
|
||||
help_url="https://docs.databricks.com/en/sql/admin/create-sql-warehouse.html",
|
||||
description="Default Databricks SQL Warehouse ID for query execution",
|
||||
aden_supported=False,
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To find your SQL Warehouse ID:
|
||||
|
||||
1. Go to your Databricks workspace
|
||||
2. Click 'SQL Warehouses' in the sidebar
|
||||
3. Click on your warehouse
|
||||
4. Copy the ID from the URL or the 'Connection details' tab
|
||||
5. Set DATABRICKS_WAREHOUSE_ID=your_warehouse_id""",
|
||||
# Credential store mapping
|
||||
credential_id="databricks_warehouse",
|
||||
credential_key="warehouse_id",
|
||||
),
|
||||
}
|
||||
@@ -29,6 +29,7 @@ from .brevo_tool import register_tools as register_brevo
|
||||
from .calcom_tool import register_tools as register_calcom
|
||||
from .calendar_tool import register_tools as register_calendar
|
||||
from .csv_tool import register_tools as register_csv
|
||||
from .databricks_tool import register_tools as register_databricks
|
||||
from .discord_tool import register_tools as register_discord
|
||||
|
||||
# Security scanning tools
|
||||
@@ -157,6 +158,9 @@ def register_all_tools(
|
||||
# Postgres tool
|
||||
register_postgres(mcp, credentials=credentials)
|
||||
|
||||
# Databricks tools
|
||||
register_databricks(mcp, credentials=credentials)
|
||||
|
||||
# Return the list of all registered tool names
|
||||
return list(mcp._tool_manager._tools.keys())
|
||||
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
# Databricks Tool
|
||||
|
||||
Query Databricks SQL Warehouses and interact with Databricks managed MCP servers.
|
||||
|
||||
## Tools
|
||||
|
||||
### Custom SQL Tools (Read-Only)
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `run_databricks_sql` | Execute read-only SQL queries against a Databricks SQL Warehouse |
|
||||
| `describe_databricks_table` | Fetch table schema/metadata from Unity Catalog |
|
||||
|
||||
### Managed MCP Server Tools
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `databricks_mcp_query_sql` | Execute SQL via the managed SQL MCP server |
|
||||
| `databricks_mcp_query_uc_function` | Execute a Unity Catalog function |
|
||||
| `databricks_mcp_vector_search` | Query a Vector Search index |
|
||||
| `databricks_mcp_query_genie` | Query a Genie space with natural language |
|
||||
| `databricks_mcp_list_tools` | Discover tools on any managed MCP server endpoint |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `DATABRICKS_HOST` | Yes | Workspace URL (e.g., `https://dbc-xxx.cloud.databricks.com`) |
|
||||
| `DATABRICKS_TOKEN` | Yes | Personal access token (`dapi...`) |
|
||||
| `DATABRICKS_WAREHOUSE_ID` | No | Default SQL Warehouse ID |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Execute a Read-Only SQL Query
|
||||
|
||||
```python
|
||||
run_databricks_sql(
|
||||
sql="SELECT name, COUNT(*) as cnt FROM main.default.users GROUP BY name",
|
||||
warehouse_id="abc123def456",
|
||||
max_rows=100
|
||||
)
|
||||
```
|
||||
|
||||
### Describe a Unity Catalog Table
|
||||
|
||||
```python
|
||||
describe_databricks_table(
|
||||
catalog="main",
|
||||
schema="default",
|
||||
table="users"
|
||||
)
|
||||
```
|
||||
|
||||
### Query via Managed MCP SQL Server
|
||||
|
||||
```python
|
||||
databricks_mcp_query_sql(
|
||||
sql="SELECT * FROM main.default.orders LIMIT 10"
|
||||
)
|
||||
```
|
||||
|
||||
### Execute a Unity Catalog Function
|
||||
|
||||
```python
|
||||
databricks_mcp_query_uc_function(
|
||||
catalog="main",
|
||||
schema="analytics",
|
||||
function_name="get_revenue_summary",
|
||||
arguments={"start_date": "2024-01-01"}
|
||||
)
|
||||
```
|
||||
|
||||
### Search a Vector Index
|
||||
|
||||
```python
|
||||
databricks_mcp_vector_search(
|
||||
catalog="prod",
|
||||
schema="knowledge_base",
|
||||
index_name="docs_index",
|
||||
query="How to configure authentication?",
|
||||
num_results=5
|
||||
)
|
||||
```
|
||||
|
||||
### Query a Genie Space
|
||||
|
||||
```python
|
||||
databricks_mcp_query_genie(
|
||||
genie_space_id="abc123",
|
||||
question="What was the total revenue last quarter?"
|
||||
)
|
||||
```
|
||||
|
||||
### Discover Available MCP Tools
|
||||
|
||||
```python
|
||||
databricks_mcp_list_tools(
|
||||
server_type="functions",
|
||||
resource_path="system/ai"
|
||||
)
|
||||
```
|
||||
|
||||
## Safety Features
|
||||
|
||||
- **Read-only enforcement** on `run_databricks_sql`: INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, TRUNCATE, MERGE, and REPLACE are blocked
|
||||
- **Row limits**: Configurable max_rows (1–10,000) to prevent large result sets
|
||||
- **Credential isolation**: Uses CredentialStoreAdapter pattern; secrets never logged
|
||||
|
||||
## Error Handling
|
||||
|
||||
All tools return structured error dicts with `error` and optional `help` fields. Common errors include:
|
||||
|
||||
- **Authentication failure**: Invalid or expired token
|
||||
- **Permission denied**: Insufficient privileges on the target resource
|
||||
- **Not found**: Invalid catalog, schema, table, or warehouse ID
|
||||
- **Missing dependency**: `databricks-sdk` or `databricks-mcp` not installed
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install 'databricks-sdk>=0.30.0' 'databricks-mcp>=0.1.0'
|
||||
```
|
||||
|
||||
Or via the project's optional dependencies:
|
||||
|
||||
```bash
|
||||
pip install '.[databricks]'
|
||||
```
|
||||
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Databricks Tool - Query Databricks SQL Warehouses and interact with managed MCP servers.
|
||||
|
||||
Provides MCP tools for:
|
||||
- Executing read-only SQL queries via Databricks SQL Warehouses
|
||||
- Describing tables via Unity Catalog
|
||||
- Interacting with Databricks managed MCP servers (SQL, Vector Search, Genie, UC functions)
|
||||
"""
|
||||
|
||||
from .databricks_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,562 @@
|
||||
"""
|
||||
Databricks Managed MCP Server Tools.
|
||||
|
||||
Provides tools to interact with Databricks managed MCP server endpoints:
|
||||
- SQL: Execute queries via the managed SQL MCP server
|
||||
- Unity Catalog Functions: Execute predefined UC functions
|
||||
- Vector Search: Query Vector Search indexes
|
||||
- Genie: Query Genie spaces with natural language
|
||||
- Discovery: List available tools on any managed MCP server
|
||||
|
||||
These tools use the official databricks-mcp library for authentication
|
||||
and communication with Databricks managed MCP server endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_mcp_client(server_url: str, host: str | None, token: str | None) -> Any:
|
||||
"""
|
||||
Create a DatabricksMCPClient for the given server URL.
|
||||
|
||||
Args:
|
||||
server_url: Full URL of the managed MCP server endpoint
|
||||
host: Databricks workspace URL
|
||||
token: Personal access token
|
||||
|
||||
Returns:
|
||||
DatabricksMCPClient instance
|
||||
|
||||
Raises:
|
||||
ImportError: If databricks-mcp or databricks-sdk is not installed
|
||||
"""
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks_mcp import DatabricksMCPClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"databricks-mcp and databricks-sdk are required for Databricks MCP tools. "
|
||||
"Install them with: pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'"
|
||||
) from None
|
||||
|
||||
kwargs: dict[str, str] = {}
|
||||
if host:
|
||||
kwargs["host"] = host
|
||||
if token:
|
||||
kwargs["token"] = token
|
||||
|
||||
workspace_client = WorkspaceClient(**kwargs)
|
||||
return DatabricksMCPClient(server_url=server_url, workspace_client=workspace_client)
|
||||
|
||||
|
||||
def register_mcp_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register Databricks managed MCP server tools with the MCP server."""
|
||||
|
||||
def _get_credentials() -> dict[str, str | None]:
|
||||
"""Get Databricks credentials from credential store or environment."""
|
||||
if credentials is not None:
|
||||
try:
|
||||
host = credentials.get("databricks_host")
|
||||
except KeyError:
|
||||
host = None
|
||||
try:
|
||||
token = credentials.get("databricks_token")
|
||||
except KeyError:
|
||||
token = None
|
||||
try:
|
||||
warehouse = credentials.get("databricks_warehouse")
|
||||
except KeyError:
|
||||
warehouse = None
|
||||
return {
|
||||
"host": host,
|
||||
"token": token,
|
||||
"warehouse_id": warehouse,
|
||||
}
|
||||
return {
|
||||
"host": os.getenv("DATABRICKS_HOST"),
|
||||
"token": os.getenv("DATABRICKS_TOKEN"),
|
||||
"warehouse_id": os.getenv("DATABRICKS_WAREHOUSE_ID"),
|
||||
}
|
||||
|
||||
def _get_host() -> str | None:
|
||||
"""Get the Databricks workspace host URL."""
|
||||
creds = _get_credentials()
|
||||
return creds.get("host")
|
||||
|
||||
def _build_server_url(path: str) -> str | None:
|
||||
"""Build a full managed MCP server URL from a path suffix."""
|
||||
host = _get_host()
|
||||
if not host:
|
||||
return None
|
||||
# Ensure host doesn't have trailing slash
|
||||
host = host.rstrip("/")
|
||||
return f"{host}{path}"
|
||||
|
||||
@mcp.tool()
|
||||
def databricks_mcp_query_sql(
|
||||
sql: str,
|
||||
warehouse_id: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Execute a SQL query via the Databricks managed SQL MCP server.
|
||||
|
||||
Unlike run_databricks_sql, this tool uses the official Databricks managed
|
||||
MCP SQL server endpoint and supports both read and write operations as
|
||||
permitted by the workspace.
|
||||
|
||||
Args:
|
||||
sql: The SQL query to execute.
|
||||
warehouse_id: SQL Warehouse ID. Falls back to DATABRICKS_WAREHOUSE_ID
|
||||
env var if not provided. Required for the SQL MCP server.
|
||||
|
||||
Returns:
|
||||
Dict with query results:
|
||||
- success: True if query executed successfully
|
||||
- result: The query result text from the MCP server
|
||||
|
||||
Or error dict with:
|
||||
- error: Error message
|
||||
- help: Optional help text
|
||||
|
||||
Example:
|
||||
>>> databricks_mcp_query_sql("SELECT * FROM main.default.users LIMIT 10")
|
||||
{
|
||||
"success": True,
|
||||
"result": "..."
|
||||
}
|
||||
"""
|
||||
if not sql or not sql.strip():
|
||||
return {"error": "sql is required"}
|
||||
|
||||
try:
|
||||
creds = _get_credentials()
|
||||
server_url = _build_server_url("/api/2.0/mcp/sql")
|
||||
|
||||
if not server_url:
|
||||
return {
|
||||
"error": "Databricks host not configured",
|
||||
"help": "Set DATABRICKS_HOST environment variable to your workspace URL.",
|
||||
}
|
||||
|
||||
effective_warehouse = warehouse_id or creds.get("warehouse_id")
|
||||
mcp_client = _get_mcp_client(
|
||||
server_url=server_url,
|
||||
host=creds.get("host"),
|
||||
token=creds.get("token"),
|
||||
)
|
||||
|
||||
# Build arguments for the SQL tool
|
||||
tool_args: dict[str, Any] = {"statement": sql}
|
||||
if effective_warehouse:
|
||||
tool_args["warehouse_id"] = effective_warehouse
|
||||
|
||||
response = mcp_client.call_tool("execute_sql", tool_args)
|
||||
result_text = "".join([c.text for c in response.content])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result_text,
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"help": "Install dependencies: "
|
||||
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Databricks MCP SQL query failed: {e!s}"}
|
||||
|
||||
@mcp.tool()
|
||||
def databricks_mcp_query_uc_function(
|
||||
catalog: str,
|
||||
schema: str,
|
||||
function_name: str,
|
||||
arguments: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Execute a Unity Catalog function via the Databricks managed MCP server.
|
||||
|
||||
Use this to run predefined SQL functions registered in Unity Catalog.
|
||||
These functions encapsulate business logic and can be invoked as tools.
|
||||
|
||||
Args:
|
||||
catalog: Unity Catalog catalog name (e.g., "main").
|
||||
schema: Schema name within the catalog (e.g., "default").
|
||||
function_name: Name of the UC function to execute.
|
||||
arguments: Optional dict of arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
Dict with function result:
|
||||
- success: True if function executed successfully
|
||||
- result: The function result text from the MCP server
|
||||
|
||||
Or error dict with:
|
||||
- error: Error message
|
||||
|
||||
Example:
|
||||
>>> databricks_mcp_query_uc_function(
|
||||
... catalog="main",
|
||||
... schema="analytics",
|
||||
... function_name="get_revenue_summary",
|
||||
... arguments={"start_date": "2024-01-01", "end_date": "2024-12-31"}
|
||||
... )
|
||||
{
|
||||
"success": True,
|
||||
"result": "Revenue summary: ..."
|
||||
}
|
||||
"""
|
||||
if not catalog or not catalog.strip():
|
||||
return {"error": "catalog is required"}
|
||||
if not schema or not schema.strip():
|
||||
return {"error": "schema is required"}
|
||||
if not function_name or not function_name.strip():
|
||||
return {"error": "function_name is required"}
|
||||
|
||||
try:
|
||||
creds = _get_credentials()
|
||||
path = f"/api/2.0/mcp/functions/{catalog}/{schema}/{function_name}"
|
||||
server_url = _build_server_url(path)
|
||||
|
||||
if not server_url:
|
||||
return {
|
||||
"error": "Databricks host not configured",
|
||||
"help": "Set DATABRICKS_HOST environment variable.",
|
||||
}
|
||||
|
||||
mcp_client = _get_mcp_client(
|
||||
server_url=server_url,
|
||||
host=creds.get("host"),
|
||||
token=creds.get("token"),
|
||||
)
|
||||
|
||||
# Construct the tool name using the UC naming convention
|
||||
tool_name = f"{catalog}__{schema}__{function_name}"
|
||||
tool_args = arguments or {}
|
||||
|
||||
response = mcp_client.call_tool(tool_name, tool_args)
|
||||
result_text = "".join([c.text for c in response.content])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result_text,
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"help": "Install dependencies: "
|
||||
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Databricks UC function call failed: {e!s}"}
|
||||
|
||||
@mcp.tool()
|
||||
def databricks_mcp_vector_search(
|
||||
catalog: str,
|
||||
schema: str,
|
||||
index_name: str,
|
||||
query: str,
|
||||
num_results: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Query a Databricks Vector Search index via the managed MCP server.
|
||||
|
||||
Use this to find semantically relevant documents from a Vector Search
|
||||
index that uses Databricks managed embeddings.
|
||||
|
||||
Args:
|
||||
catalog: Unity Catalog catalog name containing the index.
|
||||
schema: Schema name within the catalog.
|
||||
index_name: Name of the Vector Search index.
|
||||
query: The search query text.
|
||||
num_results: Number of results to return (default: 10).
|
||||
|
||||
Returns:
|
||||
Dict with search results:
|
||||
- success: True if search executed successfully
|
||||
- result: The search result text from the MCP server
|
||||
|
||||
Or error dict with:
|
||||
- error: Error message
|
||||
|
||||
Example:
|
||||
>>> databricks_mcp_vector_search(
|
||||
... catalog="prod",
|
||||
... schema="knowledge_base",
|
||||
... index_name="docs_index",
|
||||
... query="How to configure authentication?",
|
||||
... num_results=5
|
||||
... )
|
||||
{
|
||||
"success": True,
|
||||
"result": "..."
|
||||
}
|
||||
"""
|
||||
if not catalog or not catalog.strip():
|
||||
return {"error": "catalog is required"}
|
||||
if not schema or not schema.strip():
|
||||
return {"error": "schema is required"}
|
||||
if not index_name or not index_name.strip():
|
||||
return {"error": "index_name is required"}
|
||||
if not query or not query.strip():
|
||||
return {"error": "query is required"}
|
||||
|
||||
try:
|
||||
creds = _get_credentials()
|
||||
path = f"/api/2.0/mcp/vector-search/{catalog}/{schema}/{index_name}"
|
||||
server_url = _build_server_url(path)
|
||||
|
||||
if not server_url:
|
||||
return {
|
||||
"error": "Databricks host not configured",
|
||||
"help": "Set DATABRICKS_HOST environment variable.",
|
||||
}
|
||||
|
||||
mcp_client = _get_mcp_client(
|
||||
server_url=server_url,
|
||||
host=creds.get("host"),
|
||||
token=creds.get("token"),
|
||||
)
|
||||
|
||||
tool_args: dict[str, Any] = {
|
||||
"query": query,
|
||||
"num_results": num_results,
|
||||
}
|
||||
|
||||
# Discover the actual tool name from the server
|
||||
tools = mcp_client.list_tools()
|
||||
if not tools:
|
||||
return {
|
||||
"error": "No tools discovered on the Vector Search MCP server",
|
||||
"help": f"Check that the index '{catalog}.{schema}.{index_name}' exists.",
|
||||
}
|
||||
|
||||
tool_name = tools[0].name
|
||||
response = mcp_client.call_tool(tool_name, tool_args)
|
||||
result_text = "".join([c.text for c in response.content])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result_text,
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"help": "Install dependencies: "
|
||||
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Databricks Vector Search failed: {e!s}"}
|
||||
|
||||
@mcp.tool()
|
||||
def databricks_mcp_query_genie(
|
||||
genie_space_id: str,
|
||||
question: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Query a Databricks Genie space via the managed MCP server.
|
||||
|
||||
Genie spaces allow natural language queries against structured data.
|
||||
Use this to analyze data by asking questions in plain English.
|
||||
Results are read-only.
|
||||
|
||||
Note: Genie queries may take longer to execute as they involve
|
||||
natural language to SQL translation.
|
||||
|
||||
Args:
|
||||
genie_space_id: The ID of the Genie space to query.
|
||||
question: Natural language question to ask the Genie space.
|
||||
|
||||
Returns:
|
||||
Dict with Genie results:
|
||||
- success: True if query executed successfully
|
||||
- result: The Genie response text
|
||||
|
||||
Or error dict with:
|
||||
- error: Error message
|
||||
|
||||
Example:
|
||||
>>> databricks_mcp_query_genie(
|
||||
... genie_space_id="abc123",
|
||||
... question="What was the total revenue last quarter?"
|
||||
... )
|
||||
{
|
||||
"success": True,
|
||||
"result": "The total revenue last quarter was $1.2M..."
|
||||
}
|
||||
"""
|
||||
if not genie_space_id or not genie_space_id.strip():
|
||||
return {"error": "genie_space_id is required"}
|
||||
if not question or not question.strip():
|
||||
return {"error": "question is required"}
|
||||
|
||||
try:
|
||||
creds = _get_credentials()
|
||||
path = f"/api/2.0/mcp/genie/{genie_space_id}"
|
||||
server_url = _build_server_url(path)
|
||||
|
||||
if not server_url:
|
||||
return {
|
||||
"error": "Databricks host not configured",
|
||||
"help": "Set DATABRICKS_HOST environment variable.",
|
||||
}
|
||||
|
||||
mcp_client = _get_mcp_client(
|
||||
server_url=server_url,
|
||||
host=creds.get("host"),
|
||||
token=creds.get("token"),
|
||||
)
|
||||
|
||||
# Discover the actual tool name from the server
|
||||
tools = mcp_client.list_tools()
|
||||
if not tools:
|
||||
return {
|
||||
"error": "No tools discovered on the Genie MCP server",
|
||||
"help": f"Check that the Genie space '{genie_space_id}' exists "
|
||||
"and you have access to it.",
|
||||
}
|
||||
|
||||
tool_name = tools[0].name
|
||||
response = mcp_client.call_tool(tool_name, {"question": question})
|
||||
result_text = "".join([c.text for c in response.content])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result_text,
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"help": "Install dependencies: "
|
||||
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Databricks Genie query failed: {e!s}"}
|
||||
|
||||
@mcp.tool()
|
||||
def databricks_mcp_list_tools(
|
||||
server_url: str | None = None,
|
||||
server_type: str | None = None,
|
||||
resource_path: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Discover available tools on a Databricks managed MCP server.
|
||||
|
||||
Use this to explore what tools are available on a specific MCP server
|
||||
endpoint before calling them. Supports both direct URL and parameterized
|
||||
server type specification.
|
||||
|
||||
Args:
|
||||
server_url: Full URL of the MCP server endpoint. If provided,
|
||||
server_type and resource_path are ignored.
|
||||
server_type: Type of managed server: "sql", "vector-search",
|
||||
"genie", or "functions". Used with resource_path.
|
||||
resource_path: Resource path for the server type. Examples:
|
||||
- For vector-search: "catalog/schema/index_name"
|
||||
- For genie: "genie_space_id"
|
||||
- For functions: "catalog/schema/function_name"
|
||||
- For sql: not needed
|
||||
|
||||
Returns:
|
||||
Dict with discovered tools:
|
||||
- success: True if discovery succeeded
|
||||
- server_url: The MCP server URL queried
|
||||
- tools: List of tool definitions (name, description, parameters)
|
||||
|
||||
Or error dict with:
|
||||
- error: Error message
|
||||
|
||||
Example:
|
||||
>>> databricks_mcp_list_tools(server_type="functions", resource_path="system/ai")
|
||||
{
|
||||
"success": True,
|
||||
"server_url": "https://workspace.cloud.databricks.com/api/2.0/mcp/functions/system/ai",
|
||||
"tools": [
|
||||
{
|
||||
"name": "system__ai__python_exec",
|
||||
"description": "Execute Python code",
|
||||
"parameters": {...}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
creds = _get_credentials()
|
||||
|
||||
# Resolve server URL
|
||||
effective_url = server_url
|
||||
if not effective_url:
|
||||
if not server_type:
|
||||
return {
|
||||
"error": "Either server_url or server_type is required",
|
||||
"help": "Provide a full server_url or specify server_type "
|
||||
"(sql, vector-search, genie, functions) with resource_path.",
|
||||
}
|
||||
|
||||
valid_types = {"sql", "vector-search", "genie", "functions"}
|
||||
if server_type not in valid_types:
|
||||
return {
|
||||
"error": f"Invalid server_type: {server_type}",
|
||||
"help": f"Must be one of: {', '.join(sorted(valid_types))}",
|
||||
}
|
||||
|
||||
path = f"/api/2.0/mcp/{server_type}"
|
||||
if resource_path:
|
||||
path = f"{path}/{resource_path}"
|
||||
|
||||
effective_url = _build_server_url(path)
|
||||
|
||||
if not effective_url:
|
||||
return {
|
||||
"error": "Databricks host not configured",
|
||||
"help": "Set DATABRICKS_HOST environment variable.",
|
||||
}
|
||||
|
||||
mcp_client = _get_mcp_client(
|
||||
server_url=effective_url,
|
||||
host=creds.get("host"),
|
||||
token=creds.get("token"),
|
||||
)
|
||||
|
||||
tools = mcp_client.list_tools()
|
||||
tool_list = []
|
||||
for t in tools:
|
||||
tool_info: dict[str, Any] = {
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
}
|
||||
if t.inputSchema:
|
||||
tool_info["parameters"] = t.inputSchema
|
||||
tool_list.append(tool_info)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"server_url": effective_url,
|
||||
"tools": tool_list,
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"help": "Install dependencies: "
|
||||
"pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to list MCP tools: {e!s}"}
|
||||
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Databricks Tool - Execute SQL queries and explore tables in Databricks.
|
||||
|
||||
Provides two categories of tools:
|
||||
1. Custom SQL tools with read-only safety guards (run_databricks_sql, describe_databricks_table)
|
||||
2. Managed MCP server tools (databricks_mcp_*)
|
||||
|
||||
Supports:
|
||||
- Personal access token authentication via DATABRICKS_HOST + DATABRICKS_TOKEN
|
||||
- Databricks CLI profile authentication (for managed MCP tools)
|
||||
|
||||
Safety features:
|
||||
- Read-only queries only (INSERT, UPDATE, DELETE, etc. are blocked)
|
||||
- Configurable row limits to prevent large result sets
|
||||
- SQL write-keyword detection with comment stripping
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from .databricks_mcp_tool import register_mcp_tools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# SQL keywords that indicate write operations (case-insensitive)
|
||||
WRITE_KEYWORDS = [
|
||||
r"\bINSERT\b",
|
||||
r"\bUPDATE\b",
|
||||
r"\bDELETE\b",
|
||||
r"\bDROP\b",
|
||||
r"\bCREATE\b",
|
||||
r"\bALTER\b",
|
||||
r"\bTRUNCATE\b",
|
||||
r"\bMERGE\b",
|
||||
r"\bREPLACE\b",
|
||||
]
|
||||
|
||||
# Compiled regex pattern for detecting write operations
|
||||
WRITE_PATTERN = re.compile("|".join(WRITE_KEYWORDS), re.IGNORECASE)
|
||||
|
||||
|
||||
def _is_read_only_query(sql: str) -> bool:
|
||||
"""
|
||||
Check if a SQL query is read-only.
|
||||
|
||||
Args:
|
||||
sql: The SQL query string to check
|
||||
|
||||
Returns:
|
||||
True if the query appears to be read-only, False otherwise
|
||||
"""
|
||||
# Remove comments (both -- and /* */ style)
|
||||
sql_no_comments = re.sub(r"--.*$", "", sql, flags=re.MULTILINE)
|
||||
sql_no_comments = re.sub(r"/\*.*?\*/", "", sql_no_comments, flags=re.DOTALL)
|
||||
|
||||
# Check for write keywords
|
||||
return not bool(WRITE_PATTERN.search(sql_no_comments))
|
||||
|
||||
|
||||
def _create_workspace_client(
|
||||
host: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Create a Databricks WorkspaceClient with appropriate credentials.
|
||||
|
||||
Args:
|
||||
host: Databricks workspace URL
|
||||
token: Personal access token
|
||||
|
||||
Returns:
|
||||
WorkspaceClient instance
|
||||
|
||||
Raises:
|
||||
ImportError: If databricks-sdk is not installed
|
||||
Exception: If authentication fails
|
||||
"""
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"databricks-sdk is required for Databricks tools. "
|
||||
"Install it with: pip install 'databricks-sdk>=0.30.0'"
|
||||
) from None
|
||||
|
||||
kwargs: dict[str, str] = {}
|
||||
if host:
|
||||
kwargs["host"] = host
|
||||
if token:
|
||||
kwargs["token"] = token
|
||||
|
||||
return WorkspaceClient(**kwargs)
|
||||
|
||||
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register Databricks tools with the MCP server."""
|
||||
|
||||
def _get_credentials() -> dict[str, str | None]:
|
||||
"""Get Databricks credentials from credential store or environment."""
|
||||
if credentials is not None:
|
||||
try:
|
||||
host = credentials.get("databricks_host")
|
||||
except KeyError:
|
||||
host = None
|
||||
try:
|
||||
token = credentials.get("databricks_token")
|
||||
except KeyError:
|
||||
token = None
|
||||
try:
|
||||
warehouse = credentials.get("databricks_warehouse")
|
||||
except KeyError:
|
||||
warehouse = None
|
||||
return {
|
||||
"host": host,
|
||||
"token": token,
|
||||
"warehouse_id": warehouse,
|
||||
}
|
||||
return {
|
||||
"host": os.getenv("DATABRICKS_HOST"),
|
||||
"token": os.getenv("DATABRICKS_TOKEN"),
|
||||
"warehouse_id": os.getenv("DATABRICKS_WAREHOUSE_ID"),
|
||||
}
|
||||
|
||||
def _get_client() -> Any:
|
||||
"""
|
||||
Get a Databricks WorkspaceClient with credentials resolution.
|
||||
|
||||
Returns:
|
||||
WorkspaceClient instance
|
||||
"""
|
||||
creds = _get_credentials()
|
||||
return _create_workspace_client(
|
||||
host=creds["host"],
|
||||
token=creds["token"],
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def run_databricks_sql(
|
||||
sql: str,
|
||||
warehouse_id: str | None = None,
|
||||
max_rows: int = 1000,
|
||||
) -> dict:
|
||||
"""
|
||||
Execute a read-only SQL query against a Databricks SQL Warehouse.
|
||||
|
||||
This tool executes SQL queries and returns the results as structured data.
|
||||
Only SELECT queries are allowed - write operations (INSERT, UPDATE, DELETE,
|
||||
DROP, CREATE, ALTER, TRUNCATE, MERGE) are blocked for safety.
|
||||
|
||||
Args:
|
||||
sql: The SQL query to execute. Must be a read-only query.
|
||||
warehouse_id: SQL Warehouse ID. Falls back to DATABRICKS_WAREHOUSE_ID
|
||||
env var if not provided.
|
||||
max_rows: Maximum number of rows to return (default: 1000).
|
||||
Use this to prevent accidentally fetching large result sets.
|
||||
|
||||
Returns:
|
||||
Dict with query results:
|
||||
- success: True if query executed successfully
|
||||
- rows: List of row dictionaries
|
||||
- total_rows: Total number of rows returned
|
||||
- rows_returned: Number of rows actually returned (may be limited)
|
||||
- schema: List of column definitions (name, type)
|
||||
- query_truncated: True if results were truncated due to max_rows
|
||||
|
||||
Or error dict with:
|
||||
- error: Error message
|
||||
- help: Optional help text
|
||||
|
||||
Example:
|
||||
>>> run_databricks_sql(
|
||||
... sql="SELECT name, COUNT(*) as cnt FROM catalog.schema.users GROUP BY name",
|
||||
... max_rows=100
|
||||
... )
|
||||
{
|
||||
"success": True,
|
||||
"rows": [{"name": "Alice", "cnt": 42}, ...],
|
||||
"total_rows": 100,
|
||||
"rows_returned": 100,
|
||||
"schema": [{"name": "name", "type": "STRING"}, ...],
|
||||
"query_truncated": False
|
||||
}
|
||||
"""
|
||||
# Validate SQL is read-only
|
||||
if not _is_read_only_query(sql):
|
||||
return {
|
||||
"error": "Write operations are not allowed",
|
||||
"help": "Only SELECT queries are permitted. "
|
||||
"INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, TRUNCATE, and MERGE are blocked.",
|
||||
}
|
||||
|
||||
# Validate max_rows
|
||||
if max_rows < 1:
|
||||
return {"error": "max_rows must be at least 1"}
|
||||
if max_rows > 10000:
|
||||
return {
|
||||
"error": "max_rows cannot exceed 10000",
|
||||
"help": "For larger result sets, consider using pagination or "
|
||||
"exporting to cloud storage.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = _get_client()
|
||||
creds = _get_credentials()
|
||||
effective_warehouse = warehouse_id or creds.get("warehouse_id")
|
||||
|
||||
if not effective_warehouse:
|
||||
return {
|
||||
"error": "No SQL Warehouse ID provided",
|
||||
"help": "Provide warehouse_id parameter or set DATABRICKS_WAREHOUSE_ID "
|
||||
"environment variable.",
|
||||
}
|
||||
|
||||
# Execute query via Databricks SQL Statement API
|
||||
response = client.statement_execution.execute_statement(
|
||||
statement=sql,
|
||||
warehouse_id=effective_warehouse,
|
||||
wait_timeout="30s",
|
||||
row_limit=max_rows,
|
||||
)
|
||||
|
||||
# Check for execution errors
|
||||
if response.status and response.status.error:
|
||||
return {
|
||||
"error": f"Databricks SQL error: {response.status.error.message}",
|
||||
}
|
||||
|
||||
# Parse results
|
||||
schema = []
|
||||
if response.manifest and response.manifest.schema and response.manifest.schema.columns:
|
||||
schema = [
|
||||
{
|
||||
"name": col.name,
|
||||
"type": str(col.type_name) if col.type_name else "UNKNOWN",
|
||||
}
|
||||
for col in response.manifest.schema.columns
|
||||
]
|
||||
|
||||
rows = []
|
||||
total_rows = 0
|
||||
if response.result and response.result.data_array:
|
||||
total_rows = len(response.result.data_array)
|
||||
for row_data in response.result.data_array[:max_rows]:
|
||||
row = {}
|
||||
for i, value in enumerate(row_data):
|
||||
col_name = schema[i]["name"] if i < len(schema) else f"col_{i}"
|
||||
row[col_name] = value
|
||||
rows.append(row)
|
||||
|
||||
query_truncated = total_rows > max_rows
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"rows": rows,
|
||||
"total_rows": total_rows,
|
||||
"rows_returned": len(rows),
|
||||
"schema": schema,
|
||||
"query_truncated": query_truncated,
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"help": "Install the dependency by running: pip install 'databricks-sdk>=0.30.0'",
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Provide helpful messages for common errors
|
||||
if "TEMPORARILY_UNAVAILABLE" in error_msg or "401" in error_msg:
|
||||
return {
|
||||
"error": "Databricks authentication failed",
|
||||
"help": "Check that DATABRICKS_HOST and DATABRICKS_TOKEN are set correctly. "
|
||||
"Token may have expired — generate a new one from workspace settings.",
|
||||
}
|
||||
if "403" in error_msg or "PERMISSION_DENIED" in error_msg:
|
||||
return {
|
||||
"error": f"Databricks permission denied: {error_msg}",
|
||||
"help": "Ensure your token has permission to access the SQL Warehouse "
|
||||
"and the requested tables/catalogs.",
|
||||
}
|
||||
if "404" in error_msg or "NOT_FOUND" in error_msg:
|
||||
return {
|
||||
"error": f"Databricks resource not found: {error_msg}",
|
||||
"help": "Check that the warehouse ID, catalog, schema, "
|
||||
"and table names are correct.",
|
||||
}
|
||||
if "INVALID_PARAMETER" in error_msg:
|
||||
return {
|
||||
"error": f"Invalid parameter: {error_msg}",
|
||||
"help": "Check that the warehouse_id is a valid SQL Warehouse ID.",
|
||||
}
|
||||
|
||||
return {"error": f"Databricks SQL query failed: {error_msg}"}
|
||||
|
||||
@mcp.tool()
|
||||
def describe_databricks_table(
|
||||
catalog: str,
|
||||
schema: str,
|
||||
table: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Describe a table in Databricks Unity Catalog, returning its schema and metadata.
|
||||
|
||||
Use this tool to explore table structure before writing queries.
|
||||
Returns column names, types, nullability, and table metadata.
|
||||
|
||||
Args:
|
||||
catalog: Unity Catalog catalog name (e.g., "main").
|
||||
schema: Schema name within the catalog (e.g., "default").
|
||||
table: Table name to describe.
|
||||
|
||||
Returns:
|
||||
Dict with table information:
|
||||
- success: True if operation succeeded
|
||||
- catalog: The catalog name
|
||||
- schema: The schema name
|
||||
- table: The table name
|
||||
- full_name: Fully qualified table name (catalog.schema.table)
|
||||
- table_type: Type of the table (MANAGED, EXTERNAL, VIEW, etc.)
|
||||
- columns: List of column definitions (name, type, nullable, comment)
|
||||
- comment: Table comment/description if available
|
||||
- storage_location: Physical storage location if applicable
|
||||
|
||||
Or error dict with:
|
||||
- error: Error message
|
||||
- help: Optional help text
|
||||
|
||||
Example:
|
||||
>>> describe_databricks_table("main", "default", "users")
|
||||
{
|
||||
"success": True,
|
||||
"catalog": "main",
|
||||
"schema": "default",
|
||||
"table": "users",
|
||||
"full_name": "main.default.users",
|
||||
"table_type": "MANAGED",
|
||||
"columns": [
|
||||
{"name": "id", "type": "LONG", "nullable": False, "comment": "User ID"},
|
||||
{"name": "email", "type": "STRING", "nullable": True, "comment": None}
|
||||
],
|
||||
"comment": "User accounts table"
|
||||
}
|
||||
"""
|
||||
if not catalog or not catalog.strip():
|
||||
return {"error": "catalog is required"}
|
||||
if not schema or not schema.strip():
|
||||
return {"error": "schema is required"}
|
||||
if not table or not table.strip():
|
||||
return {"error": "table name is required"}
|
||||
|
||||
full_name = f"{catalog}.{schema}.{table}"
|
||||
|
||||
try:
|
||||
client = _get_client()
|
||||
|
||||
# Retrieve table metadata via Unity Catalog API
|
||||
table_info = client.tables.get(full_name)
|
||||
|
||||
columns = []
|
||||
if table_info.columns:
|
||||
for col in table_info.columns:
|
||||
columns.append(
|
||||
{
|
||||
"name": col.name,
|
||||
"type": str(col.type_name) if col.type_name else "UNKNOWN",
|
||||
"nullable": col.nullable if col.nullable is not None else True,
|
||||
"comment": col.comment,
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"catalog": catalog,
|
||||
"schema": schema,
|
||||
"table": table,
|
||||
"full_name": full_name,
|
||||
"table_type": str(table_info.table_type) if table_info.table_type else None,
|
||||
"columns": columns,
|
||||
"comment": table_info.comment,
|
||||
}
|
||||
|
||||
if table_info.storage_location:
|
||||
result["storage_location"] = table_info.storage_location
|
||||
|
||||
return result
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"help": "Install the dependency by running: pip install 'databricks-sdk>=0.30.0'",
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
if "TEMPORARILY_UNAVAILABLE" in error_msg or "401" in error_msg:
|
||||
return {
|
||||
"error": "Databricks authentication failed",
|
||||
"help": "Check that DATABRICKS_HOST and DATABRICKS_TOKEN are set correctly.",
|
||||
}
|
||||
if "NOT_FOUND" in error_msg or "DOES_NOT_EXIST" in error_msg:
|
||||
return {
|
||||
"error": f"Table not found: {full_name}",
|
||||
"help": "Check that the catalog, schema, and table names are correct. "
|
||||
f"Full error: {error_msg}",
|
||||
}
|
||||
if "403" in error_msg or "PERMISSION_DENIED" in error_msg:
|
||||
return {
|
||||
"error": f"Permission denied for table: {full_name}",
|
||||
"help": "Ensure your token has the 'USE CATALOG', 'USE SCHEMA', "
|
||||
"and 'SELECT' privileges on the target table.",
|
||||
}
|
||||
|
||||
return {"error": f"Failed to describe table: {error_msg}"}
|
||||
|
||||
# Register managed MCP server tools
|
||||
register_mcp_tools(mcp, credentials)
|
||||
@@ -0,0 +1,783 @@
|
||||
"""
|
||||
Tests for Databricks tools.
|
||||
|
||||
Tests cover:
|
||||
- Custom SQL tool: read-only enforcement, row limiting, query execution, error handling
|
||||
- Describe table: input validation, successful description, error handling
|
||||
- Managed MCP tools: SQL, UC functions, Vector Search, Genie, tool discovery
|
||||
- Credential resolution from CredentialStoreAdapter and environment
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
from aden_tools.tools.databricks_tool import register_tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp():
|
||||
"""Create a FastMCP instance for testing."""
|
||||
return FastMCP("test-server")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials():
|
||||
"""Create mock credentials for testing."""
|
||||
return CredentialStoreAdapter.for_testing(
|
||||
{
|
||||
"databricks_host": "https://test-workspace.cloud.databricks.com",
|
||||
"databricks_token": "dapi_test_token_12345",
|
||||
"databricks_warehouse": "test-warehouse-id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registered_mcp(mcp, mock_credentials):
|
||||
"""Register Databricks tools with mock credentials."""
|
||||
register_tools(mcp, credentials=mock_credentials)
|
||||
return mcp
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Custom SQL Tool — Read-Only Enforcement
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestReadOnlyEnforcement:
|
||||
"""Tests for SQL write operation blocking in run_databricks_sql."""
|
||||
|
||||
def test_blocks_insert(self, registered_mcp):
|
||||
"""INSERT statements should be blocked."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="INSERT INTO table VALUES (1, 2)")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_blocks_update(self, registered_mcp):
|
||||
"""UPDATE statements should be blocked."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="UPDATE table SET col = 1")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_blocks_delete(self, registered_mcp):
|
||||
"""DELETE statements should be blocked."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="DELETE FROM table WHERE id = 1")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_blocks_drop(self, registered_mcp):
|
||||
"""DROP statements should be blocked."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="DROP TABLE my_table")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_blocks_create(self, registered_mcp):
|
||||
"""CREATE statements should be blocked."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="CREATE TABLE my_table (id INT)")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_blocks_alter(self, registered_mcp):
|
||||
"""ALTER statements should be blocked."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="ALTER TABLE my_table ADD COLUMN new_col INT")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_blocks_truncate(self, registered_mcp):
|
||||
"""TRUNCATE statements should be blocked."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="TRUNCATE TABLE my_table")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_blocks_merge(self, registered_mcp):
|
||||
"""MERGE statements should be blocked."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="MERGE INTO target USING source ON condition WHEN MATCHED THEN UPDATE")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_blocks_case_insensitive(self, registered_mcp):
|
||||
"""Write detection should be case-insensitive."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="insert into table values (1)")
|
||||
assert "error" in result
|
||||
assert "Write operations are not allowed" in result["error"]
|
||||
|
||||
def test_allows_select(self, registered_mcp):
|
||||
"""SELECT statements should be allowed (will fail on client, not validation)."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("Mock error")
|
||||
result = tool.fn(sql="SELECT * FROM table")
|
||||
assert "Write operations are not allowed" not in result.get("error", "")
|
||||
|
||||
def test_allows_select_with_subquery(self, registered_mcp):
|
||||
"""Complex SELECT with subqueries should be allowed."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("Mock error")
|
||||
result = tool.fn(
|
||||
sql="""
|
||||
SELECT a.*, b.count
|
||||
FROM (SELECT id, name FROM users) a
|
||||
JOIN (SELECT user_id, COUNT(*) as count FROM orders GROUP BY user_id) b
|
||||
ON a.id = b.user_id
|
||||
"""
|
||||
)
|
||||
assert "Write operations are not allowed" not in result.get("error", "")
|
||||
|
||||
def test_ignores_write_keywords_in_comments(self, registered_mcp):
|
||||
"""Write keywords inside comments should not trigger blocking."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("Mock error")
|
||||
result = tool.fn(sql="SELECT * FROM t -- INSERT comment")
|
||||
assert "Write operations are not allowed" not in result.get("error", "")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Custom SQL Tool — Row Limits
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestRowLimits:
|
||||
"""Tests for row limit validation in run_databricks_sql."""
|
||||
|
||||
def test_rejects_zero_max_rows(self, registered_mcp):
|
||||
"""max_rows of 0 should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="SELECT 1", max_rows=0)
|
||||
assert "error" in result
|
||||
assert "max_rows must be at least 1" in result["error"]
|
||||
|
||||
def test_rejects_negative_max_rows(self, registered_mcp):
|
||||
"""Negative max_rows should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="SELECT 1", max_rows=-1)
|
||||
assert "error" in result
|
||||
assert "max_rows must be at least 1" in result["error"]
|
||||
|
||||
def test_rejects_excessive_max_rows(self, registered_mcp):
|
||||
"""max_rows over 10000 should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
result = tool.fn(sql="SELECT 1", max_rows=10001)
|
||||
assert "error" in result
|
||||
assert "max_rows cannot exceed 10000" in result["error"]
|
||||
|
||||
def test_accepts_valid_max_rows(self, registered_mcp):
|
||||
"""Valid max_rows values should be accepted."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("Mock error")
|
||||
for max_rows in [1, 100, 1000, 10000]:
|
||||
result = tool.fn(sql="SELECT 1", max_rows=max_rows)
|
||||
assert "max_rows" not in result.get("error", "")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Custom SQL Tool — Query Execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestQueryExecution:
|
||||
"""Tests for successful query execution with mocked Databricks client."""
|
||||
|
||||
def test_successful_query(self, registered_mcp):
|
||||
"""Test successful query execution with mocked WorkspaceClient."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
# Mock the statement execution response
|
||||
mock_col1 = MagicMock()
|
||||
mock_col1.name = "id"
|
||||
mock_col1.type_name = "INT"
|
||||
mock_col2 = MagicMock()
|
||||
mock_col2.name = "name"
|
||||
mock_col2.type_name = "STRING"
|
||||
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.columns = [mock_col1, mock_col2]
|
||||
|
||||
mock_manifest = MagicMock()
|
||||
mock_manifest.schema = mock_schema
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.data_array = [["1", "Alice"], ["2", "Bob"]]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status.error = None
|
||||
mock_response.manifest = mock_manifest
|
||||
mock_response.result = mock_result
|
||||
|
||||
mock_client.statement_execution.execute_statement.return_value = mock_response
|
||||
|
||||
result = tool.fn(sql="SELECT id, name FROM users")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["rows"] == [
|
||||
{"id": "1", "name": "Alice"},
|
||||
{"id": "2", "name": "Bob"},
|
||||
]
|
||||
assert result["total_rows"] == 2
|
||||
assert result["rows_returned"] == 2
|
||||
assert result["query_truncated"] is False
|
||||
assert len(result["schema"]) == 2
|
||||
|
||||
def test_query_truncation(self, registered_mcp):
|
||||
"""Test that results are truncated when exceeding max_rows."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_col.name = "id"
|
||||
mock_col.type_name = "INT"
|
||||
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.columns = [mock_col]
|
||||
|
||||
mock_manifest = MagicMock()
|
||||
mock_manifest.schema = mock_schema
|
||||
|
||||
# Create 10 rows of data
|
||||
mock_result = MagicMock()
|
||||
mock_result.data_array = [[str(i)] for i in range(10)]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status.error = None
|
||||
mock_response.manifest = mock_manifest
|
||||
mock_response.result = mock_result
|
||||
|
||||
mock_client.statement_execution.execute_statement.return_value = mock_response
|
||||
|
||||
# Request only 5 rows
|
||||
result = tool.fn(sql="SELECT id FROM users", max_rows=5)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_rows"] == 10
|
||||
assert result["rows_returned"] == 5
|
||||
assert result["query_truncated"] is True
|
||||
assert len(result["rows"]) == 5
|
||||
|
||||
def test_missing_warehouse_id(self, mcp):
|
||||
"""Test error when no warehouse ID is configured."""
|
||||
creds = CredentialStoreAdapter.for_testing(
|
||||
{
|
||||
"databricks_host": "https://test.cloud.databricks.com",
|
||||
"databricks_token": "dapi_test",
|
||||
}
|
||||
)
|
||||
register_tools(mcp, credentials=creds)
|
||||
|
||||
tool = mcp._tool_manager._tools["run_databricks_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.return_value = MagicMock()
|
||||
result = tool.fn(sql="SELECT 1")
|
||||
assert "error" in result
|
||||
assert "No SQL Warehouse ID provided" in result["error"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Describe Table
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestDescribeTable:
|
||||
"""Tests for describe_databricks_table tool."""
|
||||
|
||||
def test_empty_catalog(self, registered_mcp):
|
||||
"""Empty catalog should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
|
||||
result = tool.fn(catalog="", schema="default", table="users")
|
||||
assert "error" in result
|
||||
assert "catalog is required" in result["error"]
|
||||
|
||||
def test_empty_schema(self, registered_mcp):
|
||||
"""Empty schema should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
|
||||
result = tool.fn(catalog="main", schema="", table="users")
|
||||
assert "error" in result
|
||||
assert "schema is required" in result["error"]
|
||||
|
||||
def test_empty_table(self, registered_mcp):
|
||||
"""Empty table name should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
|
||||
result = tool.fn(catalog="main", schema="default", table="")
|
||||
assert "error" in result
|
||||
assert "table name is required" in result["error"]
|
||||
|
||||
def test_successful_describe(self, registered_mcp):
|
||||
"""Test successful table description with mocked client."""
|
||||
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
# Mock table info
|
||||
mock_col1 = MagicMock()
|
||||
mock_col1.name = "id"
|
||||
mock_col1.type_name = "LONG"
|
||||
mock_col1.nullable = False
|
||||
mock_col1.comment = "User ID"
|
||||
|
||||
mock_col2 = MagicMock()
|
||||
mock_col2.name = "email"
|
||||
mock_col2.type_name = "STRING"
|
||||
mock_col2.nullable = True
|
||||
mock_col2.comment = None
|
||||
|
||||
mock_table_info = MagicMock()
|
||||
mock_table_info.columns = [mock_col1, mock_col2]
|
||||
mock_table_info.table_type = "MANAGED"
|
||||
mock_table_info.comment = "User accounts table"
|
||||
mock_table_info.storage_location = "s3://bucket/path"
|
||||
|
||||
mock_client.tables.get.return_value = mock_table_info
|
||||
|
||||
result = tool.fn(catalog="main", schema="default", table="users")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["catalog"] == "main"
|
||||
assert result["schema"] == "default"
|
||||
assert result["table"] == "users"
|
||||
assert result["full_name"] == "main.default.users"
|
||||
assert result["table_type"] == "MANAGED"
|
||||
assert result["comment"] == "User accounts table"
|
||||
assert result["storage_location"] == "s3://bucket/path"
|
||||
assert len(result["columns"]) == 2
|
||||
assert result["columns"][0]["name"] == "id"
|
||||
assert result["columns"][0]["nullable"] is False
|
||||
assert result["columns"][1]["name"] == "email"
|
||||
assert result["columns"][1]["nullable"] is True
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Managed MCP Tools
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestMCPQuerySQL:
|
||||
"""Tests for databricks_mcp_query_sql tool."""
|
||||
|
||||
def test_empty_sql(self, registered_mcp):
|
||||
"""Empty SQL should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"]
|
||||
result = tool.fn(sql="")
|
||||
assert "error" in result
|
||||
assert "sql is required" in result["error"]
|
||||
|
||||
def test_successful_mcp_sql(self, registered_mcp):
|
||||
"""Test successful MCP SQL query with mocked client."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
|
||||
) as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_content = MagicMock()
|
||||
mock_content.text = "Query results: 42 rows returned"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [mock_content]
|
||||
mock_client.call_tool.return_value = mock_response
|
||||
|
||||
result = tool.fn(sql="SELECT * FROM main.default.users")
|
||||
|
||||
assert result["success"] is True
|
||||
assert "Query results" in result["result"]
|
||||
|
||||
|
||||
class TestMCPUCFunction:
|
||||
"""Tests for databricks_mcp_query_uc_function tool."""
|
||||
|
||||
def test_empty_catalog(self, registered_mcp):
|
||||
"""Empty catalog should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"]
|
||||
result = tool.fn(catalog="", schema="default", function_name="my_func")
|
||||
assert "error" in result
|
||||
assert "catalog is required" in result["error"]
|
||||
|
||||
def test_empty_function_name(self, registered_mcp):
|
||||
"""Empty function name should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"]
|
||||
result = tool.fn(catalog="main", schema="default", function_name="")
|
||||
assert "error" in result
|
||||
assert "function_name is required" in result["error"]
|
||||
|
||||
def test_successful_uc_function(self, registered_mcp):
|
||||
"""Test successful UC function call with mocked client."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
|
||||
) as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_content = MagicMock()
|
||||
mock_content.text = "Revenue: $1.2M"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [mock_content]
|
||||
mock_client.call_tool.return_value = mock_response
|
||||
|
||||
result = tool.fn(
|
||||
catalog="main",
|
||||
schema="analytics",
|
||||
function_name="get_revenue",
|
||||
arguments={"year": "2024"},
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "Revenue" in result["result"]
|
||||
|
||||
|
||||
class TestMCPVectorSearch:
|
||||
"""Tests for databricks_mcp_vector_search tool."""
|
||||
|
||||
def test_empty_query(self, registered_mcp):
|
||||
"""Empty query should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"]
|
||||
result = tool.fn(catalog="prod", schema="kb", index_name="idx", query="")
|
||||
assert "error" in result
|
||||
assert "query is required" in result["error"]
|
||||
|
||||
def test_empty_index_name(self, registered_mcp):
|
||||
"""Empty index name should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"]
|
||||
result = tool.fn(catalog="prod", schema="kb", index_name="", query="test")
|
||||
assert "error" in result
|
||||
assert "index_name is required" in result["error"]
|
||||
|
||||
def test_successful_vector_search(self, registered_mcp):
|
||||
"""Test successful vector search with mocked client."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
|
||||
) as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "search_index"
|
||||
mock_client.list_tools.return_value = [mock_tool]
|
||||
|
||||
mock_content = MagicMock()
|
||||
mock_content.text = "Found 3 relevant documents"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [mock_content]
|
||||
mock_client.call_tool.return_value = mock_response
|
||||
|
||||
result = tool.fn(
|
||||
catalog="prod",
|
||||
schema="kb",
|
||||
index_name="docs_index",
|
||||
query="How to configure auth?",
|
||||
num_results=5,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "relevant documents" in result["result"]
|
||||
|
||||
|
||||
class TestMCPGenie:
|
||||
"""Tests for databricks_mcp_query_genie tool."""
|
||||
|
||||
def test_empty_genie_space_id(self, registered_mcp):
|
||||
"""Empty genie space ID should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"]
|
||||
result = tool.fn(genie_space_id="", question="What is revenue?")
|
||||
assert "error" in result
|
||||
assert "genie_space_id is required" in result["error"]
|
||||
|
||||
def test_empty_question(self, registered_mcp):
|
||||
"""Empty question should be rejected."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"]
|
||||
result = tool.fn(genie_space_id="abc123", question="")
|
||||
assert "error" in result
|
||||
assert "question is required" in result["error"]
|
||||
|
||||
def test_successful_genie_query(self, registered_mcp):
|
||||
"""Test successful Genie query with mocked client."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
|
||||
) as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "ask_genie"
|
||||
mock_client.list_tools.return_value = [mock_tool]
|
||||
|
||||
mock_content = MagicMock()
|
||||
mock_content.text = "Total revenue last quarter was $1.2M"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [mock_content]
|
||||
mock_client.call_tool.return_value = mock_response
|
||||
|
||||
result = tool.fn(
|
||||
genie_space_id="abc123",
|
||||
question="What was revenue last quarter?",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "$1.2M" in result["result"]
|
||||
|
||||
|
||||
class TestMCPListTools:
|
||||
"""Tests for databricks_mcp_list_tools tool."""
|
||||
|
||||
def test_missing_server_url_and_type(self, registered_mcp):
|
||||
"""Should return error when neither server_url nor server_type is given."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"]
|
||||
result = tool.fn()
|
||||
assert "error" in result
|
||||
assert "server_url or server_type is required" in result["error"]
|
||||
|
||||
def test_invalid_server_type(self, registered_mcp):
|
||||
"""Should return error for invalid server type."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"]
|
||||
result = tool.fn(server_type="invalid")
|
||||
assert "error" in result
|
||||
assert "Invalid server_type" in result["error"]
|
||||
|
||||
def test_successful_list_tools(self, registered_mcp):
|
||||
"""Test successful tool listing with mocked client."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
|
||||
) as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_tool1 = MagicMock()
|
||||
mock_tool1.name = "execute_sql"
|
||||
mock_tool1.description = "Execute SQL queries"
|
||||
mock_tool1.inputSchema = {"type": "object", "properties": {}}
|
||||
|
||||
mock_tool2 = MagicMock()
|
||||
mock_tool2.name = "list_tables"
|
||||
mock_tool2.description = "List available tables"
|
||||
mock_tool2.inputSchema = None
|
||||
|
||||
mock_client.list_tools.return_value = [mock_tool1, mock_tool2]
|
||||
|
||||
result = tool.fn(server_type="sql")
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["tools"]) == 2
|
||||
assert result["tools"][0]["name"] == "execute_sql"
|
||||
assert result["tools"][1]["name"] == "list_tables"
|
||||
assert "https://test-workspace.cloud.databricks.com" in result["server_url"]
|
||||
|
||||
def test_with_direct_server_url(self, registered_mcp):
|
||||
"""Test tool listing with a direct server URL."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
|
||||
) as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.list_tools.return_value = []
|
||||
|
||||
url = "https://my-workspace.cloud.databricks.com/api/2.0/mcp/sql"
|
||||
result = tool.fn(server_url=url)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["server_url"] == url
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Error Handling
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling and user-friendly messages."""
|
||||
|
||||
def test_authentication_error(self, registered_mcp):
|
||||
"""Authentication errors should provide helpful messages."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("401 Unauthorized")
|
||||
result = tool.fn(sql="SELECT 1")
|
||||
|
||||
assert "error" in result
|
||||
assert "authentication failed" in result["error"].lower()
|
||||
assert "help" in result
|
||||
|
||||
def test_permission_error(self, registered_mcp):
|
||||
"""Permission errors should provide helpful messages."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("PERMISSION_DENIED: access denied")
|
||||
result = tool.fn(sql="SELECT 1")
|
||||
|
||||
assert "error" in result
|
||||
assert "permission denied" in result["error"].lower()
|
||||
assert "help" in result
|
||||
|
||||
def test_not_found_error(self, registered_mcp):
|
||||
"""Not found errors should provide helpful messages."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("NOT_FOUND: warehouse xyz not found")
|
||||
result = tool.fn(sql="SELECT 1")
|
||||
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"].lower()
|
||||
assert "help" in result
|
||||
|
||||
def test_table_not_found_error(self, registered_mcp):
|
||||
"""Table not found errors should provide helpful messages."""
|
||||
tool = registered_mcp._tool_manager._tools["describe_databricks_table"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("DOES_NOT_EXIST: table not found")
|
||||
result = tool.fn(catalog="main", schema="default", table="nonexistent")
|
||||
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"].lower()
|
||||
|
||||
def test_mcp_import_error(self, registered_mcp):
|
||||
"""Import errors for managed MCP tools should be helpful."""
|
||||
tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client"
|
||||
) as mock_get:
|
||||
mock_get.side_effect = ImportError("databricks-mcp and databricks-sdk are required")
|
||||
result = tool.fn(sql="SELECT 1")
|
||||
|
||||
assert "error" in result
|
||||
assert "databricks" in result["error"].lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Credential Resolution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCredentialResolution:
|
||||
"""Tests for credential resolution from different sources."""
|
||||
|
||||
def test_uses_credential_store(self, mcp):
|
||||
"""Should use credentials from CredentialStoreAdapter."""
|
||||
mock_creds = CredentialStoreAdapter.for_testing(
|
||||
{
|
||||
"databricks_host": "https://custom.cloud.databricks.com",
|
||||
"databricks_token": "dapi_custom_token",
|
||||
"databricks_warehouse": "custom-warehouse",
|
||||
}
|
||||
)
|
||||
register_tools(mcp, credentials=mock_creds)
|
||||
|
||||
assert mock_creds.get("databricks_host") == "https://custom.cloud.databricks.com"
|
||||
assert mock_creds.get("databricks_token") == "dapi_custom_token"
|
||||
assert mock_creds.get("databricks_warehouse") == "custom-warehouse"
|
||||
|
||||
def test_falls_back_to_env_vars(self, mcp):
|
||||
"""Should fall back to environment variables when no credential store."""
|
||||
register_tools(mcp, credentials=None)
|
||||
|
||||
# Tools are registered and will use os.getenv internally
|
||||
assert "run_databricks_sql" in mcp._tool_manager._tools
|
||||
assert "describe_databricks_table" in mcp._tool_manager._tools
|
||||
assert "databricks_mcp_query_sql" in mcp._tool_manager._tools
|
||||
assert "databricks_mcp_query_uc_function" in mcp._tool_manager._tools
|
||||
assert "databricks_mcp_vector_search" in mcp._tool_manager._tools
|
||||
assert "databricks_mcp_query_genie" in mcp._tool_manager._tools
|
||||
assert "databricks_mcp_list_tools" in mcp._tool_manager._tools
|
||||
|
||||
def test_all_tools_registered(self, registered_mcp):
|
||||
"""All 7 Databricks tools should be registered."""
|
||||
expected_tools = [
|
||||
"run_databricks_sql",
|
||||
"describe_databricks_table",
|
||||
"databricks_mcp_query_sql",
|
||||
"databricks_mcp_query_uc_function",
|
||||
"databricks_mcp_vector_search",
|
||||
"databricks_mcp_query_genie",
|
||||
"databricks_mcp_list_tools",
|
||||
]
|
||||
for tool_name in expected_tools:
|
||||
assert tool_name in registered_mcp._tool_manager._tools, (
|
||||
f"Tool '{tool_name}' not registered"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Import Error Handling
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestImportError:
|
||||
"""Tests for handling missing databricks-sdk package."""
|
||||
|
||||
def test_sdk_import_error_message(self, registered_mcp):
|
||||
"""Should provide helpful message when databricks-sdk not installed."""
|
||||
tool = registered_mcp._tool_manager._tools["run_databricks_sql"]
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = ImportError(
|
||||
"databricks-sdk is required for Databricks tools. "
|
||||
"Install it with: pip install 'databricks-sdk>=0.30.0'"
|
||||
)
|
||||
result = tool.fn(sql="SELECT 1")
|
||||
|
||||
assert "error" in result
|
||||
assert "databricks-sdk" in result["error"]
|
||||
assert "pip install" in result["error"]
|
||||
@@ -536,6 +536,32 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/7c/996760c30f1302704af57c66ff2d723f7d656d0d0b93563b5528a51484bb/cyclopts-4.5.1-py3-none-any.whl", hash = "sha256:0642c93601e554ca6b7b9abd81093847ea4448b2616280f2a0952416574e8c7a", size = 199807, upload-time = "2026-01-25T15:23:55.219Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "databricks-mcp"
|
||||
version = "0.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "databricks-sdk" },
|
||||
{ name = "mcp" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/49/10/9f12e8d3322327a8edac8b906a0f1cece3dfbc58efdd2198538b4ba56957/databricks_mcp-0.1.0-py3-none-any.whl", hash = "sha256:d02ed6cdd24a862a365b4fae2892223d8209d14e5be500ebc14d086f646408cc", size = 2307, upload-time = "2025-06-05T20:39:11.724Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "databricks-sdk"
|
||||
version = "0.94.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "google-auth" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3e/39/18b3a4a69851f583d7970ae71f19975e06ef892ea17393bcafb773088f96/databricks_sdk-0.94.0.tar.gz", hash = "sha256:1b29a9d63e2dc9808ed79253ba3668cd6d130dabb880a576d647fa3b1b1a7b30", size = 861465, upload-time = "2026-02-26T08:17:12.063Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/da/072fe4df10aec583eb72327856b4743c7f5fb50a16772d5f86fcc20e2457/databricks_sdk-0.94.0-py3-none-any.whl", hash = "sha256:ae980186ee2d99f3639d14a94457776dca22980a04f65f89497a79889b2b9149", size = 811079, upload-time = "2026-02-26T08:17:10.242Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diff-match-patch"
|
||||
version = "20241021"
|
||||
@@ -3490,6 +3516,8 @@ dependencies = [
|
||||
|
||||
[package.optional-dependencies]
|
||||
all = [
|
||||
{ name = "databricks-mcp" },
|
||||
{ name = "databricks-sdk" },
|
||||
{ name = "duckdb" },
|
||||
{ name = "google-cloud-bigquery" },
|
||||
{ name = "openpyxl" },
|
||||
@@ -3500,6 +3528,10 @@ all = [
|
||||
bigquery = [
|
||||
{ name = "google-cloud-bigquery" },
|
||||
]
|
||||
databricks = [
|
||||
{ name = "databricks-mcp" },
|
||||
{ name = "databricks-sdk" },
|
||||
]
|
||||
dev = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
@@ -3529,6 +3561,10 @@ dev = [
|
||||
requires-dist = [
|
||||
{ name = "arxiv", specifier = ">=2.1.0" },
|
||||
{ name = "beautifulsoup4", specifier = ">=4.12.0" },
|
||||
{ name = "databricks-mcp", marker = "extra == 'all'", specifier = ">=0.1.0" },
|
||||
{ name = "databricks-mcp", marker = "extra == 'databricks'", specifier = ">=0.1.0" },
|
||||
{ name = "databricks-sdk", marker = "extra == 'all'", specifier = ">=0.30.0" },
|
||||
{ name = "databricks-sdk", marker = "extra == 'databricks'", specifier = ">=0.30.0" },
|
||||
{ name = "diff-match-patch", specifier = ">=20230430" },
|
||||
{ name = "dnspython", specifier = ">=2.4.0" },
|
||||
{ name = "duckdb", marker = "extra == 'all'", specifier = ">=1.0.0" },
|
||||
@@ -3561,7 +3597,7 @@ requires-dist = [
|
||||
{ name = "restrictedpython", marker = "extra == 'sandbox'", specifier = ">=7.0" },
|
||||
{ name = "stripe", specifier = ">=14.3.0" },
|
||||
]
|
||||
provides-extras = ["dev", "sandbox", "ocr", "excel", "sql", "bigquery", "all"]
|
||||
provides-extras = ["dev", "sandbox", "ocr", "excel", "sql", "bigquery", "databricks", "all"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
|
||||
Reference in New Issue
Block a user