diff --git a/.gitignore b/.gitignore index 2aa2d4c5..54798a34 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,4 @@ core/tests/*dumps/* screenshots/* .gemini/* +.coverage diff --git a/tools/src/aden_tools/credentials/__init__.py b/tools/src/aden_tools/credentials/__init__.py index fc4e9869..43afd335 100644 --- a/tools/src/aden_tools/credentials/__init__.py +++ b/tools/src/aden_tools/credentials/__init__.py @@ -139,6 +139,7 @@ from .trello import TRELLO_CREDENTIALS from .twilio import TWILIO_CREDENTIALS from .twitter import TWITTER_CREDENTIALS from .vercel import VERCEL_CREDENTIALS +from .wandb import WANDB_CREDENTIALS from .youtube import YOUTUBE_CREDENTIALS from .zendesk import ZENDESK_CREDENTIALS from .zoho_crm import ZOHO_CRM_CREDENTIALS @@ -219,6 +220,7 @@ CREDENTIAL_SPECS = { **TWITTER_CREDENTIALS, **VERCEL_CREDENTIALS, **YOUTUBE_CREDENTIALS, + **WANDB_CREDENTIALS, **ZENDESK_CREDENTIALS, **ZOHO_CRM_CREDENTIALS, **ZOOM_CREDENTIALS, @@ -313,6 +315,7 @@ __all__ = [ "TWILIO_CREDENTIALS", "TWITTER_CREDENTIALS", "VERCEL_CREDENTIALS", + "WANDB_CREDENTIALS", "YOUTUBE_CREDENTIALS", "ZENDESK_CREDENTIALS", "ZOHO_CRM_CREDENTIALS", diff --git a/tools/src/aden_tools/credentials/wandb.py b/tools/src/aden_tools/credentials/wandb.py new file mode 100644 index 00000000..4f616ed8 --- /dev/null +++ b/tools/src/aden_tools/credentials/wandb.py @@ -0,0 +1,38 @@ +""" +Weights & Biases integration credentials. + +Contains credentials for the W&B GraphQL API. +Requires WANDB_API_KEY only — no host configuration needed. +""" + +from __future__ import annotations + +from .base import CredentialSpec + +WANDB_CREDENTIALS = { + "wandb_api_key": CredentialSpec( + env_var="WANDB_API_KEY", + tools=[ + "wandb_list_projects", + "wandb_list_runs", + "wandb_get_run", + "wandb_get_run_metrics", + "wandb_list_artifacts", + "wandb_get_summary", + ], + required=True, + startup_required=False, + help_url="https://wandb.ai/authorize", + description="Weights & Biases API Key", + direct_api_key_supported=True, + api_key_instructions="""To set up W&B API access: +1. Create a W&B account at https://wandb.ai +2. Go to https://wandb.ai/authorize +3. Copy your API key +4. Set environment variable: + export WANDB_API_KEY=your-api-key""", + health_check_endpoint="", + credential_id="wandb_api_key", + credential_key="api_key", + ), +} diff --git a/tools/src/aden_tools/tools/__init__.py b/tools/src/aden_tools/tools/__init__.py index a48cb9ed..74c19c17 100644 --- a/tools/src/aden_tools/tools/__init__.py +++ b/tools/src/aden_tools/tools/__init__.py @@ -133,6 +133,7 @@ from .twilio_tool import register_tools as register_twilio from .twitter_tool import register_tools as register_twitter from .vercel_tool import register_tools as register_vercel from .vision_tool import register_tools as register_vision +from .wandb_tool import register_tools as register_wandb try: from .web_scrape_tool import register_tools as register_web_scrape @@ -306,6 +307,7 @@ def _register_unverified( register_zendesk(mcp, credentials=credentials) register_zoho_crm(mcp, credentials=credentials) register_zoom(mcp, credentials=credentials) + register_wandb(mcp, credentials=credentials) register_freshdesk(mcp, credentials=credentials) diff --git a/tools/src/aden_tools/tools/wandb_tool/README.md b/tools/src/aden_tools/tools/wandb_tool/README.md new file mode 100644 index 00000000..6041c65f --- /dev/null +++ b/tools/src/aden_tools/tools/wandb_tool/README.md @@ -0,0 +1,94 @@ +# Weights & Biases Tool + +Query ML experiment runs, metrics, and artifacts from Weights & Biases using the W&B GraphQL API. + +## Tools + +| Tool | Description | +|------|-------------| +| `wandb_list_projects` | List all projects for a W&B entity (user or organization) | +| `wandb_list_runs` | List runs in a project with optional filters | +| `wandb_get_run` | Get full details of a specific run (config, state, summary) | +| `wandb_get_run_metrics` | Get sampled metric history for a run | +| `wandb_list_artifacts` | List output artifacts logged by a run | +| `wandb_get_summary` | Get final summary metrics for a run | + +## Setup + +Requires a W&B account and API key. + +1. Create a W&B account at [wandb.ai](https://wandb.ai) +2. Get your API key at [wandb.ai/authorize](https://wandb.ai/authorize) +3. Set the environment variable: + +```bash +export WANDB_API_KEY=your-api-key +``` + +Or configure via the Aden credential store as `wandb_api_key`. + +## Usage Examples + +### List projects for an entity + +```python +wandb_list_projects(entity="my-team") +``` + +### List recent runs in a project + +```python +wandb_list_runs(entity="my-team", project="my-project", per_page=10) +``` + +### Filter runs by state + +```python +wandb_list_runs( + entity="my-team", + project="my-project", + filters='{"state": "finished"}', +) +``` + +### Get details of a specific run + +```python +wandb_get_run(entity="my-team", project="my-project", run_id="abc123") +``` + +### Get training metrics for a run + +```python +wandb_get_run_metrics( + entity="my-team", + project="my-project", + run_id="abc123", + metric_keys="loss,accuracy", +) +``` + +### Get final summary metrics + +```python +wandb_get_summary(entity="my-team", project="my-project", run_id="abc123") +``` + +### List artifacts produced by a run + +```python +wandb_list_artifacts(entity="my-team", project="my-project", run_id="abc123") +``` + +## Error Handling + +All tools return error dicts on failure: + +```python +{"error": "Weights & Biases credentials not configured", "help": "Set WANDB_API_KEY..."} +{"error": "Invalid Weights & Biases API key"} +{"error": "Weights & Biases resource not found"} +{"error": "Request timed out"} +{"error": "filters must be a valid JSON string"} +{"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"} +``` diff --git a/tools/src/aden_tools/tools/wandb_tool/__init__.py b/tools/src/aden_tools/tools/wandb_tool/__init__.py new file mode 100644 index 00000000..4e0672b0 --- /dev/null +++ b/tools/src/aden_tools/tools/wandb_tool/__init__.py @@ -0,0 +1,5 @@ +"""Weights & Biases experiment tracking tool for Aden Tools.""" + +from .wandb_tool import register_tools + +__all__ = ["register_tools"] diff --git a/tools/src/aden_tools/tools/wandb_tool/wandb_tool.py b/tools/src/aden_tools/tools/wandb_tool/wandb_tool.py new file mode 100644 index 00000000..9e972c15 --- /dev/null +++ b/tools/src/aden_tools/tools/wandb_tool/wandb_tool.py @@ -0,0 +1,440 @@ +""" +Weights & Biases ML experiment tracking tool. + +Uses the W&B GraphQL API via httpx — no SDK dependency. + +Authentication: Bearer token (WANDB_API_KEY) +GraphQL endpoint: https://api.wandb.ai/graphql + +API Reference: https://github.com/wandb/wandb/blob/main/wandb/proto/wandb_internal.proto +""" + +from __future__ import annotations + +import json +import os +from typing import TYPE_CHECKING, Any + +import httpx +from fastmcp import FastMCP + +if TYPE_CHECKING: + from aden_tools.credentials import CredentialStoreAdapter + +GRAPHQL_URL = "https://api.wandb.ai/graphql" + + +def _get_creds( + credentials: CredentialStoreAdapter | None, +) -> tuple[str] | dict[str, Any]: + """Return (api_key,) or an error dict.""" + if credentials is not None: + api_key = credentials.get("wandb_api_key") + else: + api_key = os.getenv("WANDB_API_KEY") + + if not api_key: + return { + "error": "Weights & Biases credentials not configured", + "help": ( + "Set WANDB_API_KEY environment variable or configure via credential store. " + "Get your API key at https://wandb.ai/authorize" + ), + } + return (api_key,) + + +def _graphql( + api_key: str, + query: str, + variables: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Execute a GraphQL query and return the parsed response.""" + try: + resp = httpx.post( + GRAPHQL_URL, + headers={"Authorization": f"Bearer {api_key}"}, + json={"query": query, "variables": variables or {}}, + timeout=30.0, + ) + except httpx.TimeoutException: + return {"error": "Request timed out"} + except httpx.RequestError as e: + return {"error": f"Network error: {e}"} + + if resp.status_code == 401: + return {"error": "Invalid Weights & Biases API key"} + if resp.status_code == 403: + return {"error": "Insufficient permissions for this Weights & Biases resource"} + if resp.status_code >= 400: + try: + detail = resp.json().get("errors", [{}])[0].get("message", resp.text) + except Exception: + detail = resp.text + return {"error": f"Weights & Biases API error (HTTP {resp.status_code}): {detail}"} + + payload = resp.json() + if "errors" in payload: + msg = payload["errors"][0].get("message", str(payload["errors"])) + return {"error": f"Weights & Biases GraphQL error: {msg}"} + + return payload.get("data", {}) + + +def register_tools( + mcp: FastMCP, + credentials: CredentialStoreAdapter | None = None, +) -> None: + """Register Weights & Biases experiment tracking tools with the MCP server.""" + + @mcp.tool() + def wandb_list_projects(entity: str) -> dict: + """ + List all projects for a Weights & Biases entity (user or organization). + + Args: + entity: The W&B entity name (username or organization). + + Returns: + Dict containing the list of projects for the entity. + """ + creds = _get_creds(credentials) + if isinstance(creds, dict): + return creds + (api_key,) = creds + + query = """ + query ListProjects($entity: String!) { + projects(entityName: $entity) { + edges { + node { + name + description + createdAt + } + } + } + } + """ + data = _graphql(api_key, query, {"entity": entity}) + if "error" in data: + return data + + edges = data.get("projects", {}).get("edges", []) + return { + "entity": entity, + "projects": [ + { + "name": e["node"]["name"], + "description": e["node"].get("description", ""), + "created_at": e["node"].get("createdAt", ""), + } + for e in edges + ], + } + + @mcp.tool() + def wandb_list_runs( + entity: str, + project: str, + filters: str = "", + per_page: int = 50, + ) -> dict: + """ + List runs in a Weights & Biases project. + + Args: + entity: The W&B entity name (username or organization). + project: The project name. + filters: Optional JSON filter string to narrow results. + per_page: Number of runs to return (default 50). + + Returns: + Dict containing the list of runs in the project. + """ + parsed_filters = None + if filters: + try: + parsed_filters = json.loads(filters) + except json.JSONDecodeError: + return {"error": "filters must be a valid JSON string"} + + creds = _get_creds(credentials) + if isinstance(creds, dict): + return creds + (api_key,) = creds + + query = """ + query ListRuns($project: String!, $entity: String!, $perPage: Int!, $filters: JSONString) { + project(name: $project, entityName: $entity) { + runs(first: $perPage, filters: $filters) { + edges { + node { + name + id + state + createdAt + config + summaryMetrics + } + } + } + } + } + """ + variables: dict[str, Any] = {"project": project, "entity": entity, "perPage": per_page} + if parsed_filters is not None: + variables["filters"] = parsed_filters + data = _graphql(api_key, query, variables) + if "error" in data: + return data + + edges = data.get("project", {}).get("runs", {}).get("edges", []) + runs = [] + for e in edges: + node = e["node"] + try: + config = json.loads(node.get("config") or "{}") + except (json.JSONDecodeError, TypeError): + config = {} + runs.append( + { + "id": node.get("name"), + "display_name": node.get("id"), + "state": node.get("state"), + "created_at": node.get("createdAt"), + "config": config, + } + ) + return {"entity": entity, "project": project, "runs": runs} + + @mcp.tool() + def wandb_get_run(entity: str, project: str, run_id: str) -> dict: + """ + Get details of a specific Weights & Biases run. + + Args: + entity: The W&B entity name (username or organization). + project: The project name. + run_id: The run ID. + + Returns: + Dict containing full run details including config and metadata. + """ + if not run_id: + return {"error": "run_id is required"} + + creds = _get_creds(credentials) + if isinstance(creds, dict): + return creds + (api_key,) = creds + + query = """ + query GetRun($project: String!, $entity: String!, $run: String!) { + project(name: $project, entityName: $entity) { + run(name: $run) { + name + id + state + createdAt + config + summaryMetrics + tags + notes + } + } + } + """ + data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id}) + if "error" in data: + return data + + node = data.get("project", {}).get("run") + if not node: + return {"error": "Weights & Biases resource not found"} + + try: + config = json.loads(node.get("config") or "{}") + except (json.JSONDecodeError, TypeError): + config = {} + try: + summary = json.loads(node.get("summaryMetrics") or "{}") + except (json.JSONDecodeError, TypeError): + summary = {} + + return { + "id": node.get("name"), + "display_name": node.get("id"), + "state": node.get("state"), + "created_at": node.get("createdAt"), + "config": config, + "summary": summary, + "tags": node.get("tags") or [], + "notes": node.get("notes") or "", + } + + @mcp.tool() + def wandb_get_run_metrics( + entity: str, + project: str, + run_id: str, + metric_keys: str = "", + ) -> dict: + """ + Get sampled metrics history for a specific Weights & Biases run. + + Args: + entity: The W&B entity name (username or organization). + project: The project name. + run_id: The run ID. + metric_keys: Comma-separated metric keys to sample (e.g. "loss,accuracy"). + At least one key is required. + + Returns: + Dict containing sampled metric history per key. + """ + creds = _get_creds(credentials) + if isinstance(creds, dict): + return creds + (api_key,) = creds + + if not run_id: + return {"error": "run_id is required"} + if not metric_keys: + return {"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"} + + keys = [k.strip() for k in metric_keys.split(",") if k.strip()] + if not keys: + return {"error": "metric_keys must include at least one non-empty key"} + specs = json.dumps([{"key": k} for k in keys]) + + query = f""" + query GetRunMetrics($project: String!, $entity: String!, $run: String!) {{ + project(name: $project, entityName: $entity) {{ + run(name: $run) {{ + sampledHistory(specs: {specs}) + }} + }} + }} + """ + data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id}) + if "error" in data: + return data + + node = data.get("project", {}).get("run") + if not node: + return {"error": "Weights & Biases resource not found"} + + return { + "run_id": run_id, + "metric_keys": keys, + "history": node.get("sampledHistory", []), + } + + @mcp.tool() + def wandb_list_artifacts(entity: str, project: str, run_id: str) -> dict: + """ + List artifacts logged by a specific Weights & Biases run. + + Args: + entity: The W&B entity name (username or organization). + project: The project name. + run_id: The run ID. + + Returns: + Dict containing the list of output artifacts for the run. + """ + if not run_id: + return {"error": "run_id is required"} + + creds = _get_creds(credentials) + if isinstance(creds, dict): + return creds + (api_key,) = creds + + query = """ + query ListArtifacts($project: String!, $entity: String!, $run: String!) { + project(name: $project, entityName: $entity) { + run(name: $run) { + outputArtifacts { + edges { + node { + name + type + description + createdAt + } + } + } + } + } + } + """ + data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id}) + if "error" in data: + return data + + node = data.get("project", {}).get("run") + if not node: + return {"error": "Weights & Biases resource not found"} + + edges = node.get("outputArtifacts", {}).get("edges", []) + return { + "run_id": run_id, + "artifacts": [ + { + "name": e["node"]["name"], + "type": e["node"]["type"], + "description": e["node"].get("description", ""), + "created_at": e["node"].get("createdAt", ""), + } + for e in edges + ], + } + + @mcp.tool() + def wandb_get_summary(entity: str, project: str, run_id: str) -> dict: + """ + Get summary metrics for a specific Weights & Biases run. + + Args: + entity: The W&B entity name (username or organization). + project: The project name. + run_id: The run ID. + + Returns: + Dict containing the run's final summary metrics. + """ + if not run_id: + return {"error": "run_id is required"} + + creds = _get_creds(credentials) + if isinstance(creds, dict): + return creds + (api_key,) = creds + + query = """ + query GetSummary($project: String!, $entity: String!, $run: String!) { + project(name: $project, entityName: $entity) { + run(name: $run) { + summaryMetrics + } + } + } + """ + data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id}) + if "error" in data: + return data + + node = data.get("project", {}).get("run") + if not node: + return {"error": "Weights & Biases resource not found"} + + try: + summary = json.loads(node.get("summaryMetrics") or "{}") + except (json.JSONDecodeError, TypeError): + summary = {} + + # Filter out internal W&B keys + summary = {k: v for k, v in summary.items() if not k.startswith("_")} + return {"run_id": run_id, "summary": summary} diff --git a/tools/tests/tools/test_wandb_tool.py b/tools/tests/tools/test_wandb_tool.py new file mode 100644 index 00000000..a3b07175 --- /dev/null +++ b/tools/tests/tools/test_wandb_tool.py @@ -0,0 +1,317 @@ +"""Tests for wandb_tool - Weights & Biases integration (GraphQL/httpx).""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from fastmcp import FastMCP + +from aden_tools.tools.wandb_tool.wandb_tool import register_tools + +ENV = {"WANDB_API_KEY": "test-key-abcdefghij"} +_PATCH_POST = "aden_tools.tools.wandb_tool.wandb_tool.httpx.post" + + +def _mock_resp(data: Any, status_code: int = 200) -> MagicMock: + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = data + resp.text = str(data) + return resp + + +def _gql_ok(data: dict[str, Any]) -> MagicMock: + """Wrap data in the GraphQL envelope: {"data": {...}}.""" + return _mock_resp({"data": data}) + + +@pytest.fixture +def tool_fns(mcp: FastMCP) -> dict[str, Any]: + register_tools(mcp, credentials=None) + tools = mcp._tool_manager._tools + return {name: tools[name].fn for name in tools} + + +class TestWandbTool: + # --- Credential tests --- + + def test_missing_credentials_returns_error(self, tool_fns: dict[str, Any]) -> None: + """Missing WANDB_API_KEY must return a descriptive error dict with help.""" + with patch.dict("os.environ", {}, clear=True): + result = tool_fns["wandb_list_projects"](entity="test-entity") + assert "error" in result + assert "credentials not configured" in result["error"] + assert "help" in result + + # --- wandb_list_projects --- + + def test_wandb_list_projects_success(self, tool_fns: dict[str, Any]) -> None: + """wandb_list_projects returns projects list from GraphQL.""" + gql_data = { + "projects": { + "edges": [ + { + "node": { + "name": "proj-a", + "description": "Desc A", + "createdAt": "2024-01-01", + } + }, + { + "node": { + "name": "proj-b", + "description": "", + "createdAt": "2024-02-01", + } + }, + ] + } + } + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_gql_ok(gql_data)), + ): + result = tool_fns["wandb_list_projects"](entity="test-entity") + + assert result["entity"] == "test-entity" + assert len(result["projects"]) == 2 + assert result["projects"][0]["name"] == "proj-a" + + def test_wandb_list_projects_http_401(self, tool_fns: dict[str, Any]) -> None: + """HTTP 401 returns an invalid key error.""" + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_mock_resp({}, status_code=401)), + ): + result = tool_fns["wandb_list_projects"](entity="e") + assert result["error"] == "Invalid Weights & Biases API key" + + def test_wandb_list_projects_graphql_error(self, tool_fns: dict[str, Any]) -> None: + """GraphQL error block is surfaced as an error dict.""" + gql_err = {"errors": [{"message": "entity not found"}]} + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_mock_resp(gql_err)), + ): + result = tool_fns["wandb_list_projects"](entity="e") + assert "error" in result + assert "entity not found" in result["error"] + + # --- wandb_list_runs --- + + def test_wandb_list_runs_success(self, tool_fns: dict[str, Any]) -> None: + """wandb_list_runs returns runs list.""" + gql_data = { + "project": { + "runs": { + "edges": [ + { + "node": { + "name": "w854ckuu", + "id": "ferengi-directive-1", + "state": "finished", + "createdAt": "2024-01-01", + "config": "{}", + "summaryMetrics": "{}", + } + } + ] + } + } + } + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_gql_ok(gql_data)) as mock_post, + ): + result = tool_fns["wandb_list_runs"]( + entity="testentity", + project="testproject", + filters='{"key": "value"}', + per_page=50, + ) + + assert result["project"] == "testproject" + assert len(result["runs"]) == 1 + assert result["runs"][0]["id"] == "w854ckuu" + # Verify filters and per_page were forwarded in GraphQL variables + call_json = mock_post.call_args[1]["json"] + assert call_json["variables"]["perPage"] == 50 + assert call_json["variables"]["filters"] == {"key": "value"} + + def test_wandb_list_runs_invalid_filters_json(self, tool_fns: dict[str, Any]) -> None: + """wandb_list_runs returns error for invalid JSON filters before any HTTP call.""" + with patch.dict("os.environ", ENV): + result = tool_fns["wandb_list_runs"](entity="e", project="p", filters="not-json") + assert "error" in result + assert "valid JSON" in result["error"] + + # --- wandb_get_run --- + + def test_wandb_get_run_success(self, tool_fns: dict[str, Any]) -> None: + """wandb_get_run returns run details.""" + gql_data = { + "project": { + "run": { + "name": "run-123", + "id": "my-run", + "state": "finished", + "createdAt": "2024-01-01", + "config": '{"lr": 0.001}', + "summaryMetrics": '{"accuracy": 0.9}', + "tags": ["v1"], + "notes": "test", + } + } + } + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_gql_ok(gql_data)), + ): + result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="run-123") + + assert result["id"] == "run-123" + assert result["config"] == {"lr": 0.001} + assert result["summary"] == {"accuracy": 0.9} + + def test_wandb_get_run_missing_id(self, tool_fns: dict[str, Any]) -> None: + """wandb_get_run with empty run_id returns error before HTTP call.""" + result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="") + assert "error" in result + assert result["error"] == "run_id is required" + + def test_wandb_get_run_not_found(self, tool_fns: dict[str, Any]) -> None: + """wandb_get_run returns not-found error when run is null.""" + gql_data = {"project": {"run": None}} + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_gql_ok(gql_data)), + ): + result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="nope") + assert "error" in result + assert "not found" in result["error"] + + # --- wandb_get_run_metrics --- + + def test_wandb_get_run_metrics_success(self, tool_fns: dict[str, Any]) -> None: + """wandb_get_run_metrics returns sampled history.""" + gql_data = { + "project": { + "run": { + "sampledHistory": [{"loss": 0.5}, {"loss": 0.3}], + } + } + } + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_gql_ok(gql_data)), + ): + result = tool_fns["wandb_get_run_metrics"]( + entity="e", project="p", run_id="r1", metric_keys="loss" + ) + + assert result["run_id"] == "r1" + assert result["metric_keys"] == ["loss"] + assert result["history"] == [{"loss": 0.5}, {"loss": 0.3}] + + def test_wandb_get_run_metrics_missing_id(self, tool_fns: dict[str, Any]) -> None: + """wandb_get_run_metrics with empty run_id returns error.""" + with patch.dict("os.environ", ENV): + result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="") + assert "error" in result + assert result["error"] == "run_id is required" + + def test_wandb_get_run_metrics_missing_keys(self, tool_fns: dict[str, Any]) -> None: + """wandb_get_run_metrics with no metric_keys returns error.""" + with patch.dict("os.environ", ENV): + result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="r1") + assert "error" in result + assert "metric_keys is required" in result["error"] + + # --- wandb_list_artifacts --- + + def test_wandb_list_artifacts_success(self, tool_fns: dict[str, Any]) -> None: + """wandb_list_artifacts returns artifact list.""" + gql_data = { + "project": { + "run": { + "outputArtifacts": { + "edges": [ + { + "node": { + "name": "model:v0", + "type": "model", + "description": "", + "createdAt": "2024-01-01", + } + } + ] + } + } + } + } + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_gql_ok(gql_data)), + ): + result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="r1") + + assert result["run_id"] == "r1" + assert result["artifacts"][0]["name"] == "model:v0" + + def test_wandb_list_artifacts_missing_id(self, tool_fns: dict[str, Any]) -> None: + """wandb_list_artifacts with empty run_id returns error.""" + result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="") + assert "error" in result + assert result["error"] == "run_id is required" + + # --- wandb_get_summary --- + + def test_wandb_get_summary_success(self, tool_fns: dict[str, Any]) -> None: + """wandb_get_summary returns summary filtering out _-prefixed keys.""" + gql_data = { + "project": {"run": {"summaryMetrics": '{"accuracy": 0.9, "loss": 0.1, "_step": 5}'}} + } + with ( + patch.dict("os.environ", ENV), + patch(_PATCH_POST, return_value=_gql_ok(gql_data)), + ): + result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="r1") + + assert result["run_id"] == "r1" + assert result["summary"]["accuracy"] == 0.9 + assert "_step" not in result["summary"] + + def test_wandb_get_summary_missing_id(self, tool_fns: dict[str, Any]) -> None: + """wandb_get_summary with empty run_id returns error.""" + result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="") + assert "error" in result + assert result["error"] == "run_id is required" + + # --- Network/timeout errors --- + + def test_timeout_returns_error(self, tool_fns: dict[str, Any]) -> None: + """httpx.TimeoutException is caught and returns a timeout message.""" + with ( + patch.dict("os.environ", ENV), + patch( + "aden_tools.tools.wandb_tool.wandb_tool.httpx.post", + side_effect=httpx.TimeoutException("timeout"), + ), + ): + result = tool_fns["wandb_list_projects"](entity="e") + assert result["error"] == "Request timed out" + + def test_network_error_returns_error(self, tool_fns: dict[str, Any]) -> None: + """httpx.RequestError is caught and returns a network error message.""" + with ( + patch.dict("os.environ", ENV), + patch( + "aden_tools.tools.wandb_tool.wandb_tool.httpx.post", + side_effect=httpx.RequestError("Connection refused"), + ), + ): + result = tool_fns["wandb_list_projects"](entity="e") + assert "Network error" in result["error"]