* feat(tools): add Discord integration (#2913) - discord_list_guilds: list servers the bot is in - discord_list_channels: list channels for a guild - discord_send_message: send message to channel - discord_get_messages: get recent messages Auth: DISCORD_BOT_TOKEN, credential spec, health checker. Uses Discord API v10 (Bot token). Co-authored-by: Cursor <cursoragent@cursor.com> * style: apply ruff format to discord tool files Co-authored-by: Cursor <cursoragent@cursor.com> * feat(discord): add rate limit handling, message validation, channel filter - Rate limit (429): return clear error with retry_after from API - Message length: validate before send, max 2000 chars per Discord limit - Channel filter: text_only param (default True) for list_channels - Add 6 new tests for rate limit, validation, filtering Co-authored-by: Cursor <cursoragent@cursor.com> * feat(discord): add retry on 429 rate limit - Retry up to 2 times using Discord's retry_after - Cap wait at 60s, fallback to exponential backoff if no retry_after - Add _request_with_retry helper for all API calls - Add 3 tests: retry then success, retry exhausted, tool-level retry Co-authored-by: Cursor <cursoragent@cursor.com> * fix(discord): remove unused DISCORD_API_BASE import Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: mishrapravin114 <mishrapravin114@users.noreply.github.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -37,6 +37,7 @@ Credential categories:
|
||||
- search.py: Search tool credentials (brave_search, google_search, etc.)
|
||||
- email.py: Email provider credentials (resend, google/gmail)
|
||||
- apollo.py: Apollo.io API credentials
|
||||
- discord.py: Discord bot credentials
|
||||
- github.py: GitHub API credentials
|
||||
- hubspot.py: HubSpot CRM credentials
|
||||
- slack.py: Slack workspace credentials
|
||||
@@ -57,6 +58,7 @@ from .base import CredentialError, CredentialSpec
|
||||
from .bigquery import BIGQUERY_CREDENTIALS
|
||||
from .browser import get_aden_auth_url, get_aden_setup_url, open_browser
|
||||
from .calcom import CALCOM_CREDENTIALS
|
||||
from .discord import DISCORD_CREDENTIALS
|
||||
from .email import EMAIL_CREDENTIALS
|
||||
from .gcp_vision import GCP_VISION_CREDENTIALS
|
||||
from .github import GITHUB_CREDENTIALS
|
||||
@@ -87,6 +89,7 @@ CREDENTIAL_SPECS = {
|
||||
**EMAIL_CREDENTIALS,
|
||||
**GCP_VISION_CREDENTIALS,
|
||||
**APOLLO_CREDENTIALS,
|
||||
**DISCORD_CREDENTIALS,
|
||||
**GITHUB_CREDENTIALS,
|
||||
**GOOGLE_MAPS_CREDENTIALS,
|
||||
**HUBSPOT_CREDENTIALS,
|
||||
@@ -137,4 +140,5 @@ __all__ = [
|
||||
"TELEGRAM_CREDENTIALS",
|
||||
"BIGQUERY_CREDENTIALS",
|
||||
"CALCOM_CREDENTIALS",
|
||||
"DISCORD_CREDENTIALS",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Discord tool credentials.
|
||||
|
||||
Contains credentials for Discord bot integration.
|
||||
"""
|
||||
|
||||
from .base import CredentialSpec
|
||||
|
||||
DISCORD_CREDENTIALS = {
|
||||
"discord": CredentialSpec(
|
||||
env_var="DISCORD_BOT_TOKEN",
|
||||
tools=[
|
||||
"discord_list_guilds",
|
||||
"discord_list_channels",
|
||||
"discord_send_message",
|
||||
"discord_get_messages",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://discord.com/developers/applications",
|
||||
description="Discord Bot Token",
|
||||
aden_supported=True,
|
||||
aden_provider_name="discord",
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Discord Bot Token:
|
||||
1. Go to https://discord.com/developers/applications
|
||||
2. Create a new application or select an existing one
|
||||
3. Go to the "Bot" section in the sidebar
|
||||
4. Click "Add Bot" if you haven't already
|
||||
5. Copy the token (click "Reset Token" if needed)
|
||||
6. Invite the bot to your server via OAuth2 → URL Generator
|
||||
- Scopes: bot
|
||||
- Permissions: Send Messages, Read Message History, View Channels""",
|
||||
health_check_endpoint="https://discord.com/api/v10/users/@me",
|
||||
health_check_method="GET",
|
||||
credential_id="discord",
|
||||
credential_key="access_token",
|
||||
),
|
||||
}
|
||||
@@ -492,6 +492,63 @@ class GitHubHealthChecker:
|
||||
)
|
||||
|
||||
|
||||
class DiscordHealthChecker:
|
||||
"""Health checker for Discord bot tokens."""
|
||||
|
||||
ENDPOINT = "https://discord.com/api/v10/users/@me"
|
||||
TIMEOUT = 10.0
|
||||
|
||||
def check(self, bot_token: str) -> HealthCheckResult:
|
||||
"""
|
||||
Validate Discord bot token by fetching the bot's user info.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.TIMEOUT) as client:
|
||||
response = client.get(
|
||||
self.ENDPOINT,
|
||||
headers={"Authorization": f"Bot {bot_token}"},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
username = data.get("username", "unknown")
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message=f"Discord bot token valid (bot: {username})",
|
||||
details={"username": username, "id": data.get("id")},
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Discord bot token is invalid",
|
||||
details={"status_code": 401},
|
||||
)
|
||||
elif response.status_code == 403:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Discord bot token lacks required permissions",
|
||||
details={"status_code": 403},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Discord API returned status {response.status_code}",
|
||||
details={"status_code": response.status_code},
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Discord API request timed out",
|
||||
details={"error": "timeout"},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Failed to connect to Discord API: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
class ResendHealthChecker:
|
||||
"""Health checker for Resend API credentials."""
|
||||
|
||||
@@ -624,6 +681,7 @@ class GoogleMapsHealthChecker:
|
||||
|
||||
# Registry of health checkers
|
||||
HEALTH_CHECKERS: dict[str, CredentialHealthChecker] = {
|
||||
"discord": DiscordHealthChecker(),
|
||||
"hubspot": HubSpotHealthChecker(),
|
||||
"brave_search": BraveSearchHealthChecker(),
|
||||
"google_calendar_oauth": GoogleCalendarHealthChecker(),
|
||||
|
||||
@@ -26,6 +26,7 @@ from .bigquery_tool import register_tools as register_bigquery
|
||||
from .calcom_tool import register_tools as register_calcom
|
||||
from .calendar_tool import register_tools as register_calendar
|
||||
from .csv_tool import register_tools as register_csv
|
||||
from .discord_tool import register_tools as register_discord
|
||||
|
||||
# Security scanning tools
|
||||
from .dns_security_scanner import register_tools as register_dns_security_scanner
|
||||
@@ -108,6 +109,7 @@ def register_all_tools(
|
||||
register_serpapi(mcp, credentials=credentials)
|
||||
register_calendar(mcp, credentials=credentials)
|
||||
register_calcom(mcp, credentials=credentials)
|
||||
register_discord(mcp, credentials=credentials)
|
||||
register_slack(mcp, credentials=credentials)
|
||||
register_razorpay(mcp, credentials=credentials)
|
||||
register_telegram(mcp, credentials=credentials)
|
||||
@@ -182,6 +184,10 @@ def register_all_tools(
|
||||
"calcom_list_schedules",
|
||||
"calcom_list_event_types",
|
||||
"calcom_get_event_type",
|
||||
"discord_list_guilds",
|
||||
"discord_list_channels",
|
||||
"discord_send_message",
|
||||
"discord_get_messages",
|
||||
"github_list_repos",
|
||||
"github_get_repo",
|
||||
"github_search_repos",
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
# Discord Tool
|
||||
|
||||
Send messages and interact with Discord servers via the Discord API.
|
||||
|
||||
## Supported Actions
|
||||
|
||||
- **discord_list_guilds** – List guilds (servers) the bot is a member of
|
||||
- **discord_list_channels** – List channels for a guild (optional `text_only` filter)
|
||||
- **discord_send_message** – Send a message to a channel (validates 2000-char limit)
|
||||
- **discord_get_messages** – Get recent messages from a channel
|
||||
|
||||
## Limits & Validation
|
||||
|
||||
- **Message length**: Max 2000 characters (validated before sending)
|
||||
- **Rate limits**: Automatically retries up to 2 times on 429 using Discord's `retry_after`; returns clear error when exhausted
|
||||
- **Channel filtering**: `discord_list_channels` defaults to text channels only; use `text_only=False` for all types
|
||||
|
||||
## Setup
|
||||
|
||||
1. Create a Discord application at [Discord Developer Portal](https://discord.com/developers/applications).
|
||||
|
||||
2. Create a bot:
|
||||
- Go to **Bot** section
|
||||
- Add a bot and copy the token
|
||||
|
||||
3. Invite the bot to your server:
|
||||
- Go to **OAuth2** → **URL Generator**
|
||||
- Scopes: `bot`
|
||||
- Bot permissions: `Send Messages`, `Read Message History`, `View Channels`, `Read Messages/View Channels`
|
||||
- Use the generated URL to invite the bot
|
||||
|
||||
4. Set the environment variable:
|
||||
```bash
|
||||
export DISCORD_BOT_TOKEN=your_bot_token_here
|
||||
```
|
||||
|
||||
## Getting IDs
|
||||
|
||||
Enable **Developer Mode** in Discord (User Settings → Advanced → Developer Mode).
|
||||
Then right-click a server or channel to **Copy ID**.
|
||||
|
||||
## Use Case
|
||||
|
||||
Example: "When a production incident is resolved, post a short summary to our #incidents Discord channel."
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Discord tool package for Aden Tools."""
|
||||
|
||||
from .discord_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Discord Tool - Send messages and interact with Discord servers via Discord API.
|
||||
|
||||
Supports:
|
||||
- Bot tokens (DISCORD_BOT_TOKEN)
|
||||
|
||||
API Reference: https://discord.com/developers/docs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_MESSAGE_LENGTH = 2000 # Discord API limit
|
||||
# Channel types: 0 = GUILD_TEXT, 5 = GUILD_ANNOUNCEMENT (both support messages)
|
||||
TEXT_CHANNEL_TYPES = (0, 5)
|
||||
MAX_RETRIES = 2 # 3 total attempts on 429
|
||||
MAX_RETRY_WAIT = 60 # cap wait at 60s
|
||||
|
||||
|
||||
class _DiscordClient:
|
||||
"""Internal client wrapping Discord API calls."""
|
||||
|
||||
def __init__(self, bot_token: str):
|
||||
self._token = bot_token
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bot {self._token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Make HTTP request with retry on 429 rate limit."""
|
||||
request_kwargs = {"headers": self._headers, "timeout": 30.0, **kwargs}
|
||||
for attempt in range(MAX_RETRIES + 1):
|
||||
response = httpx.request(method, url, **request_kwargs)
|
||||
if response.status_code == 429 and attempt < MAX_RETRIES:
|
||||
try:
|
||||
data = response.json()
|
||||
wait = min(float(data.get("retry_after", 1)), MAX_RETRY_WAIT)
|
||||
except Exception:
|
||||
wait = min(2**attempt, MAX_RETRY_WAIT)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
return self._handle_response(response)
|
||||
return self._handle_response(response)
|
||||
|
||||
def _handle_response(self, response: httpx.Response) -> dict[str, Any]:
|
||||
"""Handle Discord API response format."""
|
||||
if response.status_code == 204:
|
||||
return {"success": True}
|
||||
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
data = response.json()
|
||||
retry_after = data.get("retry_after", 60)
|
||||
message = data.get("message", "Rate limit exceeded")
|
||||
except Exception:
|
||||
retry_after = 60
|
||||
message = "Rate limit exceeded"
|
||||
return {
|
||||
"error": f"Discord rate limit exceeded. Retry after {retry_after}s",
|
||||
"retry_after": retry_after,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
data = response.json()
|
||||
message = data.get("message", response.text)
|
||||
except Exception:
|
||||
message = response.text
|
||||
return {"error": f"HTTP {response.status_code}: {message}"}
|
||||
|
||||
return response.json()
|
||||
|
||||
def list_guilds(self) -> dict[str, Any]:
|
||||
"""List guilds (servers) the bot is a member of."""
|
||||
return self._request_with_retry("GET", f"{DISCORD_API_BASE}/users/@me/guilds")
|
||||
|
||||
def list_channels(self, guild_id: str, text_only: bool = True) -> dict[str, Any]:
|
||||
"""List channels for a guild. Optionally filter to text channels only."""
|
||||
result = self._request_with_retry("GET", f"{DISCORD_API_BASE}/guilds/{guild_id}/channels")
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return result
|
||||
if text_only:
|
||||
result = [c for c in result if c.get("type") in TEXT_CHANNEL_TYPES]
|
||||
return result
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
channel_id: str,
|
||||
content: str,
|
||||
*,
|
||||
tts: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a channel."""
|
||||
body: dict[str, Any] = {"content": content, "tts": tts}
|
||||
return self._request_with_retry(
|
||||
"POST",
|
||||
f"{DISCORD_API_BASE}/channels/{channel_id}/messages",
|
||||
json=body,
|
||||
)
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
channel_id: str,
|
||||
limit: int = 50,
|
||||
before: str | None = None,
|
||||
after: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get recent messages from a channel."""
|
||||
params: dict[str, Any] = {"limit": min(limit, 100)}
|
||||
if before:
|
||||
params["before"] = before
|
||||
if after:
|
||||
params["after"] = after
|
||||
return self._request_with_retry(
|
||||
"GET",
|
||||
f"{DISCORD_API_BASE}/channels/{channel_id}/messages",
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register Discord tools with the MCP server."""
|
||||
|
||||
def _get_token() -> str | None:
|
||||
"""Get Discord bot token from credential manager or environment."""
|
||||
if credentials is not None:
|
||||
token = credentials.get("discord")
|
||||
if token is not None and not isinstance(token, str):
|
||||
raise TypeError(
|
||||
f"Expected string from credentials.get('discord'), got {type(token).__name__}"
|
||||
)
|
||||
return token
|
||||
return os.getenv("DISCORD_BOT_TOKEN")
|
||||
|
||||
def _get_client() -> _DiscordClient | dict[str, str]:
|
||||
"""Get a Discord client, or return an error dict if no credentials."""
|
||||
token = _get_token()
|
||||
if not token:
|
||||
return {
|
||||
"error": "Discord credentials not configured",
|
||||
"help": (
|
||||
"Set DISCORD_BOT_TOKEN environment variable or configure via credential store"
|
||||
),
|
||||
}
|
||||
return _DiscordClient(token)
|
||||
|
||||
@mcp.tool()
|
||||
def discord_list_guilds() -> dict:
|
||||
"""
|
||||
List Discord guilds (servers) the bot is a member of.
|
||||
|
||||
Returns guild IDs and names. Use guild IDs with discord_list_channels.
|
||||
|
||||
Returns:
|
||||
Dict with list of guilds or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.list_guilds()
|
||||
if "error" in result:
|
||||
return result
|
||||
return {"guilds": result, "success": True}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def discord_list_channels(guild_id: str, text_only: bool = True) -> dict:
|
||||
"""
|
||||
List channels for a Discord guild (server).
|
||||
|
||||
Args:
|
||||
guild_id: Guild (server) ID. Enable Developer Mode in Discord and
|
||||
right-click the server to copy ID. Or use discord_list_guilds.
|
||||
text_only: If True (default), return only text channels (type 0 and 5).
|
||||
Set False to include voice, category, and other channel types.
|
||||
|
||||
Returns:
|
||||
Dict with list of channels or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.list_channels(guild_id, text_only=text_only)
|
||||
if "error" in result:
|
||||
return result
|
||||
return {"channels": result, "success": True}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def discord_send_message(channel_id: str, content: str, tts: bool = False) -> dict:
|
||||
"""
|
||||
Send a message to a Discord channel.
|
||||
|
||||
Args:
|
||||
channel_id: Channel ID (right-click channel > Copy ID in Dev Mode)
|
||||
content: Message text (max 2000 characters)
|
||||
tts: Whether to use text-to-speech
|
||||
|
||||
Returns:
|
||||
Dict with message details or error
|
||||
"""
|
||||
if len(content) > MAX_MESSAGE_LENGTH:
|
||||
return {
|
||||
"error": f"Message exceeds {MAX_MESSAGE_LENGTH} character limit",
|
||||
"max_length": MAX_MESSAGE_LENGTH,
|
||||
"provided": len(content),
|
||||
}
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.send_message(channel_id, content, tts=tts)
|
||||
if "error" in result:
|
||||
return result
|
||||
return {"success": True, "message": result}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def discord_get_messages(
|
||||
channel_id: str,
|
||||
limit: int = 50,
|
||||
before: str | None = None,
|
||||
after: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get recent messages from a Discord channel.
|
||||
|
||||
Args:
|
||||
channel_id: Channel ID
|
||||
limit: Max messages to return (1-100, default 50)
|
||||
before: Message ID to get messages before (for pagination)
|
||||
after: Message ID to get messages after (for pagination)
|
||||
|
||||
Returns:
|
||||
Dict with list of messages or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.get_messages(channel_id, limit=limit, before=before, after=after)
|
||||
if "error" in result:
|
||||
return result
|
||||
return {"messages": result, "success": True}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
@@ -7,6 +7,7 @@ import httpx
|
||||
from aden_tools.credentials.health_check import (
|
||||
HEALTH_CHECKERS,
|
||||
AnthropicHealthChecker,
|
||||
DiscordHealthChecker,
|
||||
GitHubHealthChecker,
|
||||
GoogleCalendarHealthChecker,
|
||||
GoogleMapsHealthChecker,
|
||||
@@ -49,6 +50,11 @@ class TestHealthCheckerRegistry:
|
||||
assert "google_calendar_oauth" in HEALTH_CHECKERS
|
||||
assert isinstance(HEALTH_CHECKERS["google_calendar_oauth"], GoogleCalendarHealthChecker)
|
||||
|
||||
def test_discord_registered(self):
|
||||
"""DiscordHealthChecker is registered in HEALTH_CHECKERS."""
|
||||
assert "discord" in HEALTH_CHECKERS
|
||||
assert isinstance(HEALTH_CHECKERS["discord"], DiscordHealthChecker)
|
||||
|
||||
def test_all_expected_checkers_registered(self):
|
||||
"""All expected health checkers are in the registry."""
|
||||
expected = {
|
||||
@@ -61,6 +67,7 @@ class TestHealthCheckerRegistry:
|
||||
"resend",
|
||||
"google_calendar_oauth",
|
||||
"slack",
|
||||
"discord",
|
||||
}
|
||||
assert set(HEALTH_CHECKERS.keys()) == expected
|
||||
|
||||
|
||||
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
Tests for Discord tool.
|
||||
|
||||
Covers:
|
||||
- _DiscordClient methods (list_guilds, list_channels, send_message, get_messages)
|
||||
- Error handling (401, 403, 404, timeout)
|
||||
- Credential retrieval (CredentialStoreAdapter vs env var)
|
||||
- All 4 MCP tool functions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from aden_tools.tools.discord_tool.discord_tool import (
|
||||
MAX_MESSAGE_LENGTH,
|
||||
MAX_RETRIES,
|
||||
_DiscordClient,
|
||||
register_tools,
|
||||
)
|
||||
|
||||
# --- _DiscordClient tests ---
|
||||
|
||||
|
||||
class TestDiscordClient:
|
||||
def setup_method(self):
|
||||
self.client = _DiscordClient("test-bot-token")
|
||||
|
||||
def test_headers(self):
|
||||
headers = self.client._headers
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["Authorization"] == "Bot test-bot-token"
|
||||
|
||||
def test_handle_response_success(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.json.return_value = {"id": "123", "username": "test-bot"}
|
||||
assert self.client._handle_response(response) == {"id": "123", "username": "test-bot"}
|
||||
|
||||
def test_handle_response_204(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 204
|
||||
result = self.client._handle_response(response)
|
||||
assert result == {"success": True}
|
||||
|
||||
def test_handle_response_rate_limit_429(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 429
|
||||
response.json.return_value = {"message": "Rate limit", "retry_after": 2.5}
|
||||
response.text = '{"message": "Rate limit", "retry_after": 2.5}'
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert "rate limit" in result["error"].lower()
|
||||
assert result["retry_after"] == 2.5
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code",
|
||||
[401, 403, 404, 500],
|
||||
)
|
||||
def test_handle_response_errors(self, status_code):
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
response.json.return_value = {"message": "Test error"}
|
||||
response.text = "Test error"
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert str(status_code) in result["error"]
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_list_guilds(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "g1", "name": "Test Server"},
|
||||
{"id": "g2", "name": "Another Server"},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self.client.list_guilds()
|
||||
mock_request.assert_called_once()
|
||||
assert mock_request.call_args[0][0] == "GET"
|
||||
assert "users/@me/guilds" in mock_request.call_args[0][1]
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "Test Server"
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_list_channels_text_only_default(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "c1", "name": "general", "type": 0},
|
||||
{"id": "c2", "name": "incidents", "type": 0},
|
||||
{"id": "c3", "name": "voice-chat", "type": 2},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self.client.list_channels("guild-123")
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "general"
|
||||
assert result[1]["name"] == "incidents"
|
||||
assert not any(c["type"] == 2 for c in result)
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_list_channels_all_types(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "c1", "name": "general", "type": 0},
|
||||
{"id": "c2", "name": "voice-chat", "type": 2},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self.client.list_channels("guild-123", text_only=False)
|
||||
assert len(result) == 2
|
||||
assert result[0]["type"] == 0
|
||||
assert result[1]["type"] == 2
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_send_message(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"id": "m123",
|
||||
"channel_id": "c1",
|
||||
"content": "Hello world",
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self.client.send_message("c1", "Hello world")
|
||||
mock_request.assert_called_once()
|
||||
assert mock_request.call_args[0][0] == "POST"
|
||||
assert "channels/c1/messages" in mock_request.call_args[0][1]
|
||||
assert result["content"] == "Hello world"
|
||||
assert result["channel_id"] == "c1"
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_get_messages(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "m1", "content": "First"},
|
||||
{"id": "m2", "content": "Second"},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self.client.get_messages("c1", limit=10)
|
||||
mock_request.assert_called_once()
|
||||
assert mock_request.call_args[1]["params"] == {"limit": 10}
|
||||
assert len(result) == 2
|
||||
assert result[0]["content"] == "First"
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.time.sleep")
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_retry_on_429_then_success(self, mock_request, mock_sleep):
|
||||
mock_request.side_effect = [
|
||||
MagicMock(
|
||||
status_code=429,
|
||||
json=MagicMock(return_value={"retry_after": 0.01}),
|
||||
text="{}",
|
||||
),
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(return_value=[{"id": "g1", "name": "Server"}]),
|
||||
),
|
||||
]
|
||||
result = self.client.list_guilds()
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Server"
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called_once_with(0.01)
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.time.sleep")
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_retry_exhausted_returns_error(self, mock_request, mock_sleep):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=429,
|
||||
json=MagicMock(return_value={"retry_after": 0.01}),
|
||||
text="{}",
|
||||
)
|
||||
result = self.client.list_guilds()
|
||||
assert "error" in result
|
||||
assert "rate limit" in result["error"].lower()
|
||||
assert mock_request.call_count == MAX_RETRIES + 1
|
||||
|
||||
|
||||
# --- Tool registration tests ---
|
||||
|
||||
|
||||
class TestDiscordListGuildsTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.return_value = "test-token"
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_list_guilds_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(return_value=[{"id": "g1", "name": "Test Server"}]),
|
||||
)
|
||||
result = self._fn("discord_list_guilds")()
|
||||
assert result["success"] is True
|
||||
assert len(result["guilds"]) == 1
|
||||
assert result["guilds"][0]["name"] == "Test Server"
|
||||
|
||||
def test_list_guilds_no_credentials(self):
|
||||
mcp = MagicMock()
|
||||
fns = []
|
||||
mcp.tool.return_value = lambda fn: fns.append(fn) or fn
|
||||
register_tools(mcp, credentials=None)
|
||||
with patch.dict("os.environ", {"DISCORD_BOT_TOKEN": ""}, clear=False):
|
||||
result = next(f for f in fns if f.__name__ == "discord_list_guilds")()
|
||||
assert "error" in result
|
||||
assert "not configured" in result["error"]
|
||||
|
||||
|
||||
class TestDiscordListChannelsTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.return_value = "test-token"
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_list_channels_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "c1", "name": "general", "type": 0},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self._fn("discord_list_channels")("guild-123")
|
||||
assert result["success"] is True
|
||||
assert len(result["channels"]) == 1
|
||||
assert result["channels"][0]["name"] == "general"
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_list_channels_text_only_filter(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "c1", "name": "general", "type": 0},
|
||||
{"id": "c2", "name": "voice", "type": 2},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self._fn("discord_list_channels")("guild-123", text_only=True)
|
||||
assert result["success"] is True
|
||||
assert len(result["channels"]) == 1
|
||||
assert result["channels"][0]["name"] == "general"
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_list_channels_error(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=404,
|
||||
json=MagicMock(return_value={"message": "Unknown Guild"}),
|
||||
text="Unknown Guild",
|
||||
)
|
||||
result = self._fn("discord_list_channels")("bad-guild")
|
||||
assert "error" in result
|
||||
assert "404" in result["error"]
|
||||
|
||||
|
||||
class TestDiscordSendMessageTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.return_value = "test-token"
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_send_message_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"id": "m123",
|
||||
"channel_id": "c1",
|
||||
"content": "Incident resolved",
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self._fn("discord_send_message")("c1", "Incident resolved")
|
||||
assert result["success"] is True
|
||||
assert result["message"]["content"] == "Incident resolved"
|
||||
|
||||
def test_send_message_length_validation(self):
|
||||
long_content = "x" * (MAX_MESSAGE_LENGTH + 1)
|
||||
result = self._fn("discord_send_message")("c1", long_content)
|
||||
assert "error" in result
|
||||
assert str(MAX_MESSAGE_LENGTH) in result["error"]
|
||||
assert result["max_length"] == MAX_MESSAGE_LENGTH
|
||||
assert result["provided"] == MAX_MESSAGE_LENGTH + 1
|
||||
|
||||
def test_send_message_exactly_at_limit(self):
|
||||
content = "x" * MAX_MESSAGE_LENGTH
|
||||
with patch("aden_tools.tools.discord_tool.discord_tool.httpx.request") as mock_request:
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(return_value={"id": "m1", "channel_id": "c1", "content": content}),
|
||||
)
|
||||
result = self._fn("discord_send_message")("c1", content)
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_send_message_rate_limit_429_exhausted(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=429,
|
||||
json=MagicMock(return_value={"message": "Rate limit", "retry_after": 5}),
|
||||
text='{"message": "Rate limit", "retry_after": 5}',
|
||||
)
|
||||
result = self._fn("discord_send_message")("c1", "Hello")
|
||||
assert "error" in result
|
||||
assert "rate limit" in result["error"].lower()
|
||||
assert result.get("retry_after") == 5
|
||||
assert mock_request.call_count == MAX_RETRIES + 1
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_send_message_rate_limit_then_success(self, mock_request):
|
||||
mock_request.side_effect = [
|
||||
MagicMock(
|
||||
status_code=429,
|
||||
json=MagicMock(return_value={"retry_after": 0.01}),
|
||||
text="{}",
|
||||
),
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(return_value={"id": "m1", "channel_id": "c1", "content": "Hi"}),
|
||||
),
|
||||
]
|
||||
result = self._fn("discord_send_message")("c1", "Hi")
|
||||
assert result["success"] is True
|
||||
assert result["message"]["content"] == "Hi"
|
||||
assert mock_request.call_count == 2
|
||||
|
||||
|
||||
class TestDiscordGetMessagesTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.return_value = "test-token"
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.discord_tool.discord_tool.httpx.request")
|
||||
def test_get_messages_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "m1", "content": "First message"},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self._fn("discord_get_messages")("c1", limit=10)
|
||||
assert result["success"] is True
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0]["content"] == "First message"
|
||||
|
||||
|
||||
# --- Credential spec tests ---
|
||||
|
||||
|
||||
class TestCredentialSpec:
|
||||
def test_discord_credential_spec_exists(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
assert "discord" in CREDENTIAL_SPECS
|
||||
|
||||
def test_discord_spec_env_var(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS["discord"]
|
||||
assert spec.env_var == "DISCORD_BOT_TOKEN"
|
||||
|
||||
def test_discord_spec_tools(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS["discord"]
|
||||
assert "discord_list_guilds" in spec.tools
|
||||
assert "discord_list_channels" in spec.tools
|
||||
assert "discord_send_message" in spec.tools
|
||||
assert "discord_get_messages" in spec.tools
|
||||
assert len(spec.tools) == 4
|
||||
Reference in New Issue
Block a user