diff --git a/tools/pyproject.toml b/tools/pyproject.toml index e533eceb..d1aa2817 100644 --- a/tools/pyproject.toml +++ b/tools/pyproject.toml @@ -59,6 +59,10 @@ sql = [ bigquery = [ "google-cloud-bigquery>=3.0.0", ] +databricks = [ + "databricks-sdk>=0.30.0", + "databricks-mcp>=0.1.0", +] all = [ "RestrictedPython>=7.0", "pytesseract>=0.3.10", @@ -66,6 +70,8 @@ all = [ "duckdb>=1.0.0", "openpyxl>=3.1.0", "google-cloud-bigquery>=3.0.0", + "databricks-sdk>=0.30.0", + "databricks-mcp>=0.1.0", ] [tool.uv.sources] diff --git a/tools/src/aden_tools/credentials/__init__.py b/tools/src/aden_tools/credentials/__init__.py index 5945ba76..5d1b63b1 100644 --- a/tools/src/aden_tools/credentials/__init__.py +++ b/tools/src/aden_tools/credentials/__init__.py @@ -56,6 +56,7 @@ To add a new credential: from .apollo import APOLLO_CREDENTIALS from .base import CredentialError, CredentialSpec from .bigquery import BIGQUERY_CREDENTIALS +from .databricks import DATABRICKS_CREDENTIALS from .brevo import BREVO_CREDENTIALS from .browser import get_aden_auth_url, get_aden_setup_url, open_browser from .calcom import CALCOM_CREDENTIALS @@ -113,6 +114,7 @@ CREDENTIAL_SPECS = { **STRIPE_CREDENTIALS, **BREVO_CREDENTIALS, **POSTGRES_CREDENTIALS, + **DATABRICKS_CREDENTIALS, } __all__ = [ @@ -160,4 +162,5 @@ __all__ = [ "STRIPE_CREDENTIALS", "BREVO_CREDENTIALS", "POSTGRES_CREDENTIALS", + "DATABRICKS_CREDENTIALS", ] diff --git a/tools/src/aden_tools/credentials/databricks.py b/tools/src/aden_tools/credentials/databricks.py new file mode 100644 index 00000000..3296f8d0 --- /dev/null +++ b/tools/src/aden_tools/credentials/databricks.py @@ -0,0 +1,90 @@ +""" +Databricks tool credentials. + +Contains credentials for Databricks workspace and SQL Warehouse access, +as well as managed MCP server connectivity. +""" + +from .base import CredentialSpec + +# All tool names that require Databricks credentials +_DATABRICKS_TOOLS = [ + # Custom SQL tools (databricks_tool.py) + "run_databricks_sql", + "describe_databricks_table", + # Managed MCP tools (databricks_mcp_tool.py) + "databricks_mcp_query_sql", + "databricks_mcp_query_uc_function", + "databricks_mcp_vector_search", + "databricks_mcp_query_genie", + "databricks_mcp_list_tools", +] + +DATABRICKS_CREDENTIALS = { + "databricks_host": CredentialSpec( + env_var="DATABRICKS_HOST", + tools=_DATABRICKS_TOOLS, + required=True, + startup_required=False, + help_url="https://docs.databricks.com/en/workspace/workspace-details.html", + description="Databricks workspace URL (e.g., https://dbc-a1b2c3d4-e5f6.cloud.databricks.com)", + # Auth method support + aden_supported=False, + direct_api_key_supported=True, + api_key_instructions="""To set up Databricks authentication: + +1. Go to your Databricks workspace +2. Copy the workspace URL from your browser (e.g., https://dbc-a1b2c3d4-e5f6.cloud.databricks.com) +3. Set DATABRICKS_HOST=https://your-workspace-hostname + +Note: Do not include a trailing slash.""", + # Credential store mapping + credential_id="databricks_host", + credential_key="workspace_url", + ), + "databricks_token": CredentialSpec( + env_var="DATABRICKS_TOKEN", + tools=_DATABRICKS_TOOLS, + required=True, + startup_required=False, + help_url="https://docs.databricks.com/en/dev-tools/auth/pat.html", + description="Databricks personal access token for API authentication", + # Auth method support + aden_supported=False, + direct_api_key_supported=True, + api_key_instructions="""To create a Databricks personal access token: + +1. Go to your Databricks workspace +2. Click your username in the top bar > Settings +3. Click 'Developer' in the sidebar +4. Under 'Access tokens', click 'Manage' +5. Click 'Generate new token' +6. Give it a description and set the lifetime +7. Copy the token and set DATABRICKS_TOKEN=dapi... + +For service principals, use OAuth machine-to-machine tokens instead.""", + # Credential store mapping + credential_id="databricks_token", + credential_key="access_token", + ), + "databricks_warehouse": CredentialSpec( + env_var="DATABRICKS_WAREHOUSE_ID", + tools=["run_databricks_sql", "databricks_mcp_query_sql"], + required=False, + startup_required=False, + help_url="https://docs.databricks.com/en/sql/admin/create-sql-warehouse.html", + description="Default Databricks SQL Warehouse ID for query execution", + aden_supported=False, + direct_api_key_supported=True, + api_key_instructions="""To find your SQL Warehouse ID: + +1. Go to your Databricks workspace +2. Click 'SQL Warehouses' in the sidebar +3. Click on your warehouse +4. Copy the ID from the URL or the 'Connection details' tab +5. Set DATABRICKS_WAREHOUSE_ID=your_warehouse_id""", + # Credential store mapping + credential_id="databricks_warehouse", + credential_key="warehouse_id", + ), +} diff --git a/tools/src/aden_tools/tools/__init__.py b/tools/src/aden_tools/tools/__init__.py index edc3c464..b42a9722 100644 --- a/tools/src/aden_tools/tools/__init__.py +++ b/tools/src/aden_tools/tools/__init__.py @@ -29,6 +29,7 @@ from .brevo_tool import register_tools as register_brevo from .calcom_tool import register_tools as register_calcom from .calendar_tool import register_tools as register_calendar from .csv_tool import register_tools as register_csv +from .databricks_tool import register_tools as register_databricks from .discord_tool import register_tools as register_discord # Security scanning tools @@ -157,6 +158,9 @@ def register_all_tools( # Postgres tool register_postgres(mcp, credentials=credentials) + # Databricks tools + register_databricks(mcp, credentials=credentials) + # Return the list of all registered tool names return list(mcp._tool_manager._tools.keys()) diff --git a/tools/src/aden_tools/tools/databricks_tool/README.md b/tools/src/aden_tools/tools/databricks_tool/README.md new file mode 100644 index 00000000..3e246953 --- /dev/null +++ b/tools/src/aden_tools/tools/databricks_tool/README.md @@ -0,0 +1,128 @@ +# Databricks Tool + +Query Databricks SQL Warehouses and interact with Databricks managed MCP servers. + +## Tools + +### Custom SQL Tools (Read-Only) + +| Tool | Description | +|------|-------------| +| `run_databricks_sql` | Execute read-only SQL queries against a Databricks SQL Warehouse | +| `describe_databricks_table` | Fetch table schema/metadata from Unity Catalog | + +### Managed MCP Server Tools + +| Tool | Description | +|------|-------------| +| `databricks_mcp_query_sql` | Execute SQL via the managed SQL MCP server | +| `databricks_mcp_query_uc_function` | Execute a Unity Catalog function | +| `databricks_mcp_vector_search` | Query a Vector Search index | +| `databricks_mcp_query_genie` | Query a Genie space with natural language | +| `databricks_mcp_list_tools` | Discover tools on any managed MCP server endpoint | + +## Environment Variables + +| Variable | Required | Description | +|----------|----------|-------------| +| `DATABRICKS_HOST` | Yes | Workspace URL (e.g., `https://dbc-xxx.cloud.databricks.com`) | +| `DATABRICKS_TOKEN` | Yes | Personal access token (`dapi...`) | +| `DATABRICKS_WAREHOUSE_ID` | No | Default SQL Warehouse ID | + +## Usage Examples + +### Execute a Read-Only SQL Query + +```python +run_databricks_sql( + sql="SELECT name, COUNT(*) as cnt FROM main.default.users GROUP BY name", + warehouse_id="abc123def456", + max_rows=100 +) +``` + +### Describe a Unity Catalog Table + +```python +describe_databricks_table( + catalog="main", + schema="default", + table="users" +) +``` + +### Query via Managed MCP SQL Server + +```python +databricks_mcp_query_sql( + sql="SELECT * FROM main.default.orders LIMIT 10" +) +``` + +### Execute a Unity Catalog Function + +```python +databricks_mcp_query_uc_function( + catalog="main", + schema="analytics", + function_name="get_revenue_summary", + arguments={"start_date": "2024-01-01"} +) +``` + +### Search a Vector Index + +```python +databricks_mcp_vector_search( + catalog="prod", + schema="knowledge_base", + index_name="docs_index", + query="How to configure authentication?", + num_results=5 +) +``` + +### Query a Genie Space + +```python +databricks_mcp_query_genie( + genie_space_id="abc123", + question="What was the total revenue last quarter?" +) +``` + +### Discover Available MCP Tools + +```python +databricks_mcp_list_tools( + server_type="functions", + resource_path="system/ai" +) +``` + +## Safety Features + +- **Read-only enforcement** on `run_databricks_sql`: INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, TRUNCATE, MERGE, and REPLACE are blocked +- **Row limits**: Configurable max_rows (1–10,000) to prevent large result sets +- **Credential isolation**: Uses CredentialStoreAdapter pattern; secrets never logged + +## Error Handling + +All tools return structured error dicts with `error` and optional `help` fields. Common errors include: + +- **Authentication failure**: Invalid or expired token +- **Permission denied**: Insufficient privileges on the target resource +- **Not found**: Invalid catalog, schema, table, or warehouse ID +- **Missing dependency**: `databricks-sdk` or `databricks-mcp` not installed + +## Installation + +```bash +pip install 'databricks-sdk>=0.30.0' 'databricks-mcp>=0.1.0' +``` + +Or via the project's optional dependencies: + +```bash +pip install '.[databricks]' +``` diff --git a/tools/src/aden_tools/tools/databricks_tool/__init__.py b/tools/src/aden_tools/tools/databricks_tool/__init__.py new file mode 100644 index 00000000..b8b308f1 --- /dev/null +++ b/tools/src/aden_tools/tools/databricks_tool/__init__.py @@ -0,0 +1,12 @@ +""" +Databricks Tool - Query Databricks SQL Warehouses and interact with managed MCP servers. + +Provides MCP tools for: +- Executing read-only SQL queries via Databricks SQL Warehouses +- Describing tables via Unity Catalog +- Interacting with Databricks managed MCP servers (SQL, Vector Search, Genie, UC functions) +""" + +from .databricks_tool import register_tools + +__all__ = ["register_tools"] diff --git a/tools/src/aden_tools/tools/databricks_tool/databricks_mcp_tool.py b/tools/src/aden_tools/tools/databricks_tool/databricks_mcp_tool.py new file mode 100644 index 00000000..646e7de1 --- /dev/null +++ b/tools/src/aden_tools/tools/databricks_tool/databricks_mcp_tool.py @@ -0,0 +1,562 @@ +""" +Databricks Managed MCP Server Tools. + +Provides tools to interact with Databricks managed MCP server endpoints: +- SQL: Execute queries via the managed SQL MCP server +- Unity Catalog Functions: Execute predefined UC functions +- Vector Search: Query Vector Search indexes +- Genie: Query Genie spaces with natural language +- Discovery: List available tools on any managed MCP server + +These tools use the official databricks-mcp library for authentication +and communication with Databricks managed MCP server endpoints. +""" + +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING, Any + +from fastmcp import FastMCP + +if TYPE_CHECKING: + from aden_tools.credentials import CredentialStoreAdapter + +logger = logging.getLogger(__name__) + + +def _get_mcp_client(server_url: str, host: str | None, token: str | None) -> Any: + """ + Create a DatabricksMCPClient for the given server URL. + + Args: + server_url: Full URL of the managed MCP server endpoint + host: Databricks workspace URL + token: Personal access token + + Returns: + DatabricksMCPClient instance + + Raises: + ImportError: If databricks-mcp or databricks-sdk is not installed + """ + try: + from databricks.sdk import WorkspaceClient + from databricks_mcp import DatabricksMCPClient + except ImportError: + raise ImportError( + "databricks-mcp and databricks-sdk are required for Databricks MCP tools. " + "Install them with: pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'" + ) from None + + kwargs: dict[str, str] = {} + if host: + kwargs["host"] = host + if token: + kwargs["token"] = token + + workspace_client = WorkspaceClient(**kwargs) + return DatabricksMCPClient(server_url=server_url, workspace_client=workspace_client) + + +def register_mcp_tools( + mcp: FastMCP, + credentials: CredentialStoreAdapter | None = None, +) -> None: + """Register Databricks managed MCP server tools with the MCP server.""" + + def _get_credentials() -> dict[str, str | None]: + """Get Databricks credentials from credential store or environment.""" + if credentials is not None: + try: + host = credentials.get("databricks_host") + except KeyError: + host = None + try: + token = credentials.get("databricks_token") + except KeyError: + token = None + try: + warehouse = credentials.get("databricks_warehouse") + except KeyError: + warehouse = None + return { + "host": host, + "token": token, + "warehouse_id": warehouse, + } + return { + "host": os.getenv("DATABRICKS_HOST"), + "token": os.getenv("DATABRICKS_TOKEN"), + "warehouse_id": os.getenv("DATABRICKS_WAREHOUSE_ID"), + } + + def _get_host() -> str | None: + """Get the Databricks workspace host URL.""" + creds = _get_credentials() + return creds.get("host") + + def _build_server_url(path: str) -> str | None: + """Build a full managed MCP server URL from a path suffix.""" + host = _get_host() + if not host: + return None + # Ensure host doesn't have trailing slash + host = host.rstrip("/") + return f"{host}{path}" + + @mcp.tool() + def databricks_mcp_query_sql( + sql: str, + warehouse_id: str | None = None, + ) -> dict: + """ + Execute a SQL query via the Databricks managed SQL MCP server. + + Unlike run_databricks_sql, this tool uses the official Databricks managed + MCP SQL server endpoint and supports both read and write operations as + permitted by the workspace. + + Args: + sql: The SQL query to execute. + warehouse_id: SQL Warehouse ID. Falls back to DATABRICKS_WAREHOUSE_ID + env var if not provided. Required for the SQL MCP server. + + Returns: + Dict with query results: + - success: True if query executed successfully + - result: The query result text from the MCP server + + Or error dict with: + - error: Error message + - help: Optional help text + + Example: + >>> databricks_mcp_query_sql("SELECT * FROM main.default.users LIMIT 10") + { + "success": True, + "result": "..." + } + """ + if not sql or not sql.strip(): + return {"error": "sql is required"} + + try: + creds = _get_credentials() + server_url = _build_server_url("/api/2.0/mcp/sql") + + if not server_url: + return { + "error": "Databricks host not configured", + "help": "Set DATABRICKS_HOST environment variable to your workspace URL.", + } + + effective_warehouse = warehouse_id or creds.get("warehouse_id") + mcp_client = _get_mcp_client( + server_url=server_url, + host=creds.get("host"), + token=creds.get("token"), + ) + + # Build arguments for the SQL tool + tool_args: dict[str, Any] = {"statement": sql} + if effective_warehouse: + tool_args["warehouse_id"] = effective_warehouse + + response = mcp_client.call_tool("execute_sql", tool_args) + result_text = "".join([c.text for c in response.content]) + + return { + "success": True, + "result": result_text, + } + + except ImportError as e: + return { + "error": str(e), + "help": "Install dependencies: " + "pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'", + } + except Exception as e: + return {"error": f"Databricks MCP SQL query failed: {e!s}"} + + @mcp.tool() + def databricks_mcp_query_uc_function( + catalog: str, + schema: str, + function_name: str, + arguments: dict | None = None, + ) -> dict: + """ + Execute a Unity Catalog function via the Databricks managed MCP server. + + Use this to run predefined SQL functions registered in Unity Catalog. + These functions encapsulate business logic and can be invoked as tools. + + Args: + catalog: Unity Catalog catalog name (e.g., "main"). + schema: Schema name within the catalog (e.g., "default"). + function_name: Name of the UC function to execute. + arguments: Optional dict of arguments to pass to the function. + + Returns: + Dict with function result: + - success: True if function executed successfully + - result: The function result text from the MCP server + + Or error dict with: + - error: Error message + + Example: + >>> databricks_mcp_query_uc_function( + ... catalog="main", + ... schema="analytics", + ... function_name="get_revenue_summary", + ... arguments={"start_date": "2024-01-01", "end_date": "2024-12-31"} + ... ) + { + "success": True, + "result": "Revenue summary: ..." + } + """ + if not catalog or not catalog.strip(): + return {"error": "catalog is required"} + if not schema or not schema.strip(): + return {"error": "schema is required"} + if not function_name or not function_name.strip(): + return {"error": "function_name is required"} + + try: + creds = _get_credentials() + path = f"/api/2.0/mcp/functions/{catalog}/{schema}/{function_name}" + server_url = _build_server_url(path) + + if not server_url: + return { + "error": "Databricks host not configured", + "help": "Set DATABRICKS_HOST environment variable.", + } + + mcp_client = _get_mcp_client( + server_url=server_url, + host=creds.get("host"), + token=creds.get("token"), + ) + + # Construct the tool name using the UC naming convention + tool_name = f"{catalog}__{schema}__{function_name}" + tool_args = arguments or {} + + response = mcp_client.call_tool(tool_name, tool_args) + result_text = "".join([c.text for c in response.content]) + + return { + "success": True, + "result": result_text, + } + + except ImportError as e: + return { + "error": str(e), + "help": "Install dependencies: " + "pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'", + } + except Exception as e: + return {"error": f"Databricks UC function call failed: {e!s}"} + + @mcp.tool() + def databricks_mcp_vector_search( + catalog: str, + schema: str, + index_name: str, + query: str, + num_results: int = 10, + ) -> dict: + """ + Query a Databricks Vector Search index via the managed MCP server. + + Use this to find semantically relevant documents from a Vector Search + index that uses Databricks managed embeddings. + + Args: + catalog: Unity Catalog catalog name containing the index. + schema: Schema name within the catalog. + index_name: Name of the Vector Search index. + query: The search query text. + num_results: Number of results to return (default: 10). + + Returns: + Dict with search results: + - success: True if search executed successfully + - result: The search result text from the MCP server + + Or error dict with: + - error: Error message + + Example: + >>> databricks_mcp_vector_search( + ... catalog="prod", + ... schema="knowledge_base", + ... index_name="docs_index", + ... query="How to configure authentication?", + ... num_results=5 + ... ) + { + "success": True, + "result": "..." + } + """ + if not catalog or not catalog.strip(): + return {"error": "catalog is required"} + if not schema or not schema.strip(): + return {"error": "schema is required"} + if not index_name or not index_name.strip(): + return {"error": "index_name is required"} + if not query or not query.strip(): + return {"error": "query is required"} + + try: + creds = _get_credentials() + path = f"/api/2.0/mcp/vector-search/{catalog}/{schema}/{index_name}" + server_url = _build_server_url(path) + + if not server_url: + return { + "error": "Databricks host not configured", + "help": "Set DATABRICKS_HOST environment variable.", + } + + mcp_client = _get_mcp_client( + server_url=server_url, + host=creds.get("host"), + token=creds.get("token"), + ) + + tool_args: dict[str, Any] = { + "query": query, + "num_results": num_results, + } + + # Discover the actual tool name from the server + tools = mcp_client.list_tools() + if not tools: + return { + "error": "No tools discovered on the Vector Search MCP server", + "help": f"Check that the index '{catalog}.{schema}.{index_name}' exists.", + } + + tool_name = tools[0].name + response = mcp_client.call_tool(tool_name, tool_args) + result_text = "".join([c.text for c in response.content]) + + return { + "success": True, + "result": result_text, + } + + except ImportError as e: + return { + "error": str(e), + "help": "Install dependencies: " + "pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'", + } + except Exception as e: + return {"error": f"Databricks Vector Search failed: {e!s}"} + + @mcp.tool() + def databricks_mcp_query_genie( + genie_space_id: str, + question: str, + ) -> dict: + """ + Query a Databricks Genie space via the managed MCP server. + + Genie spaces allow natural language queries against structured data. + Use this to analyze data by asking questions in plain English. + Results are read-only. + + Note: Genie queries may take longer to execute as they involve + natural language to SQL translation. + + Args: + genie_space_id: The ID of the Genie space to query. + question: Natural language question to ask the Genie space. + + Returns: + Dict with Genie results: + - success: True if query executed successfully + - result: The Genie response text + + Or error dict with: + - error: Error message + + Example: + >>> databricks_mcp_query_genie( + ... genie_space_id="abc123", + ... question="What was the total revenue last quarter?" + ... ) + { + "success": True, + "result": "The total revenue last quarter was $1.2M..." + } + """ + if not genie_space_id or not genie_space_id.strip(): + return {"error": "genie_space_id is required"} + if not question or not question.strip(): + return {"error": "question is required"} + + try: + creds = _get_credentials() + path = f"/api/2.0/mcp/genie/{genie_space_id}" + server_url = _build_server_url(path) + + if not server_url: + return { + "error": "Databricks host not configured", + "help": "Set DATABRICKS_HOST environment variable.", + } + + mcp_client = _get_mcp_client( + server_url=server_url, + host=creds.get("host"), + token=creds.get("token"), + ) + + # Discover the actual tool name from the server + tools = mcp_client.list_tools() + if not tools: + return { + "error": "No tools discovered on the Genie MCP server", + "help": f"Check that the Genie space '{genie_space_id}' exists " + "and you have access to it.", + } + + tool_name = tools[0].name + response = mcp_client.call_tool(tool_name, {"question": question}) + result_text = "".join([c.text for c in response.content]) + + return { + "success": True, + "result": result_text, + } + + except ImportError as e: + return { + "error": str(e), + "help": "Install dependencies: " + "pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'", + } + except Exception as e: + return {"error": f"Databricks Genie query failed: {e!s}"} + + @mcp.tool() + def databricks_mcp_list_tools( + server_url: str | None = None, + server_type: str | None = None, + resource_path: str | None = None, + ) -> dict: + """ + Discover available tools on a Databricks managed MCP server. + + Use this to explore what tools are available on a specific MCP server + endpoint before calling them. Supports both direct URL and parameterized + server type specification. + + Args: + server_url: Full URL of the MCP server endpoint. If provided, + server_type and resource_path are ignored. + server_type: Type of managed server: "sql", "vector-search", + "genie", or "functions". Used with resource_path. + resource_path: Resource path for the server type. Examples: + - For vector-search: "catalog/schema/index_name" + - For genie: "genie_space_id" + - For functions: "catalog/schema/function_name" + - For sql: not needed + + Returns: + Dict with discovered tools: + - success: True if discovery succeeded + - server_url: The MCP server URL queried + - tools: List of tool definitions (name, description, parameters) + + Or error dict with: + - error: Error message + + Example: + >>> databricks_mcp_list_tools(server_type="functions", resource_path="system/ai") + { + "success": True, + "server_url": "https://workspace.cloud.databricks.com/api/2.0/mcp/functions/system/ai", + "tools": [ + { + "name": "system__ai__python_exec", + "description": "Execute Python code", + "parameters": {...} + } + ] + } + """ + try: + creds = _get_credentials() + + # Resolve server URL + effective_url = server_url + if not effective_url: + if not server_type: + return { + "error": "Either server_url or server_type is required", + "help": "Provide a full server_url or specify server_type " + "(sql, vector-search, genie, functions) with resource_path.", + } + + valid_types = {"sql", "vector-search", "genie", "functions"} + if server_type not in valid_types: + return { + "error": f"Invalid server_type: {server_type}", + "help": f"Must be one of: {', '.join(sorted(valid_types))}", + } + + path = f"/api/2.0/mcp/{server_type}" + if resource_path: + path = f"{path}/{resource_path}" + + effective_url = _build_server_url(path) + + if not effective_url: + return { + "error": "Databricks host not configured", + "help": "Set DATABRICKS_HOST environment variable.", + } + + mcp_client = _get_mcp_client( + server_url=effective_url, + host=creds.get("host"), + token=creds.get("token"), + ) + + tools = mcp_client.list_tools() + tool_list = [] + for t in tools: + tool_info: dict[str, Any] = { + "name": t.name, + "description": t.description, + } + if t.inputSchema: + tool_info["parameters"] = t.inputSchema + tool_list.append(tool_info) + + return { + "success": True, + "server_url": effective_url, + "tools": tool_list, + } + + except ImportError as e: + return { + "error": str(e), + "help": "Install dependencies: " + "pip install 'databricks-mcp>=0.1.0' 'databricks-sdk>=0.30.0'", + } + except Exception as e: + return {"error": f"Failed to list MCP tools: {e!s}"} diff --git a/tools/src/aden_tools/tools/databricks_tool/databricks_tool.py b/tools/src/aden_tools/tools/databricks_tool/databricks_tool.py new file mode 100644 index 00000000..f34d3693 --- /dev/null +++ b/tools/src/aden_tools/tools/databricks_tool/databricks_tool.py @@ -0,0 +1,429 @@ +""" +Databricks Tool - Execute SQL queries and explore tables in Databricks. + +Provides two categories of tools: +1. Custom SQL tools with read-only safety guards (run_databricks_sql, describe_databricks_table) +2. Managed MCP server tools (databricks_mcp_*) + +Supports: +- Personal access token authentication via DATABRICKS_HOST + DATABRICKS_TOKEN +- Databricks CLI profile authentication (for managed MCP tools) + +Safety features: +- Read-only queries only (INSERT, UPDATE, DELETE, etc. are blocked) +- Configurable row limits to prevent large result sets +- SQL write-keyword detection with comment stripping +""" + +from __future__ import annotations + +import logging +import os +import re +from typing import TYPE_CHECKING, Any + +from fastmcp import FastMCP + +from .databricks_mcp_tool import register_mcp_tools + +if TYPE_CHECKING: + from aden_tools.credentials import CredentialStoreAdapter + +logger = logging.getLogger(__name__) + +# SQL keywords that indicate write operations (case-insensitive) +WRITE_KEYWORDS = [ + r"\bINSERT\b", + r"\bUPDATE\b", + r"\bDELETE\b", + r"\bDROP\b", + r"\bCREATE\b", + r"\bALTER\b", + r"\bTRUNCATE\b", + r"\bMERGE\b", + r"\bREPLACE\b", +] + +# Compiled regex pattern for detecting write operations +WRITE_PATTERN = re.compile("|".join(WRITE_KEYWORDS), re.IGNORECASE) + + +def _is_read_only_query(sql: str) -> bool: + """ + Check if a SQL query is read-only. + + Args: + sql: The SQL query string to check + + Returns: + True if the query appears to be read-only, False otherwise + """ + # Remove comments (both -- and /* */ style) + sql_no_comments = re.sub(r"--.*$", "", sql, flags=re.MULTILINE) + sql_no_comments = re.sub(r"/\*.*?\*/", "", sql_no_comments, flags=re.DOTALL) + + # Check for write keywords + return not bool(WRITE_PATTERN.search(sql_no_comments)) + + +def _create_workspace_client( + host: str | None = None, + token: str | None = None, +) -> Any: + """ + Create a Databricks WorkspaceClient with appropriate credentials. + + Args: + host: Databricks workspace URL + token: Personal access token + + Returns: + WorkspaceClient instance + + Raises: + ImportError: If databricks-sdk is not installed + Exception: If authentication fails + """ + try: + from databricks.sdk import WorkspaceClient + except ImportError: + raise ImportError( + "databricks-sdk is required for Databricks tools. " + "Install it with: pip install 'databricks-sdk>=0.30.0'" + ) from None + + kwargs: dict[str, str] = {} + if host: + kwargs["host"] = host + if token: + kwargs["token"] = token + + return WorkspaceClient(**kwargs) + + +def register_tools( + mcp: FastMCP, + credentials: CredentialStoreAdapter | None = None, +) -> None: + """Register Databricks tools with the MCP server.""" + + def _get_credentials() -> dict[str, str | None]: + """Get Databricks credentials from credential store or environment.""" + if credentials is not None: + try: + host = credentials.get("databricks_host") + except KeyError: + host = None + try: + token = credentials.get("databricks_token") + except KeyError: + token = None + try: + warehouse = credentials.get("databricks_warehouse") + except KeyError: + warehouse = None + return { + "host": host, + "token": token, + "warehouse_id": warehouse, + } + return { + "host": os.getenv("DATABRICKS_HOST"), + "token": os.getenv("DATABRICKS_TOKEN"), + "warehouse_id": os.getenv("DATABRICKS_WAREHOUSE_ID"), + } + + def _get_client() -> Any: + """ + Get a Databricks WorkspaceClient with credentials resolution. + + Returns: + WorkspaceClient instance + """ + creds = _get_credentials() + return _create_workspace_client( + host=creds["host"], + token=creds["token"], + ) + + @mcp.tool() + def run_databricks_sql( + sql: str, + warehouse_id: str | None = None, + max_rows: int = 1000, + ) -> dict: + """ + Execute a read-only SQL query against a Databricks SQL Warehouse. + + This tool executes SQL queries and returns the results as structured data. + Only SELECT queries are allowed - write operations (INSERT, UPDATE, DELETE, + DROP, CREATE, ALTER, TRUNCATE, MERGE) are blocked for safety. + + Args: + sql: The SQL query to execute. Must be a read-only query. + warehouse_id: SQL Warehouse ID. Falls back to DATABRICKS_WAREHOUSE_ID + env var if not provided. + max_rows: Maximum number of rows to return (default: 1000). + Use this to prevent accidentally fetching large result sets. + + Returns: + Dict with query results: + - success: True if query executed successfully + - rows: List of row dictionaries + - total_rows: Total number of rows returned + - rows_returned: Number of rows actually returned (may be limited) + - schema: List of column definitions (name, type) + - query_truncated: True if results were truncated due to max_rows + + Or error dict with: + - error: Error message + - help: Optional help text + + Example: + >>> run_databricks_sql( + ... sql="SELECT name, COUNT(*) as cnt FROM catalog.schema.users GROUP BY name", + ... max_rows=100 + ... ) + { + "success": True, + "rows": [{"name": "Alice", "cnt": 42}, ...], + "total_rows": 100, + "rows_returned": 100, + "schema": [{"name": "name", "type": "STRING"}, ...], + "query_truncated": False + } + """ + # Validate SQL is read-only + if not _is_read_only_query(sql): + return { + "error": "Write operations are not allowed", + "help": "Only SELECT queries are permitted. " + "INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, TRUNCATE, and MERGE are blocked.", + } + + # Validate max_rows + if max_rows < 1: + return {"error": "max_rows must be at least 1"} + if max_rows > 10000: + return { + "error": "max_rows cannot exceed 10000", + "help": "For larger result sets, consider using pagination or " + "exporting to cloud storage.", + } + + try: + client = _get_client() + creds = _get_credentials() + effective_warehouse = warehouse_id or creds.get("warehouse_id") + + if not effective_warehouse: + return { + "error": "No SQL Warehouse ID provided", + "help": "Provide warehouse_id parameter or set DATABRICKS_WAREHOUSE_ID " + "environment variable.", + } + + # Execute query via Databricks SQL Statement API + response = client.statement_execution.execute_statement( + statement=sql, + warehouse_id=effective_warehouse, + wait_timeout="30s", + row_limit=max_rows, + ) + + # Check for execution errors + if response.status and response.status.error: + return { + "error": f"Databricks SQL error: {response.status.error.message}", + } + + # Parse results + schema = [] + if response.manifest and response.manifest.schema and response.manifest.schema.columns: + schema = [ + { + "name": col.name, + "type": str(col.type_name) if col.type_name else "UNKNOWN", + } + for col in response.manifest.schema.columns + ] + + rows = [] + total_rows = 0 + if response.result and response.result.data_array: + total_rows = len(response.result.data_array) + for row_data in response.result.data_array[:max_rows]: + row = {} + for i, value in enumerate(row_data): + col_name = schema[i]["name"] if i < len(schema) else f"col_{i}" + row[col_name] = value + rows.append(row) + + query_truncated = total_rows > max_rows + + return { + "success": True, + "rows": rows, + "total_rows": total_rows, + "rows_returned": len(rows), + "schema": schema, + "query_truncated": query_truncated, + } + + except ImportError as e: + return { + "error": str(e), + "help": "Install the dependency by running: pip install 'databricks-sdk>=0.30.0'", + } + except Exception as e: + error_msg = str(e) + + # Provide helpful messages for common errors + if "TEMPORARILY_UNAVAILABLE" in error_msg or "401" in error_msg: + return { + "error": "Databricks authentication failed", + "help": "Check that DATABRICKS_HOST and DATABRICKS_TOKEN are set correctly. " + "Token may have expired — generate a new one from workspace settings.", + } + if "403" in error_msg or "PERMISSION_DENIED" in error_msg: + return { + "error": f"Databricks permission denied: {error_msg}", + "help": "Ensure your token has permission to access the SQL Warehouse " + "and the requested tables/catalogs.", + } + if "404" in error_msg or "NOT_FOUND" in error_msg: + return { + "error": f"Databricks resource not found: {error_msg}", + "help": "Check that the warehouse ID, catalog, schema, " + "and table names are correct.", + } + if "INVALID_PARAMETER" in error_msg: + return { + "error": f"Invalid parameter: {error_msg}", + "help": "Check that the warehouse_id is a valid SQL Warehouse ID.", + } + + return {"error": f"Databricks SQL query failed: {error_msg}"} + + @mcp.tool() + def describe_databricks_table( + catalog: str, + schema: str, + table: str, + ) -> dict: + """ + Describe a table in Databricks Unity Catalog, returning its schema and metadata. + + Use this tool to explore table structure before writing queries. + Returns column names, types, nullability, and table metadata. + + Args: + catalog: Unity Catalog catalog name (e.g., "main"). + schema: Schema name within the catalog (e.g., "default"). + table: Table name to describe. + + Returns: + Dict with table information: + - success: True if operation succeeded + - catalog: The catalog name + - schema: The schema name + - table: The table name + - full_name: Fully qualified table name (catalog.schema.table) + - table_type: Type of the table (MANAGED, EXTERNAL, VIEW, etc.) + - columns: List of column definitions (name, type, nullable, comment) + - comment: Table comment/description if available + - storage_location: Physical storage location if applicable + + Or error dict with: + - error: Error message + - help: Optional help text + + Example: + >>> describe_databricks_table("main", "default", "users") + { + "success": True, + "catalog": "main", + "schema": "default", + "table": "users", + "full_name": "main.default.users", + "table_type": "MANAGED", + "columns": [ + {"name": "id", "type": "LONG", "nullable": False, "comment": "User ID"}, + {"name": "email", "type": "STRING", "nullable": True, "comment": None} + ], + "comment": "User accounts table" + } + """ + if not catalog or not catalog.strip(): + return {"error": "catalog is required"} + if not schema or not schema.strip(): + return {"error": "schema is required"} + if not table or not table.strip(): + return {"error": "table name is required"} + + full_name = f"{catalog}.{schema}.{table}" + + try: + client = _get_client() + + # Retrieve table metadata via Unity Catalog API + table_info = client.tables.get(full_name) + + columns = [] + if table_info.columns: + for col in table_info.columns: + columns.append( + { + "name": col.name, + "type": str(col.type_name) if col.type_name else "UNKNOWN", + "nullable": col.nullable if col.nullable is not None else True, + "comment": col.comment, + } + ) + + result = { + "success": True, + "catalog": catalog, + "schema": schema, + "table": table, + "full_name": full_name, + "table_type": str(table_info.table_type) if table_info.table_type else None, + "columns": columns, + "comment": table_info.comment, + } + + if table_info.storage_location: + result["storage_location"] = table_info.storage_location + + return result + + except ImportError as e: + return { + "error": str(e), + "help": "Install the dependency by running: pip install 'databricks-sdk>=0.30.0'", + } + except Exception as e: + error_msg = str(e) + + if "TEMPORARILY_UNAVAILABLE" in error_msg or "401" in error_msg: + return { + "error": "Databricks authentication failed", + "help": "Check that DATABRICKS_HOST and DATABRICKS_TOKEN are set correctly.", + } + if "NOT_FOUND" in error_msg or "DOES_NOT_EXIST" in error_msg: + return { + "error": f"Table not found: {full_name}", + "help": "Check that the catalog, schema, and table names are correct. " + f"Full error: {error_msg}", + } + if "403" in error_msg or "PERMISSION_DENIED" in error_msg: + return { + "error": f"Permission denied for table: {full_name}", + "help": "Ensure your token has the 'USE CATALOG', 'USE SCHEMA', " + "and 'SELECT' privileges on the target table.", + } + + return {"error": f"Failed to describe table: {error_msg}"} + + # Register managed MCP server tools + register_mcp_tools(mcp, credentials) diff --git a/tools/tests/tools/test_databricks_tool.py b/tools/tests/tools/test_databricks_tool.py new file mode 100644 index 00000000..6e67238f --- /dev/null +++ b/tools/tests/tools/test_databricks_tool.py @@ -0,0 +1,783 @@ +""" +Tests for Databricks tools. + +Tests cover: +- Custom SQL tool: read-only enforcement, row limiting, query execution, error handling +- Describe table: input validation, successful description, error handling +- Managed MCP tools: SQL, UC functions, Vector Search, Genie, tool discovery +- Credential resolution from CredentialStoreAdapter and environment +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import FastMCP + +from aden_tools.credentials import CredentialStoreAdapter +from aden_tools.tools.databricks_tool import register_tools + + +@pytest.fixture +def mcp(): + """Create a FastMCP instance for testing.""" + return FastMCP("test-server") + + +@pytest.fixture +def mock_credentials(): + """Create mock credentials for testing.""" + return CredentialStoreAdapter.for_testing( + { + "databricks_host": "https://test-workspace.cloud.databricks.com", + "databricks_token": "dapi_test_token_12345", + "databricks_warehouse": "test-warehouse-id", + } + ) + + +@pytest.fixture +def registered_mcp(mcp, mock_credentials): + """Register Databricks tools with mock credentials.""" + register_tools(mcp, credentials=mock_credentials) + return mcp + + +# =========================================================================== +# Custom SQL Tool — Read-Only Enforcement +# =========================================================================== + + +class TestReadOnlyEnforcement: + """Tests for SQL write operation blocking in run_databricks_sql.""" + + def test_blocks_insert(self, registered_mcp): + """INSERT statements should be blocked.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="INSERT INTO table VALUES (1, 2)") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_blocks_update(self, registered_mcp): + """UPDATE statements should be blocked.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="UPDATE table SET col = 1") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_blocks_delete(self, registered_mcp): + """DELETE statements should be blocked.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="DELETE FROM table WHERE id = 1") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_blocks_drop(self, registered_mcp): + """DROP statements should be blocked.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="DROP TABLE my_table") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_blocks_create(self, registered_mcp): + """CREATE statements should be blocked.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="CREATE TABLE my_table (id INT)") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_blocks_alter(self, registered_mcp): + """ALTER statements should be blocked.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="ALTER TABLE my_table ADD COLUMN new_col INT") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_blocks_truncate(self, registered_mcp): + """TRUNCATE statements should be blocked.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="TRUNCATE TABLE my_table") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_blocks_merge(self, registered_mcp): + """MERGE statements should be blocked.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="MERGE INTO target USING source ON condition WHEN MATCHED THEN UPDATE") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_blocks_case_insensitive(self, registered_mcp): + """Write detection should be case-insensitive.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="insert into table values (1)") + assert "error" in result + assert "Write operations are not allowed" in result["error"] + + def test_allows_select(self, registered_mcp): + """SELECT statements should be allowed (will fail on client, not validation).""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = Exception("Mock error") + result = tool.fn(sql="SELECT * FROM table") + assert "Write operations are not allowed" not in result.get("error", "") + + def test_allows_select_with_subquery(self, registered_mcp): + """Complex SELECT with subqueries should be allowed.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = Exception("Mock error") + result = tool.fn( + sql=""" + SELECT a.*, b.count + FROM (SELECT id, name FROM users) a + JOIN (SELECT user_id, COUNT(*) as count FROM orders GROUP BY user_id) b + ON a.id = b.user_id + """ + ) + assert "Write operations are not allowed" not in result.get("error", "") + + def test_ignores_write_keywords_in_comments(self, registered_mcp): + """Write keywords inside comments should not trigger blocking.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = Exception("Mock error") + result = tool.fn(sql="SELECT * FROM t -- INSERT comment") + assert "Write operations are not allowed" not in result.get("error", "") + + +# =========================================================================== +# Custom SQL Tool — Row Limits +# =========================================================================== + + +class TestRowLimits: + """Tests for row limit validation in run_databricks_sql.""" + + def test_rejects_zero_max_rows(self, registered_mcp): + """max_rows of 0 should be rejected.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="SELECT 1", max_rows=0) + assert "error" in result + assert "max_rows must be at least 1" in result["error"] + + def test_rejects_negative_max_rows(self, registered_mcp): + """Negative max_rows should be rejected.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="SELECT 1", max_rows=-1) + assert "error" in result + assert "max_rows must be at least 1" in result["error"] + + def test_rejects_excessive_max_rows(self, registered_mcp): + """max_rows over 10000 should be rejected.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + result = tool.fn(sql="SELECT 1", max_rows=10001) + assert "error" in result + assert "max_rows cannot exceed 10000" in result["error"] + + def test_accepts_valid_max_rows(self, registered_mcp): + """Valid max_rows values should be accepted.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = Exception("Mock error") + for max_rows in [1, 100, 1000, 10000]: + result = tool.fn(sql="SELECT 1", max_rows=max_rows) + assert "max_rows" not in result.get("error", "") + + +# =========================================================================== +# Custom SQL Tool — Query Execution +# =========================================================================== + + +class TestQueryExecution: + """Tests for successful query execution with mocked Databricks client.""" + + def test_successful_query(self, registered_mcp): + """Test successful query execution with mocked WorkspaceClient.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_client = MagicMock() + mock_create.return_value = mock_client + + # Mock the statement execution response + mock_col1 = MagicMock() + mock_col1.name = "id" + mock_col1.type_name = "INT" + mock_col2 = MagicMock() + mock_col2.name = "name" + mock_col2.type_name = "STRING" + + mock_schema = MagicMock() + mock_schema.columns = [mock_col1, mock_col2] + + mock_manifest = MagicMock() + mock_manifest.schema = mock_schema + + mock_result = MagicMock() + mock_result.data_array = [["1", "Alice"], ["2", "Bob"]] + + mock_response = MagicMock() + mock_response.status.error = None + mock_response.manifest = mock_manifest + mock_response.result = mock_result + + mock_client.statement_execution.execute_statement.return_value = mock_response + + result = tool.fn(sql="SELECT id, name FROM users") + + assert result["success"] is True + assert result["rows"] == [ + {"id": "1", "name": "Alice"}, + {"id": "2", "name": "Bob"}, + ] + assert result["total_rows"] == 2 + assert result["rows_returned"] == 2 + assert result["query_truncated"] is False + assert len(result["schema"]) == 2 + + def test_query_truncation(self, registered_mcp): + """Test that results are truncated when exceeding max_rows.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_client = MagicMock() + mock_create.return_value = mock_client + + mock_col = MagicMock() + mock_col.name = "id" + mock_col.type_name = "INT" + + mock_schema = MagicMock() + mock_schema.columns = [mock_col] + + mock_manifest = MagicMock() + mock_manifest.schema = mock_schema + + # Create 10 rows of data + mock_result = MagicMock() + mock_result.data_array = [[str(i)] for i in range(10)] + + mock_response = MagicMock() + mock_response.status.error = None + mock_response.manifest = mock_manifest + mock_response.result = mock_result + + mock_client.statement_execution.execute_statement.return_value = mock_response + + # Request only 5 rows + result = tool.fn(sql="SELECT id FROM users", max_rows=5) + + assert result["success"] is True + assert result["total_rows"] == 10 + assert result["rows_returned"] == 5 + assert result["query_truncated"] is True + assert len(result["rows"]) == 5 + + def test_missing_warehouse_id(self, mcp): + """Test error when no warehouse ID is configured.""" + creds = CredentialStoreAdapter.for_testing( + { + "databricks_host": "https://test.cloud.databricks.com", + "databricks_token": "dapi_test", + } + ) + register_tools(mcp, credentials=creds) + + tool = mcp._tool_manager._tools["run_databricks_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.return_value = MagicMock() + result = tool.fn(sql="SELECT 1") + assert "error" in result + assert "No SQL Warehouse ID provided" in result["error"] + + +# =========================================================================== +# Describe Table +# =========================================================================== + + +class TestDescribeTable: + """Tests for describe_databricks_table tool.""" + + def test_empty_catalog(self, registered_mcp): + """Empty catalog should be rejected.""" + tool = registered_mcp._tool_manager._tools["describe_databricks_table"] + result = tool.fn(catalog="", schema="default", table="users") + assert "error" in result + assert "catalog is required" in result["error"] + + def test_empty_schema(self, registered_mcp): + """Empty schema should be rejected.""" + tool = registered_mcp._tool_manager._tools["describe_databricks_table"] + result = tool.fn(catalog="main", schema="", table="users") + assert "error" in result + assert "schema is required" in result["error"] + + def test_empty_table(self, registered_mcp): + """Empty table name should be rejected.""" + tool = registered_mcp._tool_manager._tools["describe_databricks_table"] + result = tool.fn(catalog="main", schema="default", table="") + assert "error" in result + assert "table name is required" in result["error"] + + def test_successful_describe(self, registered_mcp): + """Test successful table description with mocked client.""" + tool = registered_mcp._tool_manager._tools["describe_databricks_table"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_client = MagicMock() + mock_create.return_value = mock_client + + # Mock table info + mock_col1 = MagicMock() + mock_col1.name = "id" + mock_col1.type_name = "LONG" + mock_col1.nullable = False + mock_col1.comment = "User ID" + + mock_col2 = MagicMock() + mock_col2.name = "email" + mock_col2.type_name = "STRING" + mock_col2.nullable = True + mock_col2.comment = None + + mock_table_info = MagicMock() + mock_table_info.columns = [mock_col1, mock_col2] + mock_table_info.table_type = "MANAGED" + mock_table_info.comment = "User accounts table" + mock_table_info.storage_location = "s3://bucket/path" + + mock_client.tables.get.return_value = mock_table_info + + result = tool.fn(catalog="main", schema="default", table="users") + + assert result["success"] is True + assert result["catalog"] == "main" + assert result["schema"] == "default" + assert result["table"] == "users" + assert result["full_name"] == "main.default.users" + assert result["table_type"] == "MANAGED" + assert result["comment"] == "User accounts table" + assert result["storage_location"] == "s3://bucket/path" + assert len(result["columns"]) == 2 + assert result["columns"][0]["name"] == "id" + assert result["columns"][0]["nullable"] is False + assert result["columns"][1]["name"] == "email" + assert result["columns"][1]["nullable"] is True + + +# =========================================================================== +# Managed MCP Tools +# =========================================================================== + + +class TestMCPQuerySQL: + """Tests for databricks_mcp_query_sql tool.""" + + def test_empty_sql(self, registered_mcp): + """Empty SQL should be rejected.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"] + result = tool.fn(sql="") + assert "error" in result + assert "sql is required" in result["error"] + + def test_successful_mcp_sql(self, registered_mcp): + """Test successful MCP SQL query with mocked client.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client" + ) as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_content = MagicMock() + mock_content.text = "Query results: 42 rows returned" + mock_response = MagicMock() + mock_response.content = [mock_content] + mock_client.call_tool.return_value = mock_response + + result = tool.fn(sql="SELECT * FROM main.default.users") + + assert result["success"] is True + assert "Query results" in result["result"] + + +class TestMCPUCFunction: + """Tests for databricks_mcp_query_uc_function tool.""" + + def test_empty_catalog(self, registered_mcp): + """Empty catalog should be rejected.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"] + result = tool.fn(catalog="", schema="default", function_name="my_func") + assert "error" in result + assert "catalog is required" in result["error"] + + def test_empty_function_name(self, registered_mcp): + """Empty function name should be rejected.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"] + result = tool.fn(catalog="main", schema="default", function_name="") + assert "error" in result + assert "function_name is required" in result["error"] + + def test_successful_uc_function(self, registered_mcp): + """Test successful UC function call with mocked client.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_uc_function"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client" + ) as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_content = MagicMock() + mock_content.text = "Revenue: $1.2M" + mock_response = MagicMock() + mock_response.content = [mock_content] + mock_client.call_tool.return_value = mock_response + + result = tool.fn( + catalog="main", + schema="analytics", + function_name="get_revenue", + arguments={"year": "2024"}, + ) + + assert result["success"] is True + assert "Revenue" in result["result"] + + +class TestMCPVectorSearch: + """Tests for databricks_mcp_vector_search tool.""" + + def test_empty_query(self, registered_mcp): + """Empty query should be rejected.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"] + result = tool.fn(catalog="prod", schema="kb", index_name="idx", query="") + assert "error" in result + assert "query is required" in result["error"] + + def test_empty_index_name(self, registered_mcp): + """Empty index name should be rejected.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"] + result = tool.fn(catalog="prod", schema="kb", index_name="", query="test") + assert "error" in result + assert "index_name is required" in result["error"] + + def test_successful_vector_search(self, registered_mcp): + """Test successful vector search with mocked client.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_vector_search"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client" + ) as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_tool = MagicMock() + mock_tool.name = "search_index" + mock_client.list_tools.return_value = [mock_tool] + + mock_content = MagicMock() + mock_content.text = "Found 3 relevant documents" + mock_response = MagicMock() + mock_response.content = [mock_content] + mock_client.call_tool.return_value = mock_response + + result = tool.fn( + catalog="prod", + schema="kb", + index_name="docs_index", + query="How to configure auth?", + num_results=5, + ) + + assert result["success"] is True + assert "relevant documents" in result["result"] + + +class TestMCPGenie: + """Tests for databricks_mcp_query_genie tool.""" + + def test_empty_genie_space_id(self, registered_mcp): + """Empty genie space ID should be rejected.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"] + result = tool.fn(genie_space_id="", question="What is revenue?") + assert "error" in result + assert "genie_space_id is required" in result["error"] + + def test_empty_question(self, registered_mcp): + """Empty question should be rejected.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"] + result = tool.fn(genie_space_id="abc123", question="") + assert "error" in result + assert "question is required" in result["error"] + + def test_successful_genie_query(self, registered_mcp): + """Test successful Genie query with mocked client.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_genie"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client" + ) as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_tool = MagicMock() + mock_tool.name = "ask_genie" + mock_client.list_tools.return_value = [mock_tool] + + mock_content = MagicMock() + mock_content.text = "Total revenue last quarter was $1.2M" + mock_response = MagicMock() + mock_response.content = [mock_content] + mock_client.call_tool.return_value = mock_response + + result = tool.fn( + genie_space_id="abc123", + question="What was revenue last quarter?", + ) + + assert result["success"] is True + assert "$1.2M" in result["result"] + + +class TestMCPListTools: + """Tests for databricks_mcp_list_tools tool.""" + + def test_missing_server_url_and_type(self, registered_mcp): + """Should return error when neither server_url nor server_type is given.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"] + result = tool.fn() + assert "error" in result + assert "server_url or server_type is required" in result["error"] + + def test_invalid_server_type(self, registered_mcp): + """Should return error for invalid server type.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"] + result = tool.fn(server_type="invalid") + assert "error" in result + assert "Invalid server_type" in result["error"] + + def test_successful_list_tools(self, registered_mcp): + """Test successful tool listing with mocked client.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client" + ) as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_tool1 = MagicMock() + mock_tool1.name = "execute_sql" + mock_tool1.description = "Execute SQL queries" + mock_tool1.inputSchema = {"type": "object", "properties": {}} + + mock_tool2 = MagicMock() + mock_tool2.name = "list_tables" + mock_tool2.description = "List available tables" + mock_tool2.inputSchema = None + + mock_client.list_tools.return_value = [mock_tool1, mock_tool2] + + result = tool.fn(server_type="sql") + + assert result["success"] is True + assert len(result["tools"]) == 2 + assert result["tools"][0]["name"] == "execute_sql" + assert result["tools"][1]["name"] == "list_tables" + assert "https://test-workspace.cloud.databricks.com" in result["server_url"] + + def test_with_direct_server_url(self, registered_mcp): + """Test tool listing with a direct server URL.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_list_tools"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client" + ) as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.list_tools.return_value = [] + + url = "https://my-workspace.cloud.databricks.com/api/2.0/mcp/sql" + result = tool.fn(server_url=url) + + assert result["success"] is True + assert result["server_url"] == url + + +# =========================================================================== +# Error Handling +# =========================================================================== + + +class TestErrorHandling: + """Tests for error handling and user-friendly messages.""" + + def test_authentication_error(self, registered_mcp): + """Authentication errors should provide helpful messages.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = Exception("401 Unauthorized") + result = tool.fn(sql="SELECT 1") + + assert "error" in result + assert "authentication failed" in result["error"].lower() + assert "help" in result + + def test_permission_error(self, registered_mcp): + """Permission errors should provide helpful messages.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = Exception("PERMISSION_DENIED: access denied") + result = tool.fn(sql="SELECT 1") + + assert "error" in result + assert "permission denied" in result["error"].lower() + assert "help" in result + + def test_not_found_error(self, registered_mcp): + """Not found errors should provide helpful messages.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = Exception("NOT_FOUND: warehouse xyz not found") + result = tool.fn(sql="SELECT 1") + + assert "error" in result + assert "not found" in result["error"].lower() + assert "help" in result + + def test_table_not_found_error(self, registered_mcp): + """Table not found errors should provide helpful messages.""" + tool = registered_mcp._tool_manager._tools["describe_databricks_table"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = Exception("DOES_NOT_EXIST: table not found") + result = tool.fn(catalog="main", schema="default", table="nonexistent") + + assert "error" in result + assert "not found" in result["error"].lower() + + def test_mcp_import_error(self, registered_mcp): + """Import errors for managed MCP tools should be helpful.""" + tool = registered_mcp._tool_manager._tools["databricks_mcp_query_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_mcp_tool._get_mcp_client" + ) as mock_get: + mock_get.side_effect = ImportError("databricks-mcp and databricks-sdk are required") + result = tool.fn(sql="SELECT 1") + + assert "error" in result + assert "databricks" in result["error"].lower() + + +# =========================================================================== +# Credential Resolution +# =========================================================================== + + +class TestCredentialResolution: + """Tests for credential resolution from different sources.""" + + def test_uses_credential_store(self, mcp): + """Should use credentials from CredentialStoreAdapter.""" + mock_creds = CredentialStoreAdapter.for_testing( + { + "databricks_host": "https://custom.cloud.databricks.com", + "databricks_token": "dapi_custom_token", + "databricks_warehouse": "custom-warehouse", + } + ) + register_tools(mcp, credentials=mock_creds) + + assert mock_creds.get("databricks_host") == "https://custom.cloud.databricks.com" + assert mock_creds.get("databricks_token") == "dapi_custom_token" + assert mock_creds.get("databricks_warehouse") == "custom-warehouse" + + def test_falls_back_to_env_vars(self, mcp): + """Should fall back to environment variables when no credential store.""" + register_tools(mcp, credentials=None) + + # Tools are registered and will use os.getenv internally + assert "run_databricks_sql" in mcp._tool_manager._tools + assert "describe_databricks_table" in mcp._tool_manager._tools + assert "databricks_mcp_query_sql" in mcp._tool_manager._tools + assert "databricks_mcp_query_uc_function" in mcp._tool_manager._tools + assert "databricks_mcp_vector_search" in mcp._tool_manager._tools + assert "databricks_mcp_query_genie" in mcp._tool_manager._tools + assert "databricks_mcp_list_tools" in mcp._tool_manager._tools + + def test_all_tools_registered(self, registered_mcp): + """All 7 Databricks tools should be registered.""" + expected_tools = [ + "run_databricks_sql", + "describe_databricks_table", + "databricks_mcp_query_sql", + "databricks_mcp_query_uc_function", + "databricks_mcp_vector_search", + "databricks_mcp_query_genie", + "databricks_mcp_list_tools", + ] + for tool_name in expected_tools: + assert tool_name in registered_mcp._tool_manager._tools, ( + f"Tool '{tool_name}' not registered" + ) + + +# =========================================================================== +# Import Error Handling +# =========================================================================== + + +class TestImportError: + """Tests for handling missing databricks-sdk package.""" + + def test_sdk_import_error_message(self, registered_mcp): + """Should provide helpful message when databricks-sdk not installed.""" + tool = registered_mcp._tool_manager._tools["run_databricks_sql"] + + with patch( + "aden_tools.tools.databricks_tool.databricks_tool._create_workspace_client" + ) as mock_create: + mock_create.side_effect = ImportError( + "databricks-sdk is required for Databricks tools. " + "Install it with: pip install 'databricks-sdk>=0.30.0'" + ) + result = tool.fn(sql="SELECT 1") + + assert "error" in result + assert "databricks-sdk" in result["error"] + assert "pip install" in result["error"] diff --git a/uv.lock b/uv.lock index bf58df0e..d066c77e 100644 --- a/uv.lock +++ b/uv.lock @@ -536,6 +536,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/7c/996760c30f1302704af57c66ff2d723f7d656d0d0b93563b5528a51484bb/cyclopts-4.5.1-py3-none-any.whl", hash = "sha256:0642c93601e554ca6b7b9abd81093847ea4448b2616280f2a0952416574e8c7a", size = 199807, upload-time = "2026-01-25T15:23:55.219Z" }, ] +[[package]] +name = "databricks-mcp" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "databricks-sdk" }, + { name = "mcp" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/10/9f12e8d3322327a8edac8b906a0f1cece3dfbc58efdd2198538b4ba56957/databricks_mcp-0.1.0-py3-none-any.whl", hash = "sha256:d02ed6cdd24a862a365b4fae2892223d8209d14e5be500ebc14d086f646408cc", size = 2307, upload-time = "2025-06-05T20:39:11.724Z" }, +] + +[[package]] +name = "databricks-sdk" +version = "0.94.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/39/18b3a4a69851f583d7970ae71f19975e06ef892ea17393bcafb773088f96/databricks_sdk-0.94.0.tar.gz", hash = "sha256:1b29a9d63e2dc9808ed79253ba3668cd6d130dabb880a576d647fa3b1b1a7b30", size = 861465, upload-time = "2026-02-26T08:17:12.063Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/da/072fe4df10aec583eb72327856b4743c7f5fb50a16772d5f86fcc20e2457/databricks_sdk-0.94.0-py3-none-any.whl", hash = "sha256:ae980186ee2d99f3639d14a94457776dca22980a04f65f89497a79889b2b9149", size = 811079, upload-time = "2026-02-26T08:17:10.242Z" }, +] + [[package]] name = "diff-match-patch" version = "20241021" @@ -3490,6 +3516,8 @@ dependencies = [ [package.optional-dependencies] all = [ + { name = "databricks-mcp" }, + { name = "databricks-sdk" }, { name = "duckdb" }, { name = "google-cloud-bigquery" }, { name = "openpyxl" }, @@ -3500,6 +3528,10 @@ all = [ bigquery = [ { name = "google-cloud-bigquery" }, ] +databricks = [ + { name = "databricks-mcp" }, + { name = "databricks-sdk" }, +] dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, @@ -3529,6 +3561,10 @@ dev = [ requires-dist = [ { name = "arxiv", specifier = ">=2.1.0" }, { name = "beautifulsoup4", specifier = ">=4.12.0" }, + { name = "databricks-mcp", marker = "extra == 'all'", specifier = ">=0.1.0" }, + { name = "databricks-mcp", marker = "extra == 'databricks'", specifier = ">=0.1.0" }, + { name = "databricks-sdk", marker = "extra == 'all'", specifier = ">=0.30.0" }, + { name = "databricks-sdk", marker = "extra == 'databricks'", specifier = ">=0.30.0" }, { name = "diff-match-patch", specifier = ">=20230430" }, { name = "dnspython", specifier = ">=2.4.0" }, { name = "duckdb", marker = "extra == 'all'", specifier = ">=1.0.0" }, @@ -3561,7 +3597,7 @@ requires-dist = [ { name = "restrictedpython", marker = "extra == 'sandbox'", specifier = ">=7.0" }, { name = "stripe", specifier = ">=14.3.0" }, ] -provides-extras = ["dev", "sandbox", "ocr", "excel", "sql", "bigquery", "all"] +provides-extras = ["dev", "sandbox", "ocr", "excel", "sql", "bigquery", "databricks", "all"] [package.metadata.requires-dev] dev = [