feat(tools): add Weights & Biases ML experiment tracking integration (#6963)
* feat(tools): add Weights & Biases experiment tracking and model monitoring integration
* style: fix ruff formatting in wandb_tool.py
* feat(tools): add Weights & Biases ML experiment tracking integration
* fix(tools): address CodeRabbit review comments on wandb_tool
* fix(tools): rewrite wandb_tool to use official Python SDK instead of undocumented REST endpoints
* fix(tools): address Hundao review — remove .coverage, switch to GraphQL/httpx, fix wandb_host, add README
* fix(tools): wire filters to GraphQL, validate empty metric_keys, fix line lengths
* fix(tools): check credentials before input validation in wandb_get_run_metrics
Move _get_creds() call before run_id/metric_keys checks so the
framework credential test receives the expected {error, help} response
instead of a bare input-validation error.
This commit is contained in:
@@ -79,3 +79,4 @@ core/tests/*dumps/*
|
||||
screenshots/*
|
||||
|
||||
.gemini/*
|
||||
.coverage
|
||||
|
||||
@@ -139,6 +139,7 @@ from .trello import TRELLO_CREDENTIALS
|
||||
from .twilio import TWILIO_CREDENTIALS
|
||||
from .twitter import TWITTER_CREDENTIALS
|
||||
from .vercel import VERCEL_CREDENTIALS
|
||||
from .wandb import WANDB_CREDENTIALS
|
||||
from .youtube import YOUTUBE_CREDENTIALS
|
||||
from .zendesk import ZENDESK_CREDENTIALS
|
||||
from .zoho_crm import ZOHO_CRM_CREDENTIALS
|
||||
@@ -219,6 +220,7 @@ CREDENTIAL_SPECS = {
|
||||
**TWITTER_CREDENTIALS,
|
||||
**VERCEL_CREDENTIALS,
|
||||
**YOUTUBE_CREDENTIALS,
|
||||
**WANDB_CREDENTIALS,
|
||||
**ZENDESK_CREDENTIALS,
|
||||
**ZOHO_CRM_CREDENTIALS,
|
||||
**ZOOM_CREDENTIALS,
|
||||
@@ -313,6 +315,7 @@ __all__ = [
|
||||
"TWILIO_CREDENTIALS",
|
||||
"TWITTER_CREDENTIALS",
|
||||
"VERCEL_CREDENTIALS",
|
||||
"WANDB_CREDENTIALS",
|
||||
"YOUTUBE_CREDENTIALS",
|
||||
"ZENDESK_CREDENTIALS",
|
||||
"ZOHO_CRM_CREDENTIALS",
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Weights & Biases integration credentials.
|
||||
|
||||
Contains credentials for the W&B GraphQL API.
|
||||
Requires WANDB_API_KEY only — no host configuration needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import CredentialSpec
|
||||
|
||||
WANDB_CREDENTIALS = {
|
||||
"wandb_api_key": CredentialSpec(
|
||||
env_var="WANDB_API_KEY",
|
||||
tools=[
|
||||
"wandb_list_projects",
|
||||
"wandb_list_runs",
|
||||
"wandb_get_run",
|
||||
"wandb_get_run_metrics",
|
||||
"wandb_list_artifacts",
|
||||
"wandb_get_summary",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://wandb.ai/authorize",
|
||||
description="Weights & Biases API Key",
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To set up W&B API access:
|
||||
1. Create a W&B account at https://wandb.ai
|
||||
2. Go to https://wandb.ai/authorize
|
||||
3. Copy your API key
|
||||
4. Set environment variable:
|
||||
export WANDB_API_KEY=your-api-key""",
|
||||
health_check_endpoint="",
|
||||
credential_id="wandb_api_key",
|
||||
credential_key="api_key",
|
||||
),
|
||||
}
|
||||
@@ -133,6 +133,7 @@ from .twilio_tool import register_tools as register_twilio
|
||||
from .twitter_tool import register_tools as register_twitter
|
||||
from .vercel_tool import register_tools as register_vercel
|
||||
from .vision_tool import register_tools as register_vision
|
||||
from .wandb_tool import register_tools as register_wandb
|
||||
|
||||
try:
|
||||
from .web_scrape_tool import register_tools as register_web_scrape
|
||||
@@ -306,6 +307,7 @@ def _register_unverified(
|
||||
register_zendesk(mcp, credentials=credentials)
|
||||
register_zoho_crm(mcp, credentials=credentials)
|
||||
register_zoom(mcp, credentials=credentials)
|
||||
register_wandb(mcp, credentials=credentials)
|
||||
register_freshdesk(mcp, credentials=credentials)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
# Weights & Biases Tool
|
||||
|
||||
Query ML experiment runs, metrics, and artifacts from Weights & Biases using the W&B GraphQL API.
|
||||
|
||||
## Tools
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `wandb_list_projects` | List all projects for a W&B entity (user or organization) |
|
||||
| `wandb_list_runs` | List runs in a project with optional filters |
|
||||
| `wandb_get_run` | Get full details of a specific run (config, state, summary) |
|
||||
| `wandb_get_run_metrics` | Get sampled metric history for a run |
|
||||
| `wandb_list_artifacts` | List output artifacts logged by a run |
|
||||
| `wandb_get_summary` | Get final summary metrics for a run |
|
||||
|
||||
## Setup
|
||||
|
||||
Requires a W&B account and API key.
|
||||
|
||||
1. Create a W&B account at [wandb.ai](https://wandb.ai)
|
||||
2. Get your API key at [wandb.ai/authorize](https://wandb.ai/authorize)
|
||||
3. Set the environment variable:
|
||||
|
||||
```bash
|
||||
export WANDB_API_KEY=your-api-key
|
||||
```
|
||||
|
||||
Or configure via the Aden credential store as `wandb_api_key`.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### List projects for an entity
|
||||
|
||||
```python
|
||||
wandb_list_projects(entity="my-team")
|
||||
```
|
||||
|
||||
### List recent runs in a project
|
||||
|
||||
```python
|
||||
wandb_list_runs(entity="my-team", project="my-project", per_page=10)
|
||||
```
|
||||
|
||||
### Filter runs by state
|
||||
|
||||
```python
|
||||
wandb_list_runs(
|
||||
entity="my-team",
|
||||
project="my-project",
|
||||
filters='{"state": "finished"}',
|
||||
)
|
||||
```
|
||||
|
||||
### Get details of a specific run
|
||||
|
||||
```python
|
||||
wandb_get_run(entity="my-team", project="my-project", run_id="abc123")
|
||||
```
|
||||
|
||||
### Get training metrics for a run
|
||||
|
||||
```python
|
||||
wandb_get_run_metrics(
|
||||
entity="my-team",
|
||||
project="my-project",
|
||||
run_id="abc123",
|
||||
metric_keys="loss,accuracy",
|
||||
)
|
||||
```
|
||||
|
||||
### Get final summary metrics
|
||||
|
||||
```python
|
||||
wandb_get_summary(entity="my-team", project="my-project", run_id="abc123")
|
||||
```
|
||||
|
||||
### List artifacts produced by a run
|
||||
|
||||
```python
|
||||
wandb_list_artifacts(entity="my-team", project="my-project", run_id="abc123")
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
All tools return error dicts on failure:
|
||||
|
||||
```python
|
||||
{"error": "Weights & Biases credentials not configured", "help": "Set WANDB_API_KEY..."}
|
||||
{"error": "Invalid Weights & Biases API key"}
|
||||
{"error": "Weights & Biases resource not found"}
|
||||
{"error": "Request timed out"}
|
||||
{"error": "filters must be a valid JSON string"}
|
||||
{"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"}
|
||||
```
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Weights & Biases experiment tracking tool for Aden Tools."""
|
||||
|
||||
from .wandb_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Weights & Biases ML experiment tracking tool.
|
||||
|
||||
Uses the W&B GraphQL API via httpx — no SDK dependency.
|
||||
|
||||
Authentication: Bearer token (WANDB_API_KEY)
|
||||
GraphQL endpoint: https://api.wandb.ai/graphql
|
||||
|
||||
API Reference: https://github.com/wandb/wandb/blob/main/wandb/proto/wandb_internal.proto
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
GRAPHQL_URL = "https://api.wandb.ai/graphql"
|
||||
|
||||
|
||||
def _get_creds(
|
||||
credentials: CredentialStoreAdapter | None,
|
||||
) -> tuple[str] | dict[str, Any]:
|
||||
"""Return (api_key,) or an error dict."""
|
||||
if credentials is not None:
|
||||
api_key = credentials.get("wandb_api_key")
|
||||
else:
|
||||
api_key = os.getenv("WANDB_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
return {
|
||||
"error": "Weights & Biases credentials not configured",
|
||||
"help": (
|
||||
"Set WANDB_API_KEY environment variable or configure via credential store. "
|
||||
"Get your API key at https://wandb.ai/authorize"
|
||||
),
|
||||
}
|
||||
return (api_key,)
|
||||
|
||||
|
||||
def _graphql(
|
||||
api_key: str,
|
||||
query: str,
|
||||
variables: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a GraphQL query and return the parsed response."""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
GRAPHQL_URL,
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={"query": query, "variables": variables or {}},
|
||||
timeout=30.0,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"error": "Invalid Weights & Biases API key"}
|
||||
if resp.status_code == 403:
|
||||
return {"error": "Insufficient permissions for this Weights & Biases resource"}
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
detail = resp.json().get("errors", [{}])[0].get("message", resp.text)
|
||||
except Exception:
|
||||
detail = resp.text
|
||||
return {"error": f"Weights & Biases API error (HTTP {resp.status_code}): {detail}"}
|
||||
|
||||
payload = resp.json()
|
||||
if "errors" in payload:
|
||||
msg = payload["errors"][0].get("message", str(payload["errors"]))
|
||||
return {"error": f"Weights & Biases GraphQL error: {msg}"}
|
||||
|
||||
return payload.get("data", {})
|
||||
|
||||
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register Weights & Biases experiment tracking tools with the MCP server."""
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_list_projects(entity: str) -> dict:
|
||||
"""
|
||||
List all projects for a Weights & Biases entity (user or organization).
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
|
||||
Returns:
|
||||
Dict containing the list of projects for the entity.
|
||||
"""
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query ListProjects($entity: String!) {
|
||||
projects(entityName: $entity) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
description
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"entity": entity})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
edges = data.get("projects", {}).get("edges", [])
|
||||
return {
|
||||
"entity": entity,
|
||||
"projects": [
|
||||
{
|
||||
"name": e["node"]["name"],
|
||||
"description": e["node"].get("description", ""),
|
||||
"created_at": e["node"].get("createdAt", ""),
|
||||
}
|
||||
for e in edges
|
||||
],
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_list_runs(
|
||||
entity: str,
|
||||
project: str,
|
||||
filters: str = "",
|
||||
per_page: int = 50,
|
||||
) -> dict:
|
||||
"""
|
||||
List runs in a Weights & Biases project.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
filters: Optional JSON filter string to narrow results.
|
||||
per_page: Number of runs to return (default 50).
|
||||
|
||||
Returns:
|
||||
Dict containing the list of runs in the project.
|
||||
"""
|
||||
parsed_filters = None
|
||||
if filters:
|
||||
try:
|
||||
parsed_filters = json.loads(filters)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "filters must be a valid JSON string"}
|
||||
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query ListRuns($project: String!, $entity: String!, $perPage: Int!, $filters: JSONString) {
|
||||
project(name: $project, entityName: $entity) {
|
||||
runs(first: $perPage, filters: $filters) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
id
|
||||
state
|
||||
createdAt
|
||||
config
|
||||
summaryMetrics
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables: dict[str, Any] = {"project": project, "entity": entity, "perPage": per_page}
|
||||
if parsed_filters is not None:
|
||||
variables["filters"] = parsed_filters
|
||||
data = _graphql(api_key, query, variables)
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
edges = data.get("project", {}).get("runs", {}).get("edges", [])
|
||||
runs = []
|
||||
for e in edges:
|
||||
node = e["node"]
|
||||
try:
|
||||
config = json.loads(node.get("config") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
config = {}
|
||||
runs.append(
|
||||
{
|
||||
"id": node.get("name"),
|
||||
"display_name": node.get("id"),
|
||||
"state": node.get("state"),
|
||||
"created_at": node.get("createdAt"),
|
||||
"config": config,
|
||||
}
|
||||
)
|
||||
return {"entity": entity, "project": project, "runs": runs}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_get_run(entity: str, project: str, run_id: str) -> dict:
|
||||
"""
|
||||
Get details of a specific Weights & Biases run.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
run_id: The run ID.
|
||||
|
||||
Returns:
|
||||
Dict containing full run details including config and metadata.
|
||||
"""
|
||||
if not run_id:
|
||||
return {"error": "run_id is required"}
|
||||
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query GetRun($project: String!, $entity: String!, $run: String!) {
|
||||
project(name: $project, entityName: $entity) {
|
||||
run(name: $run) {
|
||||
name
|
||||
id
|
||||
state
|
||||
createdAt
|
||||
config
|
||||
summaryMetrics
|
||||
tags
|
||||
notes
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
node = data.get("project", {}).get("run")
|
||||
if not node:
|
||||
return {"error": "Weights & Biases resource not found"}
|
||||
|
||||
try:
|
||||
config = json.loads(node.get("config") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
config = {}
|
||||
try:
|
||||
summary = json.loads(node.get("summaryMetrics") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
summary = {}
|
||||
|
||||
return {
|
||||
"id": node.get("name"),
|
||||
"display_name": node.get("id"),
|
||||
"state": node.get("state"),
|
||||
"created_at": node.get("createdAt"),
|
||||
"config": config,
|
||||
"summary": summary,
|
||||
"tags": node.get("tags") or [],
|
||||
"notes": node.get("notes") or "",
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_get_run_metrics(
|
||||
entity: str,
|
||||
project: str,
|
||||
run_id: str,
|
||||
metric_keys: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Get sampled metrics history for a specific Weights & Biases run.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
run_id: The run ID.
|
||||
metric_keys: Comma-separated metric keys to sample (e.g. "loss,accuracy").
|
||||
At least one key is required.
|
||||
|
||||
Returns:
|
||||
Dict containing sampled metric history per key.
|
||||
"""
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
if not run_id:
|
||||
return {"error": "run_id is required"}
|
||||
if not metric_keys:
|
||||
return {"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"}
|
||||
|
||||
keys = [k.strip() for k in metric_keys.split(",") if k.strip()]
|
||||
if not keys:
|
||||
return {"error": "metric_keys must include at least one non-empty key"}
|
||||
specs = json.dumps([{"key": k} for k in keys])
|
||||
|
||||
query = f"""
|
||||
query GetRunMetrics($project: String!, $entity: String!, $run: String!) {{
|
||||
project(name: $project, entityName: $entity) {{
|
||||
run(name: $run) {{
|
||||
sampledHistory(specs: {specs})
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
node = data.get("project", {}).get("run")
|
||||
if not node:
|
||||
return {"error": "Weights & Biases resource not found"}
|
||||
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"metric_keys": keys,
|
||||
"history": node.get("sampledHistory", []),
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_list_artifacts(entity: str, project: str, run_id: str) -> dict:
|
||||
"""
|
||||
List artifacts logged by a specific Weights & Biases run.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
run_id: The run ID.
|
||||
|
||||
Returns:
|
||||
Dict containing the list of output artifacts for the run.
|
||||
"""
|
||||
if not run_id:
|
||||
return {"error": "run_id is required"}
|
||||
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query ListArtifacts($project: String!, $entity: String!, $run: String!) {
|
||||
project(name: $project, entityName: $entity) {
|
||||
run(name: $run) {
|
||||
outputArtifacts {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
type
|
||||
description
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
node = data.get("project", {}).get("run")
|
||||
if not node:
|
||||
return {"error": "Weights & Biases resource not found"}
|
||||
|
||||
edges = node.get("outputArtifacts", {}).get("edges", [])
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"artifacts": [
|
||||
{
|
||||
"name": e["node"]["name"],
|
||||
"type": e["node"]["type"],
|
||||
"description": e["node"].get("description", ""),
|
||||
"created_at": e["node"].get("createdAt", ""),
|
||||
}
|
||||
for e in edges
|
||||
],
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_get_summary(entity: str, project: str, run_id: str) -> dict:
|
||||
"""
|
||||
Get summary metrics for a specific Weights & Biases run.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
run_id: The run ID.
|
||||
|
||||
Returns:
|
||||
Dict containing the run's final summary metrics.
|
||||
"""
|
||||
if not run_id:
|
||||
return {"error": "run_id is required"}
|
||||
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query GetSummary($project: String!, $entity: String!, $run: String!) {
|
||||
project(name: $project, entityName: $entity) {
|
||||
run(name: $run) {
|
||||
summaryMetrics
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
node = data.get("project", {}).get("run")
|
||||
if not node:
|
||||
return {"error": "Weights & Biases resource not found"}
|
||||
|
||||
try:
|
||||
summary = json.loads(node.get("summaryMetrics") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
summary = {}
|
||||
|
||||
# Filter out internal W&B keys
|
||||
summary = {k: v for k, v in summary.items() if not k.startswith("_")}
|
||||
return {"run_id": run_id, "summary": summary}
|
||||
@@ -0,0 +1,317 @@
|
||||
"""Tests for wandb_tool - Weights & Biases integration (GraphQL/httpx)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.wandb_tool.wandb_tool import register_tools
|
||||
|
||||
ENV = {"WANDB_API_KEY": "test-key-abcdefghij"}
|
||||
_PATCH_POST = "aden_tools.tools.wandb_tool.wandb_tool.httpx.post"
|
||||
|
||||
|
||||
def _mock_resp(data: Any, status_code: int = 200) -> MagicMock:
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.json.return_value = data
|
||||
resp.text = str(data)
|
||||
return resp
|
||||
|
||||
|
||||
def _gql_ok(data: dict[str, Any]) -> MagicMock:
|
||||
"""Wrap data in the GraphQL envelope: {"data": {...}}."""
|
||||
return _mock_resp({"data": data})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fns(mcp: FastMCP) -> dict[str, Any]:
|
||||
register_tools(mcp, credentials=None)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
class TestWandbTool:
|
||||
# --- Credential tests ---
|
||||
|
||||
def test_missing_credentials_returns_error(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""Missing WANDB_API_KEY must return a descriptive error dict with help."""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
result = tool_fns["wandb_list_projects"](entity="test-entity")
|
||||
assert "error" in result
|
||||
assert "credentials not configured" in result["error"]
|
||||
assert "help" in result
|
||||
|
||||
# --- wandb_list_projects ---
|
||||
|
||||
def test_wandb_list_projects_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_projects returns projects list from GraphQL."""
|
||||
gql_data = {
|
||||
"projects": {
|
||||
"edges": [
|
||||
{
|
||||
"node": {
|
||||
"name": "proj-a",
|
||||
"description": "Desc A",
|
||||
"createdAt": "2024-01-01",
|
||||
}
|
||||
},
|
||||
{
|
||||
"node": {
|
||||
"name": "proj-b",
|
||||
"description": "",
|
||||
"createdAt": "2024-02-01",
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="test-entity")
|
||||
|
||||
assert result["entity"] == "test-entity"
|
||||
assert len(result["projects"]) == 2
|
||||
assert result["projects"][0]["name"] == "proj-a"
|
||||
|
||||
def test_wandb_list_projects_http_401(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""HTTP 401 returns an invalid key error."""
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_mock_resp({}, status_code=401)),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="e")
|
||||
assert result["error"] == "Invalid Weights & Biases API key"
|
||||
|
||||
def test_wandb_list_projects_graphql_error(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""GraphQL error block is surfaced as an error dict."""
|
||||
gql_err = {"errors": [{"message": "entity not found"}]}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_mock_resp(gql_err)),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="e")
|
||||
assert "error" in result
|
||||
assert "entity not found" in result["error"]
|
||||
|
||||
# --- wandb_list_runs ---
|
||||
|
||||
def test_wandb_list_runs_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_runs returns runs list."""
|
||||
gql_data = {
|
||||
"project": {
|
||||
"runs": {
|
||||
"edges": [
|
||||
{
|
||||
"node": {
|
||||
"name": "w854ckuu",
|
||||
"id": "ferengi-directive-1",
|
||||
"state": "finished",
|
||||
"createdAt": "2024-01-01",
|
||||
"config": "{}",
|
||||
"summaryMetrics": "{}",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)) as mock_post,
|
||||
):
|
||||
result = tool_fns["wandb_list_runs"](
|
||||
entity="testentity",
|
||||
project="testproject",
|
||||
filters='{"key": "value"}',
|
||||
per_page=50,
|
||||
)
|
||||
|
||||
assert result["project"] == "testproject"
|
||||
assert len(result["runs"]) == 1
|
||||
assert result["runs"][0]["id"] == "w854ckuu"
|
||||
# Verify filters and per_page were forwarded in GraphQL variables
|
||||
call_json = mock_post.call_args[1]["json"]
|
||||
assert call_json["variables"]["perPage"] == 50
|
||||
assert call_json["variables"]["filters"] == {"key": "value"}
|
||||
|
||||
def test_wandb_list_runs_invalid_filters_json(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_runs returns error for invalid JSON filters before any HTTP call."""
|
||||
with patch.dict("os.environ", ENV):
|
||||
result = tool_fns["wandb_list_runs"](entity="e", project="p", filters="not-json")
|
||||
assert "error" in result
|
||||
assert "valid JSON" in result["error"]
|
||||
|
||||
# --- wandb_get_run ---
|
||||
|
||||
def test_wandb_get_run_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run returns run details."""
|
||||
gql_data = {
|
||||
"project": {
|
||||
"run": {
|
||||
"name": "run-123",
|
||||
"id": "my-run",
|
||||
"state": "finished",
|
||||
"createdAt": "2024-01-01",
|
||||
"config": '{"lr": 0.001}',
|
||||
"summaryMetrics": '{"accuracy": 0.9}',
|
||||
"tags": ["v1"],
|
||||
"notes": "test",
|
||||
}
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="run-123")
|
||||
|
||||
assert result["id"] == "run-123"
|
||||
assert result["config"] == {"lr": 0.001}
|
||||
assert result["summary"] == {"accuracy": 0.9}
|
||||
|
||||
def test_wandb_get_run_missing_id(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run with empty run_id returns error before HTTP call."""
|
||||
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="")
|
||||
assert "error" in result
|
||||
assert result["error"] == "run_id is required"
|
||||
|
||||
def test_wandb_get_run_not_found(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run returns not-found error when run is null."""
|
||||
gql_data = {"project": {"run": None}}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="nope")
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"]
|
||||
|
||||
# --- wandb_get_run_metrics ---
|
||||
|
||||
def test_wandb_get_run_metrics_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run_metrics returns sampled history."""
|
||||
gql_data = {
|
||||
"project": {
|
||||
"run": {
|
||||
"sampledHistory": [{"loss": 0.5}, {"loss": 0.3}],
|
||||
}
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_get_run_metrics"](
|
||||
entity="e", project="p", run_id="r1", metric_keys="loss"
|
||||
)
|
||||
|
||||
assert result["run_id"] == "r1"
|
||||
assert result["metric_keys"] == ["loss"]
|
||||
assert result["history"] == [{"loss": 0.5}, {"loss": 0.3}]
|
||||
|
||||
def test_wandb_get_run_metrics_missing_id(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run_metrics with empty run_id returns error."""
|
||||
with patch.dict("os.environ", ENV):
|
||||
result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="")
|
||||
assert "error" in result
|
||||
assert result["error"] == "run_id is required"
|
||||
|
||||
def test_wandb_get_run_metrics_missing_keys(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run_metrics with no metric_keys returns error."""
|
||||
with patch.dict("os.environ", ENV):
|
||||
result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="r1")
|
||||
assert "error" in result
|
||||
assert "metric_keys is required" in result["error"]
|
||||
|
||||
# --- wandb_list_artifacts ---
|
||||
|
||||
def test_wandb_list_artifacts_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_artifacts returns artifact list."""
|
||||
gql_data = {
|
||||
"project": {
|
||||
"run": {
|
||||
"outputArtifacts": {
|
||||
"edges": [
|
||||
{
|
||||
"node": {
|
||||
"name": "model:v0",
|
||||
"type": "model",
|
||||
"description": "",
|
||||
"createdAt": "2024-01-01",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="r1")
|
||||
|
||||
assert result["run_id"] == "r1"
|
||||
assert result["artifacts"][0]["name"] == "model:v0"
|
||||
|
||||
def test_wandb_list_artifacts_missing_id(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_artifacts with empty run_id returns error."""
|
||||
result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="")
|
||||
assert "error" in result
|
||||
assert result["error"] == "run_id is required"
|
||||
|
||||
# --- wandb_get_summary ---
|
||||
|
||||
def test_wandb_get_summary_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_summary returns summary filtering out _-prefixed keys."""
|
||||
gql_data = {
|
||||
"project": {"run": {"summaryMetrics": '{"accuracy": 0.9, "loss": 0.1, "_step": 5}'}}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="r1")
|
||||
|
||||
assert result["run_id"] == "r1"
|
||||
assert result["summary"]["accuracy"] == 0.9
|
||||
assert "_step" not in result["summary"]
|
||||
|
||||
def test_wandb_get_summary_missing_id(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_summary with empty run_id returns error."""
|
||||
result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="")
|
||||
assert "error" in result
|
||||
assert result["error"] == "run_id is required"
|
||||
|
||||
# --- Network/timeout errors ---
|
||||
|
||||
def test_timeout_returns_error(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""httpx.TimeoutException is caught and returns a timeout message."""
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(
|
||||
"aden_tools.tools.wandb_tool.wandb_tool.httpx.post",
|
||||
side_effect=httpx.TimeoutException("timeout"),
|
||||
),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="e")
|
||||
assert result["error"] == "Request timed out"
|
||||
|
||||
def test_network_error_returns_error(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""httpx.RequestError is caught and returns a network error message."""
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(
|
||||
"aden_tools.tools.wandb_tool.wandb_tool.httpx.post",
|
||||
side_effect=httpx.RequestError("Connection refused"),
|
||||
),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="e")
|
||||
assert "Network error" in result["error"]
|
||||
Reference in New Issue
Block a user