feat(tools): add Weights & Biases ML experiment tracking integration (#6963)

* feat(tools): add Weights & Biases experiment tracking and model monitoring integration

* style: fix ruff formatting in wandb_tool.py

* feat(tools): add Weights & Biases ML experiment tracking integration

* fix(tools): address CodeRabbit review comments on wandb_tool

* fix(tools): rewrite wandb_tool to use official Python SDK instead of undocumented REST endpoints

* fix(tools): address Hundao review — remove .coverage, switch to GraphQL/httpx, fix wandb_host, add README

* fix(tools): wire filters to GraphQL, validate empty metric_keys, fix line lengths

* fix(tools): check credentials before input validation in wandb_get_run_metrics

Move _get_creds() call before run_id/metric_keys checks so the
framework credential test receives the expected {error, help} response
instead of a bare input-validation error.
This commit is contained in:
Gaurav Rai
2026-04-08 02:57:03 -04:00
committed by GitHub
parent df29c49bd0
commit 8f608048f9
8 changed files with 900 additions and 0 deletions
+1
View File
@@ -79,3 +79,4 @@ core/tests/*dumps/*
screenshots/*
.gemini/*
.coverage
@@ -139,6 +139,7 @@ from .trello import TRELLO_CREDENTIALS
from .twilio import TWILIO_CREDENTIALS
from .twitter import TWITTER_CREDENTIALS
from .vercel import VERCEL_CREDENTIALS
from .wandb import WANDB_CREDENTIALS
from .youtube import YOUTUBE_CREDENTIALS
from .zendesk import ZENDESK_CREDENTIALS
from .zoho_crm import ZOHO_CRM_CREDENTIALS
@@ -219,6 +220,7 @@ CREDENTIAL_SPECS = {
**TWITTER_CREDENTIALS,
**VERCEL_CREDENTIALS,
**YOUTUBE_CREDENTIALS,
**WANDB_CREDENTIALS,
**ZENDESK_CREDENTIALS,
**ZOHO_CRM_CREDENTIALS,
**ZOOM_CREDENTIALS,
@@ -313,6 +315,7 @@ __all__ = [
"TWILIO_CREDENTIALS",
"TWITTER_CREDENTIALS",
"VERCEL_CREDENTIALS",
"WANDB_CREDENTIALS",
"YOUTUBE_CREDENTIALS",
"ZENDESK_CREDENTIALS",
"ZOHO_CRM_CREDENTIALS",
+38
View File
@@ -0,0 +1,38 @@
"""
Weights & Biases integration credentials.
Contains credentials for the W&B GraphQL API.
Requires WANDB_API_KEY only no host configuration needed.
"""
from __future__ import annotations
from .base import CredentialSpec
WANDB_CREDENTIALS = {
"wandb_api_key": CredentialSpec(
env_var="WANDB_API_KEY",
tools=[
"wandb_list_projects",
"wandb_list_runs",
"wandb_get_run",
"wandb_get_run_metrics",
"wandb_list_artifacts",
"wandb_get_summary",
],
required=True,
startup_required=False,
help_url="https://wandb.ai/authorize",
description="Weights & Biases API Key",
direct_api_key_supported=True,
api_key_instructions="""To set up W&B API access:
1. Create a W&B account at https://wandb.ai
2. Go to https://wandb.ai/authorize
3. Copy your API key
4. Set environment variable:
export WANDB_API_KEY=your-api-key""",
health_check_endpoint="",
credential_id="wandb_api_key",
credential_key="api_key",
),
}
+2
View File
@@ -133,6 +133,7 @@ from .twilio_tool import register_tools as register_twilio
from .twitter_tool import register_tools as register_twitter
from .vercel_tool import register_tools as register_vercel
from .vision_tool import register_tools as register_vision
from .wandb_tool import register_tools as register_wandb
try:
from .web_scrape_tool import register_tools as register_web_scrape
@@ -306,6 +307,7 @@ def _register_unverified(
register_zendesk(mcp, credentials=credentials)
register_zoho_crm(mcp, credentials=credentials)
register_zoom(mcp, credentials=credentials)
register_wandb(mcp, credentials=credentials)
register_freshdesk(mcp, credentials=credentials)
@@ -0,0 +1,94 @@
# Weights & Biases Tool
Query ML experiment runs, metrics, and artifacts from Weights & Biases using the W&B GraphQL API.
## Tools
| Tool | Description |
|------|-------------|
| `wandb_list_projects` | List all projects for a W&B entity (user or organization) |
| `wandb_list_runs` | List runs in a project with optional filters |
| `wandb_get_run` | Get full details of a specific run (config, state, summary) |
| `wandb_get_run_metrics` | Get sampled metric history for a run |
| `wandb_list_artifacts` | List output artifacts logged by a run |
| `wandb_get_summary` | Get final summary metrics for a run |
## Setup
Requires a W&B account and API key.
1. Create a W&B account at [wandb.ai](https://wandb.ai)
2. Get your API key at [wandb.ai/authorize](https://wandb.ai/authorize)
3. Set the environment variable:
```bash
export WANDB_API_KEY=your-api-key
```
Or configure via the Aden credential store as `wandb_api_key`.
## Usage Examples
### List projects for an entity
```python
wandb_list_projects(entity="my-team")
```
### List recent runs in a project
```python
wandb_list_runs(entity="my-team", project="my-project", per_page=10)
```
### Filter runs by state
```python
wandb_list_runs(
entity="my-team",
project="my-project",
filters='{"state": "finished"}',
)
```
### Get details of a specific run
```python
wandb_get_run(entity="my-team", project="my-project", run_id="abc123")
```
### Get training metrics for a run
```python
wandb_get_run_metrics(
entity="my-team",
project="my-project",
run_id="abc123",
metric_keys="loss,accuracy",
)
```
### Get final summary metrics
```python
wandb_get_summary(entity="my-team", project="my-project", run_id="abc123")
```
### List artifacts produced by a run
```python
wandb_list_artifacts(entity="my-team", project="my-project", run_id="abc123")
```
## Error Handling
All tools return error dicts on failure:
```python
{"error": "Weights & Biases credentials not configured", "help": "Set WANDB_API_KEY..."}
{"error": "Invalid Weights & Biases API key"}
{"error": "Weights & Biases resource not found"}
{"error": "Request timed out"}
{"error": "filters must be a valid JSON string"}
{"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"}
```
@@ -0,0 +1,5 @@
"""Weights & Biases experiment tracking tool for Aden Tools."""
from .wandb_tool import register_tools
__all__ = ["register_tools"]
@@ -0,0 +1,440 @@
"""
Weights & Biases ML experiment tracking tool.
Uses the W&B GraphQL API via httpx no SDK dependency.
Authentication: Bearer token (WANDB_API_KEY)
GraphQL endpoint: https://api.wandb.ai/graphql
API Reference: https://github.com/wandb/wandb/blob/main/wandb/proto/wandb_internal.proto
"""
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any
import httpx
from fastmcp import FastMCP
if TYPE_CHECKING:
from aden_tools.credentials import CredentialStoreAdapter
GRAPHQL_URL = "https://api.wandb.ai/graphql"
def _get_creds(
credentials: CredentialStoreAdapter | None,
) -> tuple[str] | dict[str, Any]:
"""Return (api_key,) or an error dict."""
if credentials is not None:
api_key = credentials.get("wandb_api_key")
else:
api_key = os.getenv("WANDB_API_KEY")
if not api_key:
return {
"error": "Weights & Biases credentials not configured",
"help": (
"Set WANDB_API_KEY environment variable or configure via credential store. "
"Get your API key at https://wandb.ai/authorize"
),
}
return (api_key,)
def _graphql(
api_key: str,
query: str,
variables: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Execute a GraphQL query and return the parsed response."""
try:
resp = httpx.post(
GRAPHQL_URL,
headers={"Authorization": f"Bearer {api_key}"},
json={"query": query, "variables": variables or {}},
timeout=30.0,
)
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
if resp.status_code == 401:
return {"error": "Invalid Weights & Biases API key"}
if resp.status_code == 403:
return {"error": "Insufficient permissions for this Weights & Biases resource"}
if resp.status_code >= 400:
try:
detail = resp.json().get("errors", [{}])[0].get("message", resp.text)
except Exception:
detail = resp.text
return {"error": f"Weights & Biases API error (HTTP {resp.status_code}): {detail}"}
payload = resp.json()
if "errors" in payload:
msg = payload["errors"][0].get("message", str(payload["errors"]))
return {"error": f"Weights & Biases GraphQL error: {msg}"}
return payload.get("data", {})
def register_tools(
mcp: FastMCP,
credentials: CredentialStoreAdapter | None = None,
) -> None:
"""Register Weights & Biases experiment tracking tools with the MCP server."""
@mcp.tool()
def wandb_list_projects(entity: str) -> dict:
"""
List all projects for a Weights & Biases entity (user or organization).
Args:
entity: The W&B entity name (username or organization).
Returns:
Dict containing the list of projects for the entity.
"""
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query ListProjects($entity: String!) {
projects(entityName: $entity) {
edges {
node {
name
description
createdAt
}
}
}
}
"""
data = _graphql(api_key, query, {"entity": entity})
if "error" in data:
return data
edges = data.get("projects", {}).get("edges", [])
return {
"entity": entity,
"projects": [
{
"name": e["node"]["name"],
"description": e["node"].get("description", ""),
"created_at": e["node"].get("createdAt", ""),
}
for e in edges
],
}
@mcp.tool()
def wandb_list_runs(
entity: str,
project: str,
filters: str = "",
per_page: int = 50,
) -> dict:
"""
List runs in a Weights & Biases project.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
filters: Optional JSON filter string to narrow results.
per_page: Number of runs to return (default 50).
Returns:
Dict containing the list of runs in the project.
"""
parsed_filters = None
if filters:
try:
parsed_filters = json.loads(filters)
except json.JSONDecodeError:
return {"error": "filters must be a valid JSON string"}
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query ListRuns($project: String!, $entity: String!, $perPage: Int!, $filters: JSONString) {
project(name: $project, entityName: $entity) {
runs(first: $perPage, filters: $filters) {
edges {
node {
name
id
state
createdAt
config
summaryMetrics
}
}
}
}
}
"""
variables: dict[str, Any] = {"project": project, "entity": entity, "perPage": per_page}
if parsed_filters is not None:
variables["filters"] = parsed_filters
data = _graphql(api_key, query, variables)
if "error" in data:
return data
edges = data.get("project", {}).get("runs", {}).get("edges", [])
runs = []
for e in edges:
node = e["node"]
try:
config = json.loads(node.get("config") or "{}")
except (json.JSONDecodeError, TypeError):
config = {}
runs.append(
{
"id": node.get("name"),
"display_name": node.get("id"),
"state": node.get("state"),
"created_at": node.get("createdAt"),
"config": config,
}
)
return {"entity": entity, "project": project, "runs": runs}
@mcp.tool()
def wandb_get_run(entity: str, project: str, run_id: str) -> dict:
"""
Get details of a specific Weights & Biases run.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
run_id: The run ID.
Returns:
Dict containing full run details including config and metadata.
"""
if not run_id:
return {"error": "run_id is required"}
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query GetRun($project: String!, $entity: String!, $run: String!) {
project(name: $project, entityName: $entity) {
run(name: $run) {
name
id
state
createdAt
config
summaryMetrics
tags
notes
}
}
}
"""
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
if "error" in data:
return data
node = data.get("project", {}).get("run")
if not node:
return {"error": "Weights & Biases resource not found"}
try:
config = json.loads(node.get("config") or "{}")
except (json.JSONDecodeError, TypeError):
config = {}
try:
summary = json.loads(node.get("summaryMetrics") or "{}")
except (json.JSONDecodeError, TypeError):
summary = {}
return {
"id": node.get("name"),
"display_name": node.get("id"),
"state": node.get("state"),
"created_at": node.get("createdAt"),
"config": config,
"summary": summary,
"tags": node.get("tags") or [],
"notes": node.get("notes") or "",
}
@mcp.tool()
def wandb_get_run_metrics(
entity: str,
project: str,
run_id: str,
metric_keys: str = "",
) -> dict:
"""
Get sampled metrics history for a specific Weights & Biases run.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
run_id: The run ID.
metric_keys: Comma-separated metric keys to sample (e.g. "loss,accuracy").
At least one key is required.
Returns:
Dict containing sampled metric history per key.
"""
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
if not run_id:
return {"error": "run_id is required"}
if not metric_keys:
return {"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"}
keys = [k.strip() for k in metric_keys.split(",") if k.strip()]
if not keys:
return {"error": "metric_keys must include at least one non-empty key"}
specs = json.dumps([{"key": k} for k in keys])
query = f"""
query GetRunMetrics($project: String!, $entity: String!, $run: String!) {{
project(name: $project, entityName: $entity) {{
run(name: $run) {{
sampledHistory(specs: {specs})
}}
}}
}}
"""
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
if "error" in data:
return data
node = data.get("project", {}).get("run")
if not node:
return {"error": "Weights & Biases resource not found"}
return {
"run_id": run_id,
"metric_keys": keys,
"history": node.get("sampledHistory", []),
}
@mcp.tool()
def wandb_list_artifacts(entity: str, project: str, run_id: str) -> dict:
"""
List artifacts logged by a specific Weights & Biases run.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
run_id: The run ID.
Returns:
Dict containing the list of output artifacts for the run.
"""
if not run_id:
return {"error": "run_id is required"}
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query ListArtifacts($project: String!, $entity: String!, $run: String!) {
project(name: $project, entityName: $entity) {
run(name: $run) {
outputArtifacts {
edges {
node {
name
type
description
createdAt
}
}
}
}
}
}
"""
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
if "error" in data:
return data
node = data.get("project", {}).get("run")
if not node:
return {"error": "Weights & Biases resource not found"}
edges = node.get("outputArtifacts", {}).get("edges", [])
return {
"run_id": run_id,
"artifacts": [
{
"name": e["node"]["name"],
"type": e["node"]["type"],
"description": e["node"].get("description", ""),
"created_at": e["node"].get("createdAt", ""),
}
for e in edges
],
}
@mcp.tool()
def wandb_get_summary(entity: str, project: str, run_id: str) -> dict:
"""
Get summary metrics for a specific Weights & Biases run.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
run_id: The run ID.
Returns:
Dict containing the run's final summary metrics.
"""
if not run_id:
return {"error": "run_id is required"}
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query GetSummary($project: String!, $entity: String!, $run: String!) {
project(name: $project, entityName: $entity) {
run(name: $run) {
summaryMetrics
}
}
}
"""
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
if "error" in data:
return data
node = data.get("project", {}).get("run")
if not node:
return {"error": "Weights & Biases resource not found"}
try:
summary = json.loads(node.get("summaryMetrics") or "{}")
except (json.JSONDecodeError, TypeError):
summary = {}
# Filter out internal W&B keys
summary = {k: v for k, v in summary.items() if not k.startswith("_")}
return {"run_id": run_id, "summary": summary}
+317
View File
@@ -0,0 +1,317 @@
"""Tests for wandb_tool - Weights & Biases integration (GraphQL/httpx)."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock, patch
import httpx
import pytest
from fastmcp import FastMCP
from aden_tools.tools.wandb_tool.wandb_tool import register_tools
ENV = {"WANDB_API_KEY": "test-key-abcdefghij"}
_PATCH_POST = "aden_tools.tools.wandb_tool.wandb_tool.httpx.post"
def _mock_resp(data: Any, status_code: int = 200) -> MagicMock:
resp = MagicMock()
resp.status_code = status_code
resp.json.return_value = data
resp.text = str(data)
return resp
def _gql_ok(data: dict[str, Any]) -> MagicMock:
"""Wrap data in the GraphQL envelope: {"data": {...}}."""
return _mock_resp({"data": data})
@pytest.fixture
def tool_fns(mcp: FastMCP) -> dict[str, Any]:
register_tools(mcp, credentials=None)
tools = mcp._tool_manager._tools
return {name: tools[name].fn for name in tools}
class TestWandbTool:
# --- Credential tests ---
def test_missing_credentials_returns_error(self, tool_fns: dict[str, Any]) -> None:
"""Missing WANDB_API_KEY must return a descriptive error dict with help."""
with patch.dict("os.environ", {}, clear=True):
result = tool_fns["wandb_list_projects"](entity="test-entity")
assert "error" in result
assert "credentials not configured" in result["error"]
assert "help" in result
# --- wandb_list_projects ---
def test_wandb_list_projects_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_projects returns projects list from GraphQL."""
gql_data = {
"projects": {
"edges": [
{
"node": {
"name": "proj-a",
"description": "Desc A",
"createdAt": "2024-01-01",
}
},
{
"node": {
"name": "proj-b",
"description": "",
"createdAt": "2024-02-01",
}
},
]
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_list_projects"](entity="test-entity")
assert result["entity"] == "test-entity"
assert len(result["projects"]) == 2
assert result["projects"][0]["name"] == "proj-a"
def test_wandb_list_projects_http_401(self, tool_fns: dict[str, Any]) -> None:
"""HTTP 401 returns an invalid key error."""
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_mock_resp({}, status_code=401)),
):
result = tool_fns["wandb_list_projects"](entity="e")
assert result["error"] == "Invalid Weights & Biases API key"
def test_wandb_list_projects_graphql_error(self, tool_fns: dict[str, Any]) -> None:
"""GraphQL error block is surfaced as an error dict."""
gql_err = {"errors": [{"message": "entity not found"}]}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_mock_resp(gql_err)),
):
result = tool_fns["wandb_list_projects"](entity="e")
assert "error" in result
assert "entity not found" in result["error"]
# --- wandb_list_runs ---
def test_wandb_list_runs_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_runs returns runs list."""
gql_data = {
"project": {
"runs": {
"edges": [
{
"node": {
"name": "w854ckuu",
"id": "ferengi-directive-1",
"state": "finished",
"createdAt": "2024-01-01",
"config": "{}",
"summaryMetrics": "{}",
}
}
]
}
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)) as mock_post,
):
result = tool_fns["wandb_list_runs"](
entity="testentity",
project="testproject",
filters='{"key": "value"}',
per_page=50,
)
assert result["project"] == "testproject"
assert len(result["runs"]) == 1
assert result["runs"][0]["id"] == "w854ckuu"
# Verify filters and per_page were forwarded in GraphQL variables
call_json = mock_post.call_args[1]["json"]
assert call_json["variables"]["perPage"] == 50
assert call_json["variables"]["filters"] == {"key": "value"}
def test_wandb_list_runs_invalid_filters_json(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_runs returns error for invalid JSON filters before any HTTP call."""
with patch.dict("os.environ", ENV):
result = tool_fns["wandb_list_runs"](entity="e", project="p", filters="not-json")
assert "error" in result
assert "valid JSON" in result["error"]
# --- wandb_get_run ---
def test_wandb_get_run_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run returns run details."""
gql_data = {
"project": {
"run": {
"name": "run-123",
"id": "my-run",
"state": "finished",
"createdAt": "2024-01-01",
"config": '{"lr": 0.001}',
"summaryMetrics": '{"accuracy": 0.9}',
"tags": ["v1"],
"notes": "test",
}
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="run-123")
assert result["id"] == "run-123"
assert result["config"] == {"lr": 0.001}
assert result["summary"] == {"accuracy": 0.9}
def test_wandb_get_run_missing_id(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run with empty run_id returns error before HTTP call."""
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="")
assert "error" in result
assert result["error"] == "run_id is required"
def test_wandb_get_run_not_found(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run returns not-found error when run is null."""
gql_data = {"project": {"run": None}}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="nope")
assert "error" in result
assert "not found" in result["error"]
# --- wandb_get_run_metrics ---
def test_wandb_get_run_metrics_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run_metrics returns sampled history."""
gql_data = {
"project": {
"run": {
"sampledHistory": [{"loss": 0.5}, {"loss": 0.3}],
}
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_get_run_metrics"](
entity="e", project="p", run_id="r1", metric_keys="loss"
)
assert result["run_id"] == "r1"
assert result["metric_keys"] == ["loss"]
assert result["history"] == [{"loss": 0.5}, {"loss": 0.3}]
def test_wandb_get_run_metrics_missing_id(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run_metrics with empty run_id returns error."""
with patch.dict("os.environ", ENV):
result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="")
assert "error" in result
assert result["error"] == "run_id is required"
def test_wandb_get_run_metrics_missing_keys(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run_metrics with no metric_keys returns error."""
with patch.dict("os.environ", ENV):
result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="r1")
assert "error" in result
assert "metric_keys is required" in result["error"]
# --- wandb_list_artifacts ---
def test_wandb_list_artifacts_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_artifacts returns artifact list."""
gql_data = {
"project": {
"run": {
"outputArtifacts": {
"edges": [
{
"node": {
"name": "model:v0",
"type": "model",
"description": "",
"createdAt": "2024-01-01",
}
}
]
}
}
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="r1")
assert result["run_id"] == "r1"
assert result["artifacts"][0]["name"] == "model:v0"
def test_wandb_list_artifacts_missing_id(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_artifacts with empty run_id returns error."""
result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="")
assert "error" in result
assert result["error"] == "run_id is required"
# --- wandb_get_summary ---
def test_wandb_get_summary_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_summary returns summary filtering out _-prefixed keys."""
gql_data = {
"project": {"run": {"summaryMetrics": '{"accuracy": 0.9, "loss": 0.1, "_step": 5}'}}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="r1")
assert result["run_id"] == "r1"
assert result["summary"]["accuracy"] == 0.9
assert "_step" not in result["summary"]
def test_wandb_get_summary_missing_id(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_summary with empty run_id returns error."""
result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="")
assert "error" in result
assert result["error"] == "run_id is required"
# --- Network/timeout errors ---
def test_timeout_returns_error(self, tool_fns: dict[str, Any]) -> None:
"""httpx.TimeoutException is caught and returns a timeout message."""
with (
patch.dict("os.environ", ENV),
patch(
"aden_tools.tools.wandb_tool.wandb_tool.httpx.post",
side_effect=httpx.TimeoutException("timeout"),
),
):
result = tool_fns["wandb_list_projects"](entity="e")
assert result["error"] == "Request timed out"
def test_network_error_returns_error(self, tool_fns: dict[str, Any]) -> None:
"""httpx.RequestError is caught and returns a network error message."""
with (
patch.dict("os.environ", ENV),
patch(
"aden_tools.tools.wandb_tool.wandb_tool.httpx.post",
side_effect=httpx.RequestError("Connection refused"),
),
):
result = tool_fns["wandb_list_projects"](entity="e")
assert "Network error" in result["error"]