Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8f55170c1e | |||
| 31a98a5f95 | |||
| 7667b773f2 | |||
| 49560260de | |||
| bb3c69cff1 | |||
| b15dd2f623 | |||
| ce308312ae | |||
| f757c724cc | |||
| a4c758403e | |||
| b48465b778 | |||
| d3baaaab24 | |||
| ad6077bd7b | |||
| c2e7afeb5e | |||
| d87dfca1ab | |||
| a79d7de482 | |||
| e5e57302fa | |||
| c69cf1aea5 | |||
| 2f4cd8c36f | |||
| 6f571e6d00 | |||
| 31bc84106f | |||
| fd79dceb0f | |||
| ad50139d67 | |||
| 12fb40c110 | |||
| 738e469d96 | |||
| 80ccbcc827 | |||
| 08fac31a9d | |||
| 89ccd66fb9 | |||
| 7c47e367de | |||
| b8741bf94c | |||
| c90dcbb32f | |||
| 1ccfdbbf7d | |||
| 4ad0d0e077 |
+16
-33
@@ -74,42 +74,25 @@ Install `make` using:
|
||||
|
||||
```bash
|
||||
sudo apt install make
|
||||
```
|
||||
|
||||
uv: command not found
|
||||
|
||||
Install uv using:
|
||||
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source ~/.bashrc
|
||||
|
||||
ruff: not found
|
||||
|
||||
If linting fails due to a missing ruff command, install it with:
|
||||
|
||||
uv tool install ruff
|
||||
|
||||
WSL Path Recommendation
|
||||
|
||||
When using WSL, it is recommended to clone the repository inside your Linux home directory (e.g., ~/hive) instead of under /mnt/c/... to avoid potential performance and permission issues.
|
||||
|
||||
|
||||
---
|
||||
|
||||
# ✅ Why This Is Good
|
||||
|
||||
- Clear
|
||||
- Professional tone
|
||||
- No unnecessary explanation
|
||||
- Under micro-fix size
|
||||
- Based on real contributor experience
|
||||
- Won’t annoy maintainers
|
||||
|
||||
---
|
||||
|
||||
Now:
|
||||
### `uv: command not found`
|
||||
Install `uv` using:
|
||||
|
||||
```bash
|
||||
git checkout -b docs/setup-troubleshooting
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source ~/.bashrc
|
||||
```
|
||||
|
||||
### `ruff: not found`
|
||||
If linting fails due to a missing `ruff` command, install it with:
|
||||
|
||||
```bash
|
||||
uv tool install ruff
|
||||
```
|
||||
|
||||
### WSL Path Recommendation
|
||||
When using WSL, it is recommended to clone the repository inside your Linux home directory (e.g., ~/hive) instead of under /mnt/c/... to avoid potential performance and permission issues.
|
||||
|
||||
## Commit Convention
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ This sets up:
|
||||
- **LLM provider** - Interactive default model configuration
|
||||
- All required Python dependencies with `uv`
|
||||
|
||||
- At last, it will initiate the open hive interface in your browser
|
||||
- Finally, it will open the Hive interface in your browser
|
||||
|
||||
> **Tip:** To reopen the dashboard later, run `hive open` from the project directory.
|
||||
|
||||
@@ -125,18 +125,18 @@ Type the agent you want to build in the home input box
|
||||
|
||||
### Use Template Agents
|
||||
|
||||
Click "Try a sample agent" and check the templates. You can run a templates directly or choose to build your version on top of the existing template.
|
||||
Click "Try a sample agent" and check the templates. You can run a template directly or choose to build your version on top of the existing template.
|
||||
|
||||
### Run Agents
|
||||
|
||||
Now you can run an agent by selectiing the agent (either an existing agent or example agent). You can click the Run button on the top left, or talk to the queen agent and it can run the agent for you.
|
||||
Now you can run an agent by selecting the agent (either an existing agent or example agent). You can click the Run button on the top left, or talk to the queen agent and it can run the agent for you.
|
||||
|
||||
<img width="2500" height="1214" alt="Image" src="https://github.com/user-attachments/assets/71c38206-2ad5-49aa-bde8-6698d0bc55f5" />
|
||||
|
||||
## Features
|
||||
|
||||
- **Browser-Use** - Control the browser on your computer to achieve hard tasks
|
||||
- **Parallel Execution** - Execute the generated graph in parallel. This way you can have multiple agent compelteing the jobs for you
|
||||
- **Parallel Execution** - Execute the generated graph in parallel. This way you can have multiple agents completing the jobs for you
|
||||
- **[Goal-Driven Generation](docs/key_concepts/goals_outcome.md)** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
|
||||
- **[Adaptiveness](docs/key_concepts/evolution.md)** - Framework captures failures, calibrates according to the objectives, and evolves the agent graph
|
||||
- **[Dynamic Node Connections](docs/key_concepts/graph.md)** - No predefined edges; connection code is generated by any capable LLM based on your goals
|
||||
|
||||
+2
-2
@@ -39,8 +39,8 @@ We consider security research conducted in accordance with this policy to be:
|
||||
## Security Best Practices for Users
|
||||
|
||||
1. **Keep Updated**: Always run the latest version
|
||||
2. **Secure Configuration**: Review `config.yaml` settings, especially in production
|
||||
3. **Environment Variables**: Never commit `.env` files or `config.yaml` with secrets
|
||||
2. **Secure Configuration**: Review your `~/.hive/configuration.json`, `.mcp.json`, and environment variable settings, especially in production
|
||||
3. **Environment Variables**: Never commit `.env` files or any configuration files that contain secrets
|
||||
4. **Network Security**: Use HTTPS in production, configure firewalls appropriately
|
||||
5. **Database Security**: Use strong passwords, limit network access
|
||||
|
||||
|
||||
@@ -601,7 +601,7 @@ async def handle_ws(websocket):
|
||||
)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
config=LoopConfig(max_iterations=10_000, max_history_tokens=32_000),
|
||||
config=LoopConfig(max_iterations=10_000, max_context_tokens=32_000),
|
||||
conversation_store=STORE,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
@@ -1769,7 +1769,7 @@ async def _run_pipeline(websocket, initial_message: str):
|
||||
config=LoopConfig(
|
||||
max_iterations=30,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_history_tokens=64000,
|
||||
max_context_tokens=64000,
|
||||
max_tool_result_chars=8_000,
|
||||
spillover_dir=str(_DATA_DIR),
|
||||
),
|
||||
|
||||
@@ -752,7 +752,7 @@ async def _run_pipeline(websocket, topic: str):
|
||||
config=LoopConfig(
|
||||
max_iterations=20,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_history_tokens=32_000,
|
||||
max_context_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_a,
|
||||
tool_executor=tool_executor,
|
||||
@@ -850,7 +850,7 @@ async def _run_pipeline(websocket, topic: str):
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_history_tokens=32_000,
|
||||
max_context_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_b,
|
||||
)
|
||||
|
||||
@@ -1258,7 +1258,7 @@ async def _run_org_pipeline(websocket, topic: str):
|
||||
config=LoopConfig(
|
||||
max_iterations=30,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_history_tokens=32_000,
|
||||
max_context_tokens=32_000,
|
||||
),
|
||||
conversation_store=store,
|
||||
tool_executor=executor,
|
||||
|
||||
@@ -10,13 +10,14 @@ from .agent import CredentialTesterAgent
|
||||
|
||||
|
||||
def setup_logging(verbose=False, debug=False):
|
||||
from framework.observability import configure_logging
|
||||
|
||||
if debug:
|
||||
level, fmt = logging.DEBUG, "%(asctime)s %(name)s: %(message)s"
|
||||
configure_logging(level="DEBUG")
|
||||
elif verbose:
|
||||
level, fmt = logging.INFO, "%(message)s"
|
||||
configure_logging(level="INFO")
|
||||
else:
|
||||
level, fmt = logging.WARNING, "%(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=level, format=fmt, stream=sys.stderr)
|
||||
configure_logging(level="WARNING")
|
||||
|
||||
|
||||
def pick_account(agent: CredentialTesterAgent) -> dict | None:
|
||||
|
||||
@@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from framework.config import get_max_context_tokens
|
||||
from framework.graph import Goal, NodeSpec, SuccessCriterion
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.edge import GraphSpec
|
||||
@@ -455,7 +456,6 @@ identity_prompt = (
|
||||
loop_config = {
|
||||
"max_iterations": 50,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -541,7 +541,7 @@ class CredentialTesterAgent:
|
||||
loop_config={
|
||||
"max_iterations": 50,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
"max_context_tokens": get_max_context_tokens(),
|
||||
},
|
||||
conversation_mode="continuous",
|
||||
identity_prompt=(
|
||||
|
||||
@@ -79,7 +79,7 @@ def _extract_agent_stats(agent_path: Path) -> tuple[int, int, list[str]]:
|
||||
if agent_json.exists():
|
||||
try:
|
||||
data = json.loads(agent_json.read_text(encoding="utf-8"))
|
||||
json_nodes = data.get("nodes", [])
|
||||
json_nodes = data.get("graph", {}).get("nodes", []) or data.get("nodes", [])
|
||||
if node_count == 0:
|
||||
node_count = len(json_nodes)
|
||||
tools: set[str] = set()
|
||||
|
||||
@@ -35,6 +35,5 @@ queen_graph = GraphSpec(
|
||||
loop_config={
|
||||
"max_iterations": 999_999,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -185,18 +185,21 @@ docs. Always run list_agent_tools() to see what actually exists.
|
||||
|
||||
# Tool Discovery (MANDATORY before designing)
|
||||
|
||||
Before designing any agent, run list_agent_tools() with NO arguments \
|
||||
to see ALL available tools (names + descriptions, grouped by category). \
|
||||
ONLY use tools from this list in your node definitions. \
|
||||
Before designing any agent, discover tools progressively — start compact, drill into \
|
||||
what you need. ONLY use tools from this list in your node definitions. \
|
||||
NEVER guess or fabricate tool names from memory.
|
||||
|
||||
list_agent_tools() # ALWAYS call this first (simple mode)
|
||||
list_agent_tools(group="google", output_schema="full") # drill into a provider
|
||||
list_agent_tools() # Step 1: provider summary (counts + credential status)
|
||||
list_agent_tools(group="google", output_schema="summary") # Step 2: service breakdown within a provider
|
||||
list_agent_tools(group="google", service="gmail") # Step 3: tool names for one service
|
||||
list_agent_tools(group="google", service="gmail", output_schema="full") # Step 4: full detail for specific tools
|
||||
|
||||
NEVER skip the first call. Always start with the full list \
|
||||
so you know what providers and tools exist before drilling in. \
|
||||
Simple mode truncates long descriptions — use group + "full" to \
|
||||
get the complete description and input_schema for the tools you need.
|
||||
Step 1 is MANDATORY. Returns provider names, tool counts, credential availability — very compact. \
|
||||
Step 2 breaks a provider into services (e.g. google → gmail/calendar/sheets/drive). Only do this \
|
||||
for providers that are relevant to the task. \
|
||||
Step 3 gets tool names for a specific service — no descriptions, minimal tokens. \
|
||||
Step 4 only for services you plan to actually use. \
|
||||
Use credentials="available" at any step to filter to tools whose credentials are already configured.
|
||||
|
||||
# Discovery & Design Workflow
|
||||
|
||||
@@ -410,11 +413,10 @@ hashline=True for anchors in results
|
||||
- undo_changes(path?) — restore from git snapshot
|
||||
|
||||
## Meta-Agent
|
||||
- list_agent_tools(server_config_path?, output_schema?, group?) — discover \
|
||||
available tools grouped by category. output_schema: "simple" (default, \
|
||||
descriptions truncated to ~200 chars) or "full" (complete descriptions + \
|
||||
input_schema). group: "all" (default) or a provider like "google". \
|
||||
Call FIRST before designing.
|
||||
- list_agent_tools(group?, service?, output_schema?, credentials?) — discover tools \
|
||||
progressively: no args=provider summary; group+output_schema="summary"=service breakdown; \
|
||||
group+service=tool names; group+service+output_schema="full"=full details. \
|
||||
credentials="available" filters to configured tools. Call FIRST before designing.
|
||||
- validate_agent_package(agent_name) — run ALL validation checks in one call \
|
||||
(class validation, runner load, tool validation, tests). Call after building.
|
||||
- list_agents() — list all agent packages in exports/ with session counts
|
||||
@@ -551,8 +553,8 @@ but no write/edit tools.
|
||||
- run_command(command, cwd?, timeout?) — Read-only commands only (grep, ls, git log). \
|
||||
Never use this to write files, run scripts, or modify the filesystem — transition \
|
||||
to BUILDING phase for that.
|
||||
- list_agent_tools(server_config_path?, output_schema?, group?) \
|
||||
— Discover available tools for design
|
||||
- list_agent_tools(server_config_path?, output_schema?, group?, credentials?) \
|
||||
— Discover available tools for design (summary → names → full)
|
||||
- list_agents() — See existing agent packages for reference
|
||||
- list_agent_sessions(agent_name, status?, limit?) — Inspect past runs of an agent
|
||||
- list_agent_checkpoints(agent_name, session_id) — View execution history
|
||||
|
||||
@@ -180,7 +180,7 @@ terminal_nodes = [] # Forever-alive
|
||||
# Module-level vars read by AgentRunner.load()
|
||||
conversation_mode = "continuous"
|
||||
identity_prompt = "You are a helpful agent."
|
||||
loop_config = {"max_iterations": 100, "max_tool_calls_per_turn": 20, "max_history_tokens": 32000}
|
||||
loop_config = {"max_iterations": 100, "max_tool_calls_per_turn": 20, "max_context_tokens": 32000}
|
||||
|
||||
|
||||
class MyAgent:
|
||||
|
||||
@@ -226,7 +226,7 @@ Only three valid keys:
|
||||
loop_config = {
|
||||
"max_iterations": 100, # Max LLM turns per node visit
|
||||
"max_tool_calls_per_turn": 20, # Max tool calls per LLM response
|
||||
"max_history_tokens": 32000, # Triggers conversation compaction
|
||||
"max_context_tokens": 32000, # Triggers conversation compaction
|
||||
}
|
||||
```
|
||||
**INVALID keys** (do NOT use): `"strategy"`, `"mode"`, `"timeout"`,
|
||||
|
||||
@@ -56,6 +56,14 @@ def get_max_tokens() -> int:
|
||||
return get_hive_config().get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
|
||||
|
||||
|
||||
DEFAULT_MAX_CONTEXT_TOKENS = 32_000
|
||||
|
||||
|
||||
def get_max_context_tokens() -> int:
|
||||
"""Return the configured max_context_tokens, falling back to DEFAULT_MAX_CONTEXT_TOKENS."""
|
||||
return get_hive_config().get("llm", {}).get("max_context_tokens", DEFAULT_MAX_CONTEXT_TOKENS)
|
||||
|
||||
|
||||
def get_api_key() -> str | None:
|
||||
"""Return the API key, supporting env var, Claude Code subscription, Codex, and ZAI Code.
|
||||
|
||||
@@ -90,6 +98,17 @@ def get_api_key() -> str | None:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Kimi Code subscription: read API key from ~/.kimi/config.toml
|
||||
if llm.get("use_kimi_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_kimi_code_token
|
||||
|
||||
token = get_kimi_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Standard env-var path (covers ZAI Code and all API-key providers)
|
||||
api_key_env_var = llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
@@ -108,6 +127,9 @@ def get_api_base() -> str | None:
|
||||
if llm.get("use_codex_subscription"):
|
||||
# Codex subscription routes through the ChatGPT backend, not api.openai.com.
|
||||
return "https://chatgpt.com/backend-api/codex"
|
||||
if llm.get("use_kimi_code_subscription"):
|
||||
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
return "https://api.kimi.com/coding"
|
||||
return llm.get("api_base")
|
||||
|
||||
|
||||
@@ -164,6 +186,7 @@ class RuntimeConfig:
|
||||
model: str = field(default_factory=get_preferred_model)
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = field(default_factory=get_max_tokens)
|
||||
max_context_tokens: int = field(default_factory=get_max_context_tokens)
|
||||
api_key: str | None = field(default_factory=get_api_key)
|
||||
api_base: str | None = field(default_factory=get_api_base)
|
||||
extra_kwargs: dict[str, Any] = field(default_factory=get_llm_extra_kwargs)
|
||||
|
||||
@@ -149,8 +149,14 @@ def delete_aden_api_key() -> None:
|
||||
|
||||
storage = EncryptedFileStorage()
|
||||
storage.delete(ADEN_CREDENTIAL_ID)
|
||||
except (FileNotFoundError, PermissionError) as e:
|
||||
logger.debug("Could not delete %s from encrypted store: %s", ADEN_CREDENTIAL_ID, e)
|
||||
except Exception:
|
||||
logger.debug("Could not delete %s from encrypted store", ADEN_CREDENTIAL_ID)
|
||||
logger.warning(
|
||||
"Unexpected error deleting %s from encrypted store",
|
||||
ADEN_CREDENTIAL_ID,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
os.environ.pop(ADEN_ENV_VAR, None)
|
||||
|
||||
@@ -167,8 +173,10 @@ def _read_credential_key_file() -> str | None:
|
||||
value = CREDENTIAL_KEY_PATH.read_text(encoding="utf-8").strip()
|
||||
if value:
|
||||
return value
|
||||
except (FileNotFoundError, PermissionError) as e:
|
||||
logger.debug("Could not read %s: %s", CREDENTIAL_KEY_PATH, e)
|
||||
except Exception:
|
||||
logger.debug("Could not read %s", CREDENTIAL_KEY_PATH)
|
||||
logger.warning("Unexpected error reading %s", CREDENTIAL_KEY_PATH, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
@@ -196,6 +204,12 @@ def _read_aden_from_encrypted_store() -> str | None:
|
||||
cred = storage.load(ADEN_CREDENTIAL_ID)
|
||||
if cred:
|
||||
return cred.get_key("api_key")
|
||||
except (FileNotFoundError, PermissionError, KeyError) as e:
|
||||
logger.debug("Could not load %s from encrypted store: %s", ADEN_CREDENTIAL_ID, e)
|
||||
except Exception:
|
||||
logger.debug("Could not load %s from encrypted store", ADEN_CREDENTIAL_ID)
|
||||
logger.warning(
|
||||
"Unexpected error loading %s from encrypted store",
|
||||
ADEN_CREDENTIAL_ID,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -307,13 +307,13 @@ class NodeConversation:
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str = "",
|
||||
max_history_tokens: int = 32000,
|
||||
max_context_tokens: int = 32000,
|
||||
compaction_threshold: float = 0.8,
|
||||
output_keys: list[str] | None = None,
|
||||
store: ConversationStore | None = None,
|
||||
) -> None:
|
||||
self._system_prompt = system_prompt
|
||||
self._max_history_tokens = max_history_tokens
|
||||
self._max_context_tokens = max_context_tokens
|
||||
self._compaction_threshold = compaction_threshold
|
||||
self._output_keys = output_keys
|
||||
self._store = store
|
||||
@@ -525,16 +525,16 @@ class NodeConversation:
|
||||
self._last_api_input_tokens = actual_input_tokens
|
||||
|
||||
def usage_ratio(self) -> float:
|
||||
"""Current token usage as a fraction of *max_history_tokens*.
|
||||
"""Current token usage as a fraction of *max_context_tokens*.
|
||||
|
||||
Returns 0.0 when ``max_history_tokens`` is zero (unlimited).
|
||||
Returns 0.0 when ``max_context_tokens`` is zero (unlimited).
|
||||
"""
|
||||
if self._max_history_tokens <= 0:
|
||||
if self._max_context_tokens <= 0:
|
||||
return 0.0
|
||||
return self.estimate_tokens() / self._max_history_tokens
|
||||
return self.estimate_tokens() / self._max_context_tokens
|
||||
|
||||
def needs_compaction(self) -> bool:
|
||||
return self.estimate_tokens() >= self._max_history_tokens * self._compaction_threshold
|
||||
return self.estimate_tokens() >= self._max_context_tokens * self._compaction_threshold
|
||||
|
||||
# --- Output-key extraction ---------------------------------------------
|
||||
|
||||
@@ -1029,7 +1029,7 @@ class NodeConversation:
|
||||
await self._store.write_meta(
|
||||
{
|
||||
"system_prompt": self._system_prompt,
|
||||
"max_history_tokens": self._max_history_tokens,
|
||||
"max_context_tokens": self._max_context_tokens,
|
||||
"compaction_threshold": self._compaction_threshold,
|
||||
"output_keys": self._output_keys,
|
||||
}
|
||||
@@ -1062,7 +1062,7 @@ class NodeConversation:
|
||||
|
||||
conv = cls(
|
||||
system_prompt=meta.get("system_prompt", ""),
|
||||
max_history_tokens=meta.get("max_history_tokens", 32000),
|
||||
max_context_tokens=meta.get("max_context_tokens", 32000),
|
||||
compaction_threshold=meta.get("compaction_threshold", 0.8),
|
||||
output_keys=meta.get("output_keys"),
|
||||
store=store,
|
||||
|
||||
@@ -37,7 +37,7 @@ async def evaluate_phase_completion(
|
||||
phase_description: str,
|
||||
success_criteria: str,
|
||||
accumulator_state: dict[str, Any],
|
||||
max_history_tokens: int = 8_196,
|
||||
max_context_tokens: int = 8_196,
|
||||
) -> PhaseVerdict:
|
||||
"""Level 2 judge: read the conversation and evaluate quality.
|
||||
|
||||
@@ -50,7 +50,7 @@ async def evaluate_phase_completion(
|
||||
phase_description: Description of the phase
|
||||
success_criteria: Natural-language criteria for phase completion
|
||||
accumulator_state: Current output key values
|
||||
max_history_tokens: Main conversation token budget (judge gets 20%)
|
||||
max_context_tokens: Main conversation token budget (judge gets 20%)
|
||||
|
||||
Returns:
|
||||
PhaseVerdict with action and optional feedback
|
||||
@@ -89,7 +89,7 @@ FEEDBACK: (reason if RETRY, empty if ACCEPT)"""
|
||||
response = await llm.acomplete(
|
||||
messages=[{"role": "user", "content": user_prompt}],
|
||||
system=system_prompt,
|
||||
max_tokens=max(1024, max_history_tokens // 5),
|
||||
max_tokens=max(1024, max_context_tokens // 5),
|
||||
max_retries=1,
|
||||
)
|
||||
if not response.content or not response.content.strip():
|
||||
|
||||
@@ -73,6 +73,7 @@ class _EscalationReceiver:
|
||||
def __init__(self) -> None:
|
||||
self._event = asyncio.Event()
|
||||
self._response: str | None = None
|
||||
self._awaiting_input = True # So inject_worker_message() can prefer us
|
||||
|
||||
async def inject_event(self, content: str, *, is_client_input: bool = False) -> None:
|
||||
"""Called by ExecutionStream.inject_input() when the user responds."""
|
||||
@@ -169,7 +170,7 @@ class LoopConfig:
|
||||
judge_every_n_turns: int = 1
|
||||
stall_detection_threshold: int = 3
|
||||
stall_similarity_threshold: float = 0.85
|
||||
max_history_tokens: int = 32_000
|
||||
max_context_tokens: int = 32_000
|
||||
store_prefix: str = ""
|
||||
|
||||
# Overflow margin for max_tool_calls_per_turn. Tool calls are only
|
||||
@@ -511,7 +512,7 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
conversation = NodeConversation(
|
||||
system_prompt=system_prompt,
|
||||
max_history_tokens=self._config.max_history_tokens,
|
||||
max_context_tokens=self._config.max_context_tokens,
|
||||
output_keys=ctx.node_spec.output_keys or None,
|
||||
store=self._conversation_store,
|
||||
)
|
||||
@@ -711,6 +712,7 @@ class EventLoopNode(NodeProtocol):
|
||||
model=turn_tokens.get("model", ""),
|
||||
input_tokens=turn_tokens.get("input", 0),
|
||||
output_tokens=turn_tokens.get("output", 0),
|
||||
cached_tokens=turn_tokens.get("cached", 0),
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
@@ -1832,7 +1834,7 @@ class EventLoopNode(NodeProtocol):
|
||||
stream_id = ctx.stream_id or ctx.node_id
|
||||
node_id = ctx.node_id
|
||||
execution_id = ctx.execution_id or ""
|
||||
token_counts: dict[str, int] = {"input": 0, "output": 0}
|
||||
token_counts: dict[str, int] = {"input": 0, "output": 0, "cached": 0}
|
||||
tool_call_count = 0
|
||||
final_text = ""
|
||||
final_system_prompt = conversation.system_prompt
|
||||
@@ -1913,6 +1915,7 @@ class EventLoopNode(NodeProtocol):
|
||||
elif isinstance(event, FinishEvent):
|
||||
token_counts["input"] += event.input_tokens
|
||||
token_counts["output"] += event.output_tokens
|
||||
token_counts["cached"] += event.cached_tokens
|
||||
token_counts["stop_reason"] = event.stop_reason
|
||||
token_counts["model"] = event.model
|
||||
|
||||
@@ -2456,7 +2459,7 @@ class EventLoopNode(NodeProtocol):
|
||||
# next turn. The char-based token estimator underestimates
|
||||
# actual API tokens, so the standard compaction check in the
|
||||
# outer loop may not trigger in time.
|
||||
protect = max(2000, self._config.max_history_tokens // 12)
|
||||
protect = max(2000, self._config.max_context_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
protect_tokens=protect,
|
||||
min_prune_tokens=max(1000, protect // 3),
|
||||
@@ -2465,7 +2468,7 @@ class EventLoopNode(NodeProtocol):
|
||||
logger.info(
|
||||
"Post-limit pruning: cleared %d old tool results (budget: %d)",
|
||||
pruned,
|
||||
self._config.max_history_tokens,
|
||||
self._config.max_context_tokens,
|
||||
)
|
||||
# Limit hit — return from this turn so the judge can
|
||||
# evaluate instead of looping back for another stream.
|
||||
@@ -2486,7 +2489,7 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
# --- Mid-turn pruning: prevent context blowup within a single turn ---
|
||||
if conversation.usage_ratio() >= 0.6:
|
||||
protect = max(2000, self._config.max_history_tokens // 12)
|
||||
protect = max(2000, self._config.max_context_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
protect_tokens=protect,
|
||||
min_prune_tokens=max(1000, protect // 3),
|
||||
@@ -2913,7 +2916,7 @@ class EventLoopNode(NodeProtocol):
|
||||
phase_description=ctx.node_spec.description,
|
||||
success_criteria=ctx.node_spec.success_criteria,
|
||||
accumulator_state=accumulator.to_dict(),
|
||||
max_history_tokens=self._config.max_history_tokens,
|
||||
max_context_tokens=self._config.max_context_tokens,
|
||||
)
|
||||
if verdict.action != "ACCEPT":
|
||||
return JudgeVerdict(
|
||||
@@ -3353,7 +3356,7 @@ class EventLoopNode(NodeProtocol):
|
||||
phase_grad = getattr(ctx, "continuous_mode", False)
|
||||
|
||||
# --- Step 1: Prune old tool results (free, no LLM) ---
|
||||
protect = max(2000, self._config.max_history_tokens // 12)
|
||||
protect = max(2000, self._config.max_context_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
protect_tokens=protect,
|
||||
min_prune_tokens=max(1000, protect // 3),
|
||||
@@ -3459,7 +3462,7 @@ class EventLoopNode(NodeProtocol):
|
||||
accumulator,
|
||||
formatted,
|
||||
)
|
||||
summary_budget = max(1024, self._config.max_history_tokens // 2)
|
||||
summary_budget = max(1024, self._config.max_context_tokens // 2)
|
||||
try:
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
@@ -3562,7 +3565,7 @@ class EventLoopNode(NodeProtocol):
|
||||
elif spec.output_keys:
|
||||
ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}")
|
||||
|
||||
target_tokens = self._config.max_history_tokens // 2
|
||||
target_tokens = self._config.max_context_tokens // 2
|
||||
target_chars = target_tokens * 4
|
||||
node_ctx = "\n".join(ctx_lines)
|
||||
|
||||
@@ -4030,6 +4033,7 @@ class EventLoopNode(NodeProtocol):
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
@@ -4041,6 +4045,7 @@ class EventLoopNode(NodeProtocol):
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
@@ -4323,22 +4328,18 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
registry[escalation_id] = receiver
|
||||
try:
|
||||
# Stream message to user (parent's node_id so TUI shows parent talking)
|
||||
await self._event_bus.emit_client_output_delta(
|
||||
stream_id=ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
content=message,
|
||||
snapshot=message,
|
||||
execution_id=ctx.execution_id,
|
||||
)
|
||||
# Request input (escalation_id for routing response back)
|
||||
await self._event_bus.emit_client_input_requested(
|
||||
stream_id=ctx.node_id,
|
||||
# Escalate to the queen instead of asking the user directly.
|
||||
# The queen handles the request and injects the response via
|
||||
# inject_worker_message(), which finds this receiver through
|
||||
# its _awaiting_input flag.
|
||||
await self._event_bus.emit_escalation_requested(
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=escalation_id,
|
||||
prompt=message,
|
||||
reason=f"Subagent report (wait_for_response) from {agent_id}",
|
||||
context=message,
|
||||
execution_id=ctx.execution_id,
|
||||
)
|
||||
# Block until user responds
|
||||
# Block until queen responds
|
||||
return await receiver.wait()
|
||||
finally:
|
||||
registry.pop(escalation_id, None)
|
||||
@@ -4445,7 +4446,7 @@ class EventLoopNode(NodeProtocol):
|
||||
max_iterations=max_iter, # Tighter budget
|
||||
max_tool_calls_per_turn=self._config.max_tool_calls_per_turn,
|
||||
tool_call_overflow_margin=self._config.tool_call_overflow_margin,
|
||||
max_history_tokens=self._config.max_history_tokens,
|
||||
max_context_tokens=self._config.max_context_tokens,
|
||||
stall_detection_threshold=self._config.stall_detection_threshold,
|
||||
max_tool_result_chars=self._config.max_tool_result_chars,
|
||||
spillover_dir=subagent_spillover,
|
||||
|
||||
@@ -330,7 +330,7 @@ class GraphExecutor:
|
||||
_depth,
|
||||
)
|
||||
else:
|
||||
max_tokens = getattr(conversation, "_max_history_tokens", 32000)
|
||||
max_tokens = getattr(conversation, "_max_context_tokens", 32000)
|
||||
target_tokens = max_tokens // 2
|
||||
target_chars = target_tokens * 4
|
||||
|
||||
@@ -1872,7 +1872,7 @@ class GraphExecutor:
|
||||
max_tool_calls_per_turn=lc.get("max_tool_calls_per_turn", 30),
|
||||
tool_call_overflow_margin=lc.get("tool_call_overflow_margin", 0.5),
|
||||
stall_detection_threshold=lc.get("stall_detection_threshold", 3),
|
||||
max_history_tokens=lc.get("max_history_tokens", 32000),
|
||||
max_context_tokens=lc.get("max_context_tokens", 32000),
|
||||
max_tool_result_chars=lc.get("max_tool_result_chars", 30_000),
|
||||
spillover_dir=spillover,
|
||||
hooks=lc.get("hooks", {}),
|
||||
|
||||
@@ -119,6 +119,19 @@ RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
MINIMAX_API_BASE = "https://api.minimax.io/v1"
|
||||
|
||||
# Providers that accept cache_control on message content blocks.
|
||||
# Anthropic: native ephemeral caching. MiniMax & Z-AI/GLM: pass-through to their APIs.
|
||||
# (OpenAI caches automatically server-side; Groq/Gemini/etc. strip the header.)
|
||||
_CACHE_CONTROL_PREFIXES = ("anthropic/", "claude-", "minimax/", "minimax-", "MiniMax-", "zai-glm", "glm-")
|
||||
|
||||
|
||||
def _model_supports_cache_control(model: str) -> bool:
|
||||
return any(model.startswith(p) for p in _CACHE_CONTROL_PREFIXES)
|
||||
# Kimi For Coding uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
# Claude Code integration uses this format; the /v1 OpenAI-compatible endpoint
|
||||
# enforces a coding-agent whitelist that blocks unknown User-Agents.
|
||||
KIMI_API_BASE = "https://api.kimi.com/coding"
|
||||
|
||||
# Empty-stream retries use a short fixed delay, not the rate-limit backoff.
|
||||
# Conversation-structure issues are deterministic — long waits don't help.
|
||||
EMPTY_STREAM_MAX_RETRIES = 3
|
||||
@@ -323,9 +336,21 @@ class LiteLLMProvider(LLMProvider):
|
||||
api_base: Custom API base URL (for proxies or local deployments)
|
||||
**kwargs: Additional arguments passed to litellm.completion()
|
||||
"""
|
||||
# Kimi For Coding exposes an Anthropic-compatible endpoint at
|
||||
# https://api.kimi.com/coding (the same format Claude Code uses natively).
|
||||
# Translate kimi/ prefix to anthropic/ so litellm uses the Anthropic
|
||||
# Messages API handler and routes to that endpoint — no special headers needed.
|
||||
_original_model = model
|
||||
if model.lower().startswith("kimi/"):
|
||||
model = "anthropic/" + model[len("kimi/") :]
|
||||
# Normalise api_base: litellm's Anthropic handler appends /v1/messages,
|
||||
# so the base must be https://api.kimi.com/coding (no /v1 suffix).
|
||||
# Strip a trailing /v1 in case the user's saved config has the old value.
|
||||
if api_base and api_base.rstrip("/").endswith("/v1"):
|
||||
api_base = api_base.rstrip("/")[:-3]
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or self._default_api_base_for_model(model)
|
||||
self.api_base = api_base or self._default_api_base_for_model(_original_model)
|
||||
self.extra_kwargs = kwargs
|
||||
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
|
||||
# several standard OpenAI params: max_output_tokens, stream_options.
|
||||
@@ -350,6 +375,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
model_lower = model.lower()
|
||||
if model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
|
||||
return MINIMAX_API_BASE
|
||||
if model_lower.startswith("kimi/"):
|
||||
return KIMI_API_BASE
|
||||
return None
|
||||
|
||||
def _completion_with_rate_limit_retry(
|
||||
@@ -689,7 +716,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
full_messages.append(sys_msg)
|
||||
full_messages.extend(messages)
|
||||
|
||||
if json_mode:
|
||||
@@ -860,7 +890,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
full_messages.append(sys_msg)
|
||||
full_messages.extend(messages)
|
||||
|
||||
# Codex Responses API requires an `instructions` field (system prompt).
|
||||
@@ -925,9 +958,26 @@ class LiteLLMProvider(LLMProvider):
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
async for chunk in response:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
# Capture usage from the trailing usage-only chunk that
|
||||
# stream_options={"include_usage": True} sends with empty choices.
|
||||
if not chunk.choices:
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
logger.debug(
|
||||
"[tokens] trailing usage chunk: input=%d output=%d model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
self.model,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[tokens] empty-choices chunk with no usage (model=%s)",
|
||||
self.model,
|
||||
)
|
||||
continue
|
||||
choice = chunk.choices[0]
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
@@ -1000,19 +1050,90 @@ class LiteLLMProvider(LLMProvider):
|
||||
tail_events.append(TextEndEvent(full_text=accumulated_text))
|
||||
|
||||
usage = getattr(chunk, "usage", None)
|
||||
logger.debug(
|
||||
"[tokens] finish-chunk raw usage: %r (type=%s)",
|
||||
usage,
|
||||
type(usage).__name__,
|
||||
)
|
||||
cached_tokens = 0
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
_details = getattr(usage, "prompt_tokens_details", None)
|
||||
cached_tokens = (
|
||||
getattr(_details, "cached_tokens", 0) or 0
|
||||
if _details is not None
|
||||
else getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
logger.debug(
|
||||
"[tokens] finish-chunk usage: input=%d output=%d cached=%d model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
self.model,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"[tokens] finish event: input=%d output=%d cached=%d stop=%s model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
choice.finish_reason,
|
||||
self.model,
|
||||
)
|
||||
tail_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
|
||||
# Fallback: LiteLLM strips usage from yielded chunks before
|
||||
# returning them to us, but appends the original chunk (with
|
||||
# usage intact) to response.chunks first. Use LiteLLM's own
|
||||
# calculate_total_usage() on that accumulated list.
|
||||
if input_tokens == 0 and output_tokens == 0:
|
||||
try:
|
||||
from litellm.litellm_core_utils.streaming_handler import (
|
||||
calculate_total_usage,
|
||||
)
|
||||
|
||||
_chunks = getattr(response, "chunks", None)
|
||||
if _chunks:
|
||||
_usage = calculate_total_usage(chunks=_chunks)
|
||||
input_tokens = _usage.prompt_tokens or 0
|
||||
output_tokens = _usage.completion_tokens or 0
|
||||
_details = getattr(_usage, "prompt_tokens_details", None)
|
||||
cached_tokens = (
|
||||
getattr(_details, "cached_tokens", 0) or 0
|
||||
if _details is not None
|
||||
else getattr(_usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
logger.debug(
|
||||
"[tokens] post-loop chunks fallback:"
|
||||
" input=%d output=%d cached=%d model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
self.model,
|
||||
)
|
||||
# Patch the FinishEvent already queued with 0 tokens
|
||||
for _i, _ev in enumerate(tail_events):
|
||||
if isinstance(_ev, FinishEvent) and _ev.input_tokens == 0:
|
||||
tail_events[_i] = FinishEvent(
|
||||
stop_reason=_ev.stop_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
model=_ev.model,
|
||||
)
|
||||
break
|
||||
except Exception as _e:
|
||||
logger.debug("[tokens] chunks fallback failed: %s", _e)
|
||||
|
||||
# Check whether the stream produced any real content.
|
||||
# (If text deltas were yielded above, has_content is True
|
||||
# and we skip the retry path — nothing was yielded in vain.)
|
||||
|
||||
@@ -71,6 +71,7 @@ class FinishEvent:
|
||||
stop_reason: str = ""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
|
||||
@@ -253,6 +253,6 @@ judge_graph = GraphSpec(
|
||||
loop_config={
|
||||
"max_iterations": 10, # One check shouldn't take many turns
|
||||
"max_tool_calls_per_turn": 3, # get_summary + optionally emit_ticket
|
||||
"max_history_tokens": 16000, # Compact — judge only needs recent context
|
||||
"max_context_tokens": 16000, # Compact — judge only needs recent context
|
||||
},
|
||||
)
|
||||
|
||||
@@ -148,8 +148,9 @@ class HumanReadableFormatter(logging.Formatter):
|
||||
if record_event is not None:
|
||||
event = f" [{record_event}]"
|
||||
|
||||
# Format message: [LEVEL] [trace context] message
|
||||
return f"{color}[{level}]{reset} {context_prefix}{record.getMessage()}{event}"
|
||||
timestamp = self.formatTime(record, "%Y-%m-%d %H:%M:%S")
|
||||
# Format message: TIMESTAMP [LEVEL] [trace context] message
|
||||
return f"{timestamp} {color}[{level}]{reset} {context_prefix}{record.getMessage()}{event}"
|
||||
|
||||
|
||||
def configure_logging(
|
||||
|
||||
@@ -243,6 +243,12 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
action="store_true",
|
||||
help="Open dashboard in browser after server starts",
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="Enable INFO log level"
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
"--debug", action="store_true", help="Enable DEBUG log level"
|
||||
)
|
||||
serve_parser.set_defaults(func=cmd_serve)
|
||||
|
||||
# open command (serve + auto-open browser)
|
||||
@@ -280,6 +286,12 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
default=None,
|
||||
help="LLM model for preloaded agents",
|
||||
)
|
||||
open_parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="Enable INFO log level"
|
||||
)
|
||||
open_parser.add_argument(
|
||||
"--debug", action="store_true", help="Enable DEBUG log level"
|
||||
)
|
||||
open_parser.set_defaults(func=cmd_open)
|
||||
|
||||
|
||||
@@ -380,13 +392,15 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
from framework.observability import configure_logging
|
||||
|
||||
# Set logging level (quiet by default for cleaner output)
|
||||
if args.quiet:
|
||||
logging.basicConfig(level=logging.ERROR, format="%(message)s")
|
||||
configure_logging(level="ERROR")
|
||||
elif getattr(args, "verbose", False):
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
configure_logging(level="INFO")
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARNING, format="%(message)s")
|
||||
configure_logging(level="WARNING")
|
||||
|
||||
# Load input context
|
||||
context = {}
|
||||
@@ -742,6 +756,17 @@ def cmd_dispatch(args: argparse.Namespace) -> int:
|
||||
if args.agents:
|
||||
# Use specific agents
|
||||
for agent_name in args.agents:
|
||||
# Guard against full paths: if the name contains path separators
|
||||
# (e.g. "exports/my_agent"), it will be doubled with agents_dir
|
||||
agent_name_path = Path(agent_name)
|
||||
if len(agent_name_path.parts) > 1:
|
||||
print(
|
||||
f"Error: --agents expects agent names, not paths. "
|
||||
f"Use: --agents {agent_name_path.name} "
|
||||
f"instead of --agents {agent_name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
agent_path = agents_dir / agent_name
|
||||
if not _is_valid_agent_dir(agent_path):
|
||||
print(f"Agent not found: {agent_path}", file=sys.stderr)
|
||||
@@ -912,11 +937,9 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Configure logging to show runtime visibility
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(message)s", # Simple format for clean output
|
||||
)
|
||||
from framework.observability import configure_logging
|
||||
|
||||
configure_logging(level="INFO")
|
||||
|
||||
agents_dir = Path(args.agents_dir)
|
||||
|
||||
@@ -1622,10 +1645,12 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
||||
|
||||
from framework.server.app import create_app
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
from framework.observability import configure_logging
|
||||
|
||||
if getattr(args, "debug", False):
|
||||
configure_logging(level="DEBUG")
|
||||
else:
|
||||
configure_logging(level="INFO")
|
||||
|
||||
model = getattr(args, "model", None)
|
||||
app = create_app(model=model)
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import UTC
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.config import get_hive_config, get_preferred_model
|
||||
from framework.config import get_hive_config, get_max_context_tokens, get_preferred_model
|
||||
from framework.credentials.validation import (
|
||||
ensure_credential_key_env as _ensure_credential_key_env,
|
||||
)
|
||||
@@ -517,6 +517,41 @@ def get_codex_account_id() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Kimi Code subscription token helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_kimi_code_token() -> str | None:
|
||||
"""Get the API key from a Kimi Code CLI installation.
|
||||
|
||||
Reads the API key from ``~/.kimi/config.toml``, which is created when
|
||||
the user runs ``kimi /login`` in the Kimi Code CLI.
|
||||
|
||||
Returns:
|
||||
The API key if available, None otherwise.
|
||||
"""
|
||||
import tomllib
|
||||
|
||||
config_path = Path.home() / ".kimi" / "config.toml"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
providers = config.get("providers", {})
|
||||
# kimi-cli stores credentials under providers.kimi-for-coding
|
||||
for provider_cfg in providers.values():
|
||||
if isinstance(provider_cfg, dict):
|
||||
key = provider_cfg.get("api_key")
|
||||
if key:
|
||||
return key
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
"""Information about an exported agent."""
|
||||
@@ -891,10 +926,31 @@ class AgentRunner:
|
||||
|
||||
if agent_config and hasattr(agent_config, "max_tokens"):
|
||||
max_tokens = agent_config.max_tokens
|
||||
logger.info(
|
||||
"Agent default_config overrides max_tokens: %d (configuration.json value ignored)",
|
||||
max_tokens,
|
||||
)
|
||||
else:
|
||||
hive_config = get_hive_config()
|
||||
max_tokens = hive_config.get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
|
||||
|
||||
# Resolve max_context_tokens with priority:
|
||||
# 1. agent loop_config["max_context_tokens"] (explicit, wins silently)
|
||||
# 2. agent default_config.max_context_tokens (logged)
|
||||
# 3. configuration.json llm.max_context_tokens
|
||||
# 4. hardcoded default (32_000)
|
||||
agent_loop_config: dict = dict(getattr(agent_module, "loop_config", {}))
|
||||
if "max_context_tokens" not in agent_loop_config:
|
||||
if agent_config and hasattr(agent_config, "max_context_tokens"):
|
||||
agent_loop_config["max_context_tokens"] = agent_config.max_context_tokens
|
||||
logger.info(
|
||||
"Agent default_config overrides max_context_tokens: %d"
|
||||
" (configuration.json value ignored)",
|
||||
agent_config.max_context_tokens,
|
||||
)
|
||||
else:
|
||||
agent_loop_config["max_context_tokens"] = get_max_context_tokens()
|
||||
|
||||
# Read intro_message from agent metadata (shown on TUI load)
|
||||
agent_metadata = getattr(agent_module, "metadata", None)
|
||||
intro_message = ""
|
||||
@@ -914,7 +970,7 @@ class AgentRunner:
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"max_tokens": max_tokens,
|
||||
"loop_config": getattr(agent_module, "loop_config", {}),
|
||||
"loop_config": agent_loop_config,
|
||||
}
|
||||
# Only pass optional fields if explicitly defined by the agent module
|
||||
conversation_mode = getattr(agent_module, "conversation_mode", None)
|
||||
@@ -1104,6 +1160,7 @@ class AgentRunner:
|
||||
llm_config = config.get("llm", {})
|
||||
use_claude_code = llm_config.get("use_claude_code_subscription", False)
|
||||
use_codex = llm_config.get("use_codex_subscription", False)
|
||||
use_kimi_code = llm_config.get("use_kimi_code_subscription", False)
|
||||
api_base = llm_config.get("api_base")
|
||||
|
||||
api_key = None
|
||||
@@ -1119,6 +1176,12 @@ class AgentRunner:
|
||||
if not api_key:
|
||||
print("Warning: Codex subscription configured but no token found.")
|
||||
print("Run 'codex' to authenticate, then try again.")
|
||||
elif use_kimi_code:
|
||||
# Get API key from Kimi Code CLI config (~/.kimi/config.toml)
|
||||
api_key = get_kimi_code_token()
|
||||
if not api_key:
|
||||
print("Warning: Kimi Code subscription configured but no key found.")
|
||||
print("Run 'kimi /login' to authenticate, then try again.")
|
||||
|
||||
if api_key and use_claude_code:
|
||||
# Use litellm's built-in Anthropic OAuth support.
|
||||
@@ -1149,6 +1212,14 @@ class AgentRunner:
|
||||
store=False,
|
||||
allowed_openai_params=["store"],
|
||||
)
|
||||
elif api_key and use_kimi_code:
|
||||
# Kimi Code subscription uses the Kimi coding API (OpenAI-compatible).
|
||||
# The api_base is set automatically by LiteLLMProvider for kimi/ models.
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
# Local models (e.g. Ollama) don't need an API key
|
||||
if self._is_local_model(self.model):
|
||||
@@ -1314,6 +1385,8 @@ class AgentRunner:
|
||||
return "TOGETHER_API_KEY"
|
||||
elif model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
|
||||
return "MINIMAX_API_KEY"
|
||||
elif model_lower.startswith("kimi/"):
|
||||
return "KIMI_API_KEY"
|
||||
else:
|
||||
# Default: assume OpenAI-compatible
|
||||
return "OPENAI_API_KEY"
|
||||
@@ -1334,6 +1407,8 @@ class AgentRunner:
|
||||
cred_id = "anthropic"
|
||||
elif model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
|
||||
cred_id = "minimax"
|
||||
elif model_lower.startswith("kimi/"):
|
||||
cred_id = "kimi"
|
||||
# Add more mappings as providers are added to LLM_CREDENTIALS
|
||||
|
||||
if cred_id is None:
|
||||
|
||||
@@ -1531,6 +1531,11 @@ class AgentRuntime:
|
||||
for executor in stream._active_executors.values():
|
||||
for node_id, node in executor.node_registry.items():
|
||||
if getattr(node, "_awaiting_input", False):
|
||||
# Skip escalation receivers — those are handled
|
||||
# by the queen via inject_worker_message(), not
|
||||
# by the user directly.
|
||||
if ":escalation:" in node_id:
|
||||
continue
|
||||
return node_id, graph_id
|
||||
return None, None
|
||||
|
||||
|
||||
@@ -616,6 +616,7 @@ class EventBus:
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
execution_id: str | None = None,
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
@@ -625,6 +626,7 @@ class EventBus:
|
||||
"model": model,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cached_tokens": cached_tokens,
|
||||
}
|
||||
if iteration is not None:
|
||||
data["iteration"] = iteration
|
||||
|
||||
@@ -137,6 +137,11 @@ async def create_queen(
|
||||
phase_state.staging_tools = [t for t in queen_tools if t.name in staging_names]
|
||||
phase_state.running_tools = [t for t in queen_tools if t.name in running_names]
|
||||
|
||||
# ---- Cross-session memory ----------------------------------------
|
||||
from framework.agents.queen.queen_memory import seed_if_missing
|
||||
|
||||
seed_if_missing()
|
||||
|
||||
# ---- Compose phase-specific prompts ------------------------------
|
||||
_orig_node = _queen_graph.nodes[0]
|
||||
|
||||
@@ -203,8 +208,7 @@ async def create_queen(
|
||||
data={"persona": persona},
|
||||
)
|
||||
)
|
||||
body = _planning_body if phase_state.phase == "planning" else _building_body
|
||||
return HookResult(system_prompt=persona + "\n\n" + body)
|
||||
return HookResult(system_prompt=persona + "\n\n" + phase_state.get_current_prompt())
|
||||
|
||||
# ---- Graph preparation -------------------------------------------
|
||||
initial_prompt_text = phase_state.get_current_prompt()
|
||||
|
||||
@@ -101,14 +101,20 @@ class QueenPhaseState:
|
||||
return list(self.building_tools)
|
||||
|
||||
def get_current_prompt(self) -> str:
|
||||
"""Return the system prompt for the current phase."""
|
||||
"""Return the system prompt for the current phase, with fresh memory appended."""
|
||||
if self.phase == "planning":
|
||||
return self.prompt_planning
|
||||
if self.phase == "running":
|
||||
return self.prompt_running
|
||||
if self.phase == "staging":
|
||||
return self.prompt_staging
|
||||
return self.prompt_building
|
||||
base = self.prompt_planning
|
||||
elif self.phase == "running":
|
||||
base = self.prompt_running
|
||||
elif self.phase == "staging":
|
||||
base = self.prompt_staging
|
||||
else:
|
||||
base = self.prompt_building
|
||||
|
||||
from framework.agents.queen.queen_memory import format_for_injection
|
||||
|
||||
memory = format_for_injection()
|
||||
return base + ("\n\n" + memory if memory else "")
|
||||
|
||||
async def _emit_phase_event(self) -> None:
|
||||
"""Publish a QUEEN_PHASE_CHANGED event so the frontend updates the tag."""
|
||||
@@ -1446,7 +1452,23 @@ def register_queen_lifecycle_tools(
|
||||
if reg is None:
|
||||
return json.dumps({"error": "Worker graph not found"})
|
||||
|
||||
# Find an active node that can accept injected input
|
||||
# Prefer nodes that are actively waiting (e.g. escalation receivers
|
||||
# blocked on queen guidance) over the main event-loop node.
|
||||
for stream in reg.streams.values():
|
||||
waiting = stream.get_waiting_nodes()
|
||||
if waiting:
|
||||
target_node_id = waiting[0]["node_id"]
|
||||
ok = await stream.inject_input(target_node_id, content, is_client_input=True)
|
||||
if ok:
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "delivered",
|
||||
"node_id": target_node_id,
|
||||
"content_preview": content[:100],
|
||||
}
|
||||
)
|
||||
|
||||
# Fallback: inject into any injectable node
|
||||
for stream in reg.streams.values():
|
||||
injectable = stream.get_injectable_nodes()
|
||||
if injectable:
|
||||
@@ -1498,6 +1520,15 @@ def register_queen_lifecycle_tools(
|
||||
Returns credential IDs, aliases, status, and identity metadata.
|
||||
Never returns secret values. Optionally filter by credential_id.
|
||||
"""
|
||||
# Load shell config vars into os.environ — same first step as check-agent.
|
||||
# Ensures keys set in ~/.zshrc/~/.bashrc are visible to is_available() checks.
|
||||
try:
|
||||
from framework.credentials.validation import ensure_credential_key_env
|
||||
|
||||
ensure_credential_key_env()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Primary: CredentialStoreAdapter sees both Aden OAuth and local accounts
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
@@ -1505,13 +1536,24 @@ def register_queen_lifecycle_tools(
|
||||
store = CredentialStoreAdapter.default()
|
||||
all_accounts = store.get_all_account_info()
|
||||
|
||||
# Filter by credential_id / provider if requested
|
||||
# Filter by credential_id / provider if requested.
|
||||
# A spec name like "gmail_oauth" maps to provider "google" via
|
||||
# credential_id field — resolve that alias before filtering.
|
||||
if credential_id:
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS.get(credential_id)
|
||||
resolved_provider = (
|
||||
(spec.credential_id or credential_id) if spec else credential_id
|
||||
)
|
||||
except Exception:
|
||||
resolved_provider = credential_id
|
||||
all_accounts = [
|
||||
a
|
||||
for a in all_accounts
|
||||
if a.get("credential_id", "").startswith(credential_id)
|
||||
or a.get("provider", "") == credential_id
|
||||
or a.get("provider", "") in (credential_id, resolved_provider)
|
||||
]
|
||||
|
||||
return json.dumps(
|
||||
@@ -1528,13 +1570,43 @@ def register_queen_lifecycle_tools(
|
||||
|
||||
# Fallback: local encrypted store only
|
||||
try:
|
||||
from framework.credentials.local.models import LocalAccountInfo
|
||||
from framework.credentials.local.registry import LocalCredentialRegistry
|
||||
from framework.credentials.storage import EncryptedFileStorage
|
||||
|
||||
registry = LocalCredentialRegistry.default()
|
||||
accounts = registry.list_accounts(
|
||||
credential_id=credential_id or None,
|
||||
)
|
||||
|
||||
# Also include flat-file credentials saved by the GUI (no "/" separator).
|
||||
# LocalCredentialRegistry.list_accounts() skips these — read them directly.
|
||||
seen_cred_ids = {info.credential_id for info in accounts}
|
||||
storage = EncryptedFileStorage()
|
||||
for storage_id in storage.list_all():
|
||||
if "/" in storage_id:
|
||||
continue # already handled by LocalCredentialRegistry above
|
||||
if credential_id and storage_id != credential_id:
|
||||
continue
|
||||
if storage_id in seen_cred_ids:
|
||||
continue
|
||||
try:
|
||||
cred_obj = storage.load(storage_id)
|
||||
except Exception:
|
||||
continue
|
||||
if cred_obj is None:
|
||||
continue
|
||||
accounts.append(
|
||||
LocalAccountInfo(
|
||||
credential_id=storage_id,
|
||||
alias="default",
|
||||
status="unknown",
|
||||
identity=cred_obj.identity,
|
||||
last_validated=cred_obj.last_refreshed,
|
||||
created_at=cred_obj.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
credentials = []
|
||||
for info in accounts:
|
||||
entry: dict[str, Any] = {
|
||||
|
||||
@@ -572,7 +572,7 @@ async def test_event_loop_conversation_compaction():
|
||||
judge = CountingJudge(retry_count=3)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=10, max_history_tokens=200),
|
||||
config=LoopConfig(max_iterations=10, max_context_tokens=200),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
|
||||
@@ -204,8 +204,8 @@ class TestNodeConversation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_ratio(self):
|
||||
"""usage_ratio returns estimate / max_history_tokens."""
|
||||
conv = NodeConversation(max_history_tokens=1000)
|
||||
"""usage_ratio returns estimate / max_context_tokens."""
|
||||
conv = NodeConversation(max_context_tokens=1000)
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.usage_ratio() == pytest.approx(0.1) # 100/1000
|
||||
|
||||
@@ -214,15 +214,15 @@ class TestNodeConversation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_ratio_zero_budget(self):
|
||||
"""usage_ratio returns 0 when max_history_tokens is 0 (unlimited)."""
|
||||
conv = NodeConversation(max_history_tokens=0)
|
||||
"""usage_ratio returns 0 when max_context_tokens is 0 (unlimited)."""
|
||||
conv = NodeConversation(max_context_tokens=0)
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.usage_ratio() == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction_with_actual_tokens(self):
|
||||
"""needs_compaction uses actual API token count when available."""
|
||||
conv = NodeConversation(max_history_tokens=1000, compaction_threshold=0.8)
|
||||
conv = NodeConversation(max_context_tokens=1000, compaction_threshold=0.8)
|
||||
await conv.add_user_message("a" * 100) # chars/4 = 25, well under 800
|
||||
|
||||
assert conv.needs_compaction() is False
|
||||
@@ -233,7 +233,7 @@ class TestNodeConversation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction(self):
|
||||
conv = NodeConversation(max_history_tokens=100, compaction_threshold=0.8)
|
||||
conv = NodeConversation(max_context_tokens=100, compaction_threshold=0.8)
|
||||
await conv.add_user_message("x" * 320)
|
||||
assert conv.needs_compaction() is True
|
||||
|
||||
@@ -457,7 +457,7 @@ class TestPersistence:
|
||||
store = MockConversationStore()
|
||||
assert await NodeConversation.restore(store) is None
|
||||
|
||||
conv = NodeConversation(system_prompt="hello", max_history_tokens=500, store=store)
|
||||
conv = NodeConversation(system_prompt="hello", max_context_tokens=500, store=store)
|
||||
await conv.add_user_message("u1")
|
||||
await conv.add_assistant_message("a1")
|
||||
|
||||
@@ -643,7 +643,7 @@ class TestConversationIntegration:
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a helpful travel agent.",
|
||||
max_history_tokens=16000,
|
||||
max_context_tokens=16000,
|
||||
store=store,
|
||||
)
|
||||
|
||||
@@ -1314,7 +1314,7 @@ class TestLlmCompact:
|
||||
"""Create a minimal EventLoopNode for testing."""
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig
|
||||
|
||||
config = LoopConfig(max_history_tokens=32000)
|
||||
config = LoopConfig(max_context_tokens=32000)
|
||||
node = EventLoopNode.__new__(EventLoopNode)
|
||||
node._config = config
|
||||
node._event_bus = None
|
||||
|
||||
@@ -970,13 +970,13 @@ class TestEscalationFlow:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_response_emits_client_events(
|
||||
async def test_wait_for_response_emits_escalation_event(
|
||||
self,
|
||||
runtime,
|
||||
parent_node_spec,
|
||||
subagent_node_spec,
|
||||
):
|
||||
"""Escalation should emit CLIENT_OUTPUT_DELTA and CLIENT_INPUT_REQUESTED events."""
|
||||
"""Escalation should emit ESCALATION_REQUESTED to the queen."""
|
||||
from framework.graph.event_loop_node import _EscalationReceiver
|
||||
|
||||
bus = EventBus()
|
||||
@@ -986,7 +986,7 @@ class TestEscalationFlow:
|
||||
bus_events.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED],
|
||||
event_types=[EventType.ESCALATION_REQUESTED],
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
@@ -1034,16 +1034,12 @@ class TestEscalationFlow:
|
||||
await node._execute_subagent(ctx, "researcher", "Navigate page with CAPTCHA")
|
||||
await injector
|
||||
|
||||
# Should have emitted both events
|
||||
output_deltas = [e for e in bus_events if e.type == EventType.CLIENT_OUTPUT_DELTA]
|
||||
input_requests = [e for e in bus_events if e.type == EventType.CLIENT_INPUT_REQUESTED]
|
||||
# Should have emitted ESCALATION_REQUESTED
|
||||
escalation_events = [e for e in bus_events if e.type == EventType.ESCALATION_REQUESTED]
|
||||
|
||||
assert len(output_deltas) >= 1, "Should emit CLIENT_OUTPUT_DELTA with the message"
|
||||
assert output_deltas[0].data["content"] == "CAPTCHA detected on page"
|
||||
assert output_deltas[0].node_id == "parent" # Shows as parent talking
|
||||
|
||||
assert len(input_requests) >= 1, "Should emit CLIENT_INPUT_REQUESTED for routing"
|
||||
assert ":escalation:" in input_requests[0].node_id # Escalation ID for routing
|
||||
assert len(escalation_events) >= 1, "Should emit ESCALATION_REQUESTED"
|
||||
assert escalation_events[0].data["context"] == "CAPTCHA detected on page"
|
||||
assert ":escalation:" in escalation_events[0].node_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_blocking_report_still_works(
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
Tests the FULL routing chain:
|
||||
ExecutionStream → GraphExecutor → EventLoopNode → _execute_subagent
|
||||
→ _report_callback registers _EscalationReceiver in executor.node_registry
|
||||
→ emit CLIENT_INPUT_REQUESTED with escalation_id
|
||||
→ subscriber calls stream.inject_input(escalation_id, "done")
|
||||
→ ExecutionStream finds _EscalationReceiver in executor.node_registry
|
||||
→ emit ESCALATION_REQUESTED (queen handles the escalation)
|
||||
→ queen inject_worker_message() finds _EscalationReceiver via get_waiting_nodes()
|
||||
→ receiver.inject_event("done") unblocks the subagent
|
||||
→ subagent continues and completes
|
||||
"""
|
||||
@@ -227,26 +226,30 @@ async def test_escalation_e2e_through_execution_stream(tmp_path):
|
||||
stream_holder: list[ExecutionStream] = []
|
||||
|
||||
async def escalation_handler(event: AgentEvent):
|
||||
"""Simulate a TUI/runner: when CLIENT_INPUT_REQUESTED arrives with
|
||||
an escalation node_id, inject the user's response via the stream."""
|
||||
"""Simulate the queen: when ESCALATION_REQUESTED arrives,
|
||||
find the waiting receiver and inject the response via the stream."""
|
||||
all_events.append(event)
|
||||
if event.type == EventType.CLIENT_INPUT_REQUESTED:
|
||||
node_id = event.node_id
|
||||
if ":escalation:" in node_id:
|
||||
escalation_events.append(event)
|
||||
# Small delay to simulate user typing
|
||||
await asyncio.sleep(0.05)
|
||||
# Route through the REAL inject_input chain
|
||||
stream = stream_holder[0]
|
||||
success = await stream.inject_input(node_id, "done logging in")
|
||||
assert success, (
|
||||
f"inject_input({node_id!r}) returned False — "
|
||||
"escalation receiver not found in executor.node_registry"
|
||||
)
|
||||
inject_called.set()
|
||||
if event.type == EventType.ESCALATION_REQUESTED:
|
||||
escalation_events.append(event)
|
||||
# Small delay to simulate queen processing
|
||||
await asyncio.sleep(0.05)
|
||||
# Route through the REAL inject_input chain — find the waiting
|
||||
# escalation receiver via get_waiting_nodes() (mirrors what
|
||||
# inject_worker_message does in the queen lifecycle tools).
|
||||
stream = stream_holder[0]
|
||||
waiting = stream.get_waiting_nodes()
|
||||
assert waiting, "Should have a waiting escalation receiver"
|
||||
target_node_id = waiting[0]["node_id"]
|
||||
assert ":escalation:" in target_node_id
|
||||
success = await stream.inject_input(target_node_id, "done logging in")
|
||||
assert success, (
|
||||
f"inject_input({target_node_id!r}) returned False — "
|
||||
"escalation receiver not found in executor.node_registry"
|
||||
)
|
||||
inject_called.set()
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED, EventType.CLIENT_OUTPUT_DELTA],
|
||||
event_types=[EventType.ESCALATION_REQUESTED],
|
||||
handler=escalation_handler,
|
||||
)
|
||||
|
||||
@@ -297,17 +300,7 @@ async def test_escalation_e2e_through_execution_stream(tmp_path):
|
||||
# 3. Escalation event has correct structure
|
||||
esc_event = escalation_events[0]
|
||||
assert ":escalation:" in esc_event.node_id
|
||||
assert esc_event.data["prompt"] == "Login required for LinkedIn. Please log in manually."
|
||||
|
||||
# 4. CLIENT_OUTPUT_DELTA was emitted for the escalation message
|
||||
output_deltas = [
|
||||
e
|
||||
for e in all_events
|
||||
if e.type == EventType.CLIENT_OUTPUT_DELTA and "Login required" in e.data.get("content", "")
|
||||
]
|
||||
assert len(output_deltas) >= 1, (
|
||||
"Should have emitted CLIENT_OUTPUT_DELTA with escalation message"
|
||||
)
|
||||
assert esc_event.data["context"] == "Login required for LinkedIn. Please log in manually."
|
||||
|
||||
# 5. The parent node got the subagent's result
|
||||
assert "result" in result.output
|
||||
@@ -444,7 +437,7 @@ async def test_escalation_cleanup_after_completion(tmp_path):
|
||||
stream_holder: list[ExecutionStream] = []
|
||||
|
||||
async def auto_respond(event: AgentEvent):
|
||||
if event.type == EventType.CLIENT_INPUT_REQUESTED and ":escalation:" in event.node_id:
|
||||
if event.type == EventType.ESCALATION_REQUESTED:
|
||||
stream = stream_holder[0]
|
||||
|
||||
# Snapshot the active executor's node_registry BEFORE responding
|
||||
@@ -462,10 +455,13 @@ async def test_escalation_cleanup_after_completion(tmp_path):
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
await stream.inject_input(event.node_id, "ok")
|
||||
# Find the waiting escalation receiver and inject response
|
||||
waiting = stream.get_waiting_nodes()
|
||||
if waiting:
|
||||
await stream.inject_input(waiting[0]["node_id"], "ok")
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
event_types=[EventType.ESCALATION_REQUESTED],
|
||||
handler=auto_respond,
|
||||
)
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ Add to `.vscode/settings.json`:
|
||||
## Security Best Practices
|
||||
|
||||
1. **Never commit API keys** - Use environment variables or `.env` files
|
||||
2. **`.env` is git-ignored** - Copy `.env.example` to `.env` at the project root and fill in your values
|
||||
2. **If you use a local `.env` file, keep it private** - This repository does not include a root `.env.example`; use your own local `.env` file or shell environment variables for secrets
|
||||
3. **Use real provider keys in non-production environments** - validate configuration with low-risk inputs before production rollout
|
||||
4. **Credential isolation** - Each tool validates its own credentials at runtime
|
||||
|
||||
|
||||
+112
-15
@@ -911,6 +911,13 @@ $zaiKey = [System.Environment]::GetEnvironmentVariable("ZAI_API_KEY", "User")
|
||||
if (-not $zaiKey) { $zaiKey = $env:ZAI_API_KEY }
|
||||
if ($zaiKey) { $ZaiCredDetected = $true }
|
||||
|
||||
$KimiCredDetected = $false
|
||||
$kimiConfigPath = Join-Path $env:USERPROFILE ".kimi\config.toml"
|
||||
if (Test-Path $kimiConfigPath) { $KimiCredDetected = $true }
|
||||
$kimiKey = [System.Environment]::GetEnvironmentVariable("KIMI_API_KEY", "User")
|
||||
if (-not $kimiKey) { $kimiKey = $env:KIMI_API_KEY }
|
||||
if ($kimiKey) { $KimiCredDetected = $true }
|
||||
|
||||
# Detect API key providers
|
||||
$ProviderMenuEnvVars = @("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", "GROQ_API_KEY", "CEREBRAS_API_KEY")
|
||||
$ProviderMenuNames = @("Anthropic (Claude) - Recommended", "OpenAI (GPT)", "Google Gemini - Free tier available", "Groq - Fast, free tier", "Cerebras - Fast, free tier")
|
||||
@@ -938,7 +945,9 @@ if (Test-Path $HiveConfigFile) {
|
||||
$PrevEnvVar = if ($prevLlm.api_key_env_var) { $prevLlm.api_key_env_var } else { "" }
|
||||
if ($prevLlm.use_claude_code_subscription) { $PrevSubMode = "claude_code" }
|
||||
elseif ($prevLlm.use_codex_subscription) { $PrevSubMode = "codex" }
|
||||
elseif ($prevLlm.use_kimi_code_subscription) { $PrevSubMode = "kimi_code" }
|
||||
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.z.ai*") { $PrevSubMode = "zai_code" }
|
||||
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.kimi.com*") { $PrevSubMode = "kimi_code" }
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
@@ -951,6 +960,7 @@ if ($PrevSubMode -or $PrevProvider) {
|
||||
"claude_code" { if ($ClaudeCredDetected) { $prevCredValid = $true } }
|
||||
"zai_code" { if ($ZaiCredDetected) { $prevCredValid = $true } }
|
||||
"codex" { if ($CodexCredDetected) { $prevCredValid = $true } }
|
||||
"kimi_code" { if ($KimiCredDetected) { $prevCredValid = $true } }
|
||||
default {
|
||||
if ($PrevEnvVar) {
|
||||
$envVal = [System.Environment]::GetEnvironmentVariable($PrevEnvVar, "Process")
|
||||
@@ -964,14 +974,16 @@ if ($PrevSubMode -or $PrevProvider) {
|
||||
"claude_code" { $DefaultChoice = "1" }
|
||||
"zai_code" { $DefaultChoice = "2" }
|
||||
"codex" { $DefaultChoice = "3" }
|
||||
"kimi_code" { $DefaultChoice = "4" }
|
||||
}
|
||||
if (-not $DefaultChoice) {
|
||||
switch ($PrevProvider) {
|
||||
"anthropic" { $DefaultChoice = "4" }
|
||||
"openai" { $DefaultChoice = "5" }
|
||||
"gemini" { $DefaultChoice = "6" }
|
||||
"groq" { $DefaultChoice = "7" }
|
||||
"cerebras" { $DefaultChoice = "8" }
|
||||
"anthropic" { $DefaultChoice = "5" }
|
||||
"openai" { $DefaultChoice = "6" }
|
||||
"gemini" { $DefaultChoice = "7" }
|
||||
"groq" { $DefaultChoice = "8" }
|
||||
"cerebras" { $DefaultChoice = "9" }
|
||||
"kimi" { $DefaultChoice = "4" }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1003,12 +1015,19 @@ Write-Host ") OpenAI Codex Subscription " -NoNewline
|
||||
Write-Color -Text "(use your Codex/ChatGPT Plus plan)" -Color DarkGray -NoNewline
|
||||
if ($CodexCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
# 4) Kimi Code
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "4" -Color Cyan -NoNewline
|
||||
Write-Host ") Kimi Code Subscription " -NoNewline
|
||||
Write-Color -Text "(use your Kimi Code plan)" -Color DarkGray -NoNewline
|
||||
if ($KimiCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
Write-Host ""
|
||||
Write-Color -Text " API key providers:" -Color Cyan
|
||||
|
||||
# 4-8) API key providers
|
||||
# 5-9) API key providers
|
||||
for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
|
||||
$num = $idx + 4
|
||||
$num = $idx + 5
|
||||
$envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "Process")
|
||||
if (-not $envVal) { $envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "User") }
|
||||
Write-Host " " -NoNewline
|
||||
@@ -1018,7 +1037,7 @@ for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
|
||||
}
|
||||
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "9" -Color Cyan -NoNewline
|
||||
Write-Color -Text "10" -Color Cyan -NoNewline
|
||||
Write-Host ") Skip for now"
|
||||
Write-Host ""
|
||||
|
||||
@@ -1029,16 +1048,16 @@ if ($DefaultChoice) {
|
||||
|
||||
while ($true) {
|
||||
if ($DefaultChoice) {
|
||||
$raw = Read-Host "Enter choice (1-9) [$DefaultChoice]"
|
||||
$raw = Read-Host "Enter choice (1-10) [$DefaultChoice]"
|
||||
if ([string]::IsNullOrWhiteSpace($raw)) { $raw = $DefaultChoice }
|
||||
} else {
|
||||
$raw = Read-Host "Enter choice (1-9)"
|
||||
$raw = Read-Host "Enter choice (1-10)"
|
||||
}
|
||||
if ($raw -match '^\d+$') {
|
||||
$num = [int]$raw
|
||||
if ($num -ge 1 -and $num -le 9) { break }
|
||||
if ($num -ge 1 -and $num -le 10) { break }
|
||||
}
|
||||
Write-Color -Text "Invalid choice. Please enter 1-9" -Color Red
|
||||
Write-Color -Text "Invalid choice. Please enter 1-10" -Color Red
|
||||
}
|
||||
|
||||
switch ($num) {
|
||||
@@ -1102,9 +1121,20 @@ switch ($num) {
|
||||
Write-Ok "Using OpenAI Codex subscription"
|
||||
}
|
||||
}
|
||||
{ $_ -ge 4 -and $_ -le 8 } {
|
||||
4 {
|
||||
# Kimi Code Subscription
|
||||
$SubscriptionMode = "kimi_code"
|
||||
$SelectedProviderId = "kimi"
|
||||
$SelectedEnvVar = "KIMI_API_KEY"
|
||||
$SelectedModel = "kimi-k2.5"
|
||||
$SelectedMaxTokens = 32768
|
||||
Write-Host ""
|
||||
Write-Ok "Using Kimi Code subscription"
|
||||
Write-Color -Text " Model: kimi-k2.5 | API: api.kimi.com/coding" -Color DarkGray
|
||||
}
|
||||
{ $_ -ge 5 -and $_ -le 9 } {
|
||||
# API key providers
|
||||
$provIdx = $num - 4
|
||||
$provIdx = $num - 5
|
||||
$SelectedEnvVar = $ProviderMenuEnvVars[$provIdx]
|
||||
$SelectedProviderId = $ProviderMenuIds[$provIdx]
|
||||
$providerName = $ProviderMenuNames[$provIdx] -replace ' - .*', '' # strip description
|
||||
@@ -1175,7 +1205,7 @@ switch ($num) {
|
||||
}
|
||||
}
|
||||
}
|
||||
9 {
|
||||
10 {
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. An LLM API key is required to test and use worker agents."
|
||||
Write-Host " Add your API key later by running:"
|
||||
@@ -1252,6 +1282,70 @@ if ($SubscriptionMode -eq "zai_code") {
|
||||
}
|
||||
}
|
||||
|
||||
# For Kimi Code subscription: prompt for API key with verification + retry
|
||||
if ($SubscriptionMode -eq "kimi_code") {
|
||||
while ($true) {
|
||||
$existingKimi = [System.Environment]::GetEnvironmentVariable("KIMI_API_KEY", "User")
|
||||
if (-not $existingKimi) { $existingKimi = $env:KIMI_API_KEY }
|
||||
|
||||
if ($existingKimi) {
|
||||
$masked = $existingKimi.Substring(0, [Math]::Min(4, $existingKimi.Length)) + "..." + $existingKimi.Substring([Math]::Max(0, $existingKimi.Length - 4))
|
||||
Write-Host ""
|
||||
Write-Color -Text " $([char]0x2B22) Current Kimi key: $masked" -Color Green
|
||||
$apiKey = Read-Host " Press Enter to keep, or paste a new key to replace"
|
||||
} else {
|
||||
Write-Host ""
|
||||
Write-Host "Get your API key from: " -NoNewline
|
||||
Write-Color -Text "https://www.kimi.com/code" -Color Cyan
|
||||
Write-Host ""
|
||||
$apiKey = Read-Host "Paste your Kimi API key (or press Enter to skip)"
|
||||
}
|
||||
|
||||
if ($apiKey) {
|
||||
[System.Environment]::SetEnvironmentVariable("KIMI_API_KEY", $apiKey, "User")
|
||||
$env:KIMI_API_KEY = $apiKey
|
||||
Write-Host ""
|
||||
Write-Ok "Kimi API key saved as User environment variable"
|
||||
|
||||
# Health check the new key
|
||||
Write-Host " Verifying Kimi API key... " -NoNewline
|
||||
try {
|
||||
$hcResult = & uv run python (Join-Path $ScriptDir "scripts/check_llm_key.py") "kimi" $apiKey "https://api.kimi.com/coding" 2>$null
|
||||
$hcJson = $hcResult | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
Write-Color -Text "ok" -Color Green
|
||||
break
|
||||
} elseif ($hcJson.valid -eq $false) {
|
||||
Write-Color -Text "failed" -Color Red
|
||||
Write-Warn $hcJson.message
|
||||
[System.Environment]::SetEnvironmentVariable("KIMI_API_KEY", $null, "User")
|
||||
Remove-Item -Path "Env:\KIMI_API_KEY" -ErrorAction SilentlyContinue
|
||||
Write-Host ""
|
||||
Read-Host " Press Enter to try again"
|
||||
} else {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} catch {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} elseif (-not $existingKimi) {
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. Add your Kimi API key later:"
|
||||
Write-Color -Text " [System.Environment]::SetEnvironmentVariable('KIMI_API_KEY', 'your-key', 'User')" -Color Cyan
|
||||
$SelectedEnvVar = ""
|
||||
$SelectedProviderId = ""
|
||||
$SubscriptionMode = ""
|
||||
break
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Prompt for model if not already selected (manual provider path)
|
||||
if ($SelectedProviderId -and -not $SelectedModel) {
|
||||
$modelSel = Get-ModelSelection $SelectedProviderId
|
||||
@@ -1287,6 +1381,9 @@ if ($SelectedProviderId) {
|
||||
} elseif ($SubscriptionMode -eq "zai_code") {
|
||||
$config.llm["api_base"] = "https://api.z.ai/api/coding/paas/v4"
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
} elseif ($SubscriptionMode -eq "kimi_code") {
|
||||
$config.llm["api_base"] = "https://api.kimi.com/coding"
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
} else {
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
}
|
||||
|
||||
+125
-35
@@ -410,7 +410,7 @@ if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
declare -A DEFAULT_MODELS=(
|
||||
["anthropic"]="claude-haiku-4-5-20251001"
|
||||
["openai"]="gpt-5-mini"
|
||||
["minimax"]="MiniMax-M2.1"
|
||||
["minimax"]="MiniMax-M2.5"
|
||||
["gemini"]="gemini-3-flash-preview"
|
||||
["groq"]="moonshotai/kimi-k2-instruct-0905"
|
||||
["cerebras"]="zai-glm-4.7"
|
||||
@@ -466,6 +466,23 @@ if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
["cerebras:1"]=8192
|
||||
)
|
||||
|
||||
# Max context tokens (input history budget) per model, based on actual context windows.
|
||||
# Leave ~10% headroom for system prompt and output tokens.
|
||||
declare -A MODEL_CHOICES_MAXCONTEXTTOKENS=(
|
||||
["anthropic:0"]=180000 # Claude Haiku 4.5 — 200k context window
|
||||
["anthropic:1"]=180000 # Claude Sonnet 4 — 200k context window
|
||||
["anthropic:2"]=180000 # Claude Sonnet 4.5 — 200k context window
|
||||
["anthropic:3"]=180000 # Claude Opus 4.6 — 200k context window
|
||||
["openai:0"]=120000 # GPT-5 Mini — 128k context window
|
||||
["openai:1"]=120000 # GPT-5.2 — 128k context window
|
||||
["gemini:0"]=900000 # Gemini 3 Flash — 1M context window
|
||||
["gemini:1"]=900000 # Gemini 3.1 Pro — 1M context window
|
||||
["groq:0"]=120000 # Kimi K2 — 128k context window
|
||||
["groq:1"]=120000 # GPT-OSS 120B — 128k context window
|
||||
["cerebras:0"]=120000 # ZAI-GLM 4.7 — 128k context window
|
||||
["cerebras:1"]=120000 # Qwen3 235B — 128k context window
|
||||
)
|
||||
|
||||
declare -A MODEL_CHOICES_COUNT=(
|
||||
["anthropic"]=4
|
||||
["openai"]=2
|
||||
@@ -502,6 +519,10 @@ if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
get_model_choice_maxtokens() {
|
||||
echo "${MODEL_CHOICES_MAXTOKENS[$1:$2]}"
|
||||
}
|
||||
|
||||
get_model_choice_maxcontexttokens() {
|
||||
echo "${MODEL_CHOICES_MAXCONTEXTTOKENS[$1:$2]}"
|
||||
}
|
||||
else
|
||||
# Bash 3.2 - use parallel indexed arrays
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY MINIMAX_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
@@ -510,7 +531,7 @@ else
|
||||
|
||||
# Default models by provider id (parallel arrays)
|
||||
MODEL_PROVIDER_IDS=(anthropic openai minimax gemini groq cerebras mistral together_ai deepseek)
|
||||
MODEL_DEFAULTS=("claude-haiku-4-5-20251001" "gpt-5-mini" "MiniMax-M2.1" "gemini-3-flash-preview" "moonshotai/kimi-k2-instruct-0905" "zai-glm-4.7" "mistral-large-latest" "meta-llama/Llama-3.3-70B-Instruct-Turbo" "deepseek-chat")
|
||||
MODEL_DEFAULTS=("claude-haiku-4-5-20251001" "gpt-5-mini" "MiniMax-M2.5" "gemini-3-flash-preview" "moonshotai/kimi-k2-instruct-0905" "zai-glm-4.7" "mistral-large-latest" "meta-llama/Llama-3.3-70B-Instruct-Turbo" "deepseek-chat")
|
||||
|
||||
# Helper: get provider display name for an env var
|
||||
get_provider_name() {
|
||||
@@ -557,6 +578,9 @@ else
|
||||
MC_IDS=("claude-haiku-4-5-20251001" "claude-sonnet-4-20250514" "claude-sonnet-4-5-20250929" "claude-opus-4-6" "gpt-5-mini" "gpt-5.2" "gemini-3-flash-preview" "gemini-3.1-pro-preview" "moonshotai/kimi-k2-instruct-0905" "openai/gpt-oss-120b" "zai-glm-4.7" "qwen3-235b-a22b-instruct-2507")
|
||||
MC_LABELS=("Haiku 4.5 - Fast + cheap (recommended)" "Sonnet 4 - Fast + capable" "Sonnet 4.5 - Best balance" "Opus 4.6 - Most capable" "GPT-5 Mini - Fast + cheap (recommended)" "GPT-5.2 - Most capable" "Gemini 3 Flash - Fast (recommended)" "Gemini 3.1 Pro - Best quality" "Kimi K2 - Best quality (recommended)" "GPT-OSS 120B - Fast reasoning" "ZAI-GLM 4.7 - Best quality (recommended)" "Qwen3 235B - Frontier reasoning")
|
||||
MC_MAXTOKENS=(8192 8192 16384 32768 16384 16384 8192 8192 8192 8192 8192 8192)
|
||||
# Max context tokens per model (same order as MC_PROVIDERS/MC_IDS above)
|
||||
# Based on actual context windows with ~10% headroom for system prompt + output.
|
||||
MC_MAXCONTEXTTOKENS=(180000 180000 180000 180000 120000 120000 900000 900000 120000 120000 120000 120000)
|
||||
|
||||
# Helper: get number of model choices for a provider
|
||||
get_model_choice_count() {
|
||||
@@ -625,6 +649,24 @@ else
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
|
||||
# Helper: get model choice max_context_tokens by provider and index
|
||||
get_model_choice_maxcontexttokens() {
|
||||
local provider_id="$1"
|
||||
local idx="$2"
|
||||
local count=0
|
||||
local i=0
|
||||
while [ $i -lt ${#MC_PROVIDERS[@]} ]; do
|
||||
if [ "${MC_PROVIDERS[$i]}" = "$provider_id" ]; then
|
||||
if [ $count -eq "$idx" ]; then
|
||||
echo "${MC_MAXCONTEXTTOKENS[$i]}"
|
||||
return
|
||||
fi
|
||||
count=$((count + 1))
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
fi
|
||||
|
||||
# Configuration directory
|
||||
@@ -664,7 +706,7 @@ SHELL_RC_FILE=$(detect_shell_rc)
|
||||
SHELL_NAME=$(basename "$SHELL")
|
||||
|
||||
# Prompt the user to choose a model for their selected provider.
|
||||
# Sets SELECTED_MODEL and SELECTED_MAX_TOKENS.
|
||||
# Sets SELECTED_MODEL, SELECTED_MAX_TOKENS, and SELECTED_MAX_CONTEXT_TOKENS.
|
||||
prompt_model_selection() {
|
||||
local provider_id="$1"
|
||||
local count
|
||||
@@ -674,6 +716,7 @@ prompt_model_selection() {
|
||||
# No curated choices for this provider (e.g. Mistral, DeepSeek)
|
||||
SELECTED_MODEL="$(get_default_model "$provider_id")"
|
||||
SELECTED_MAX_TOKENS=8192
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # 128k context window (Mistral, DeepSeek, etc.)
|
||||
return
|
||||
fi
|
||||
|
||||
@@ -681,6 +724,7 @@ prompt_model_selection() {
|
||||
# Only one choice — auto-select
|
||||
SELECTED_MODEL="$(get_model_choice_id "$provider_id" 0)"
|
||||
SELECTED_MAX_TOKENS="$(get_model_choice_maxtokens "$provider_id" 0)"
|
||||
SELECTED_MAX_CONTEXT_TOKENS="$(get_model_choice_maxcontexttokens "$provider_id" 0)"
|
||||
return
|
||||
fi
|
||||
|
||||
@@ -726,6 +770,7 @@ prompt_model_selection() {
|
||||
local idx=$((choice - 1))
|
||||
SELECTED_MODEL="$(get_model_choice_id "$provider_id" "$idx")"
|
||||
SELECTED_MAX_TOKENS="$(get_model_choice_maxtokens "$provider_id" "$idx")"
|
||||
SELECTED_MAX_CONTEXT_TOKENS="$(get_model_choice_maxcontexttokens "$provider_id" "$idx")"
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Model: ${DIM}$SELECTED_MODEL${NC}"
|
||||
return
|
||||
@@ -735,15 +780,16 @@ prompt_model_selection() {
|
||||
}
|
||||
|
||||
# Function to save configuration
|
||||
# Args: provider_id env_var model max_tokens [use_claude_code_sub] [api_base] [use_codex_sub]
|
||||
# Args: provider_id env_var model max_tokens max_context_tokens [use_claude_code_sub] [api_base] [use_codex_sub]
|
||||
save_configuration() {
|
||||
local provider_id="$1"
|
||||
local env_var="$2"
|
||||
local model="$3"
|
||||
local max_tokens="$4"
|
||||
local use_claude_code_sub="${5:-}"
|
||||
local api_base="${6:-}"
|
||||
local use_codex_sub="${7:-}"
|
||||
local max_context_tokens="$5"
|
||||
local use_claude_code_sub="${6:-}"
|
||||
local api_base="${7:-}"
|
||||
local use_codex_sub="${8:-}"
|
||||
|
||||
# Fallbacks if not provided
|
||||
if [ -z "$model" ]; then
|
||||
@@ -752,6 +798,9 @@ save_configuration() {
|
||||
if [ -z "$max_tokens" ]; then
|
||||
max_tokens=8192
|
||||
fi
|
||||
if [ -z "$max_context_tokens" ]; then
|
||||
max_context_tokens=120000
|
||||
fi
|
||||
|
||||
mkdir -p "$HIVE_CONFIG_DIR"
|
||||
|
||||
@@ -762,6 +811,7 @@ config = {
|
||||
'provider': '$provider_id',
|
||||
'model': '$model',
|
||||
'max_tokens': $max_tokens,
|
||||
'max_context_tokens': $max_context_tokens,
|
||||
'api_key_env_var': '$env_var'
|
||||
},
|
||||
'created_at': '$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")'
|
||||
@@ -796,7 +846,8 @@ FOUND_ENV_VARS=() # Corresponding env var names
|
||||
SELECTED_PROVIDER_ID="" # Will hold the chosen provider ID
|
||||
SELECTED_ENV_VAR="" # Will hold the chosen env var
|
||||
SELECTED_MODEL="" # Will hold the chosen model ID
|
||||
SELECTED_MAX_TOKENS=8192 # Will hold the chosen max_tokens
|
||||
SELECTED_MAX_TOKENS=8192 # Will hold the chosen max_tokens (output limit)
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # Will hold the chosen max_context_tokens (input history budget)
|
||||
SUBSCRIPTION_MODE="" # "claude_code" | "codex" | "zai_code" | ""
|
||||
|
||||
# ── Credential detection (silent — just set flags) ───────────
|
||||
@@ -824,6 +875,13 @@ if [ -n "${MINIMAX_API_KEY:-}" ]; then
|
||||
MINIMAX_CRED_DETECTED=true
|
||||
fi
|
||||
|
||||
KIMI_CRED_DETECTED=false
|
||||
if [ -f "$HOME/.kimi/config.toml" ]; then
|
||||
KIMI_CRED_DETECTED=true
|
||||
elif [ -n "${KIMI_API_KEY:-}" ]; then
|
||||
KIMI_CRED_DETECTED=true
|
||||
fi
|
||||
|
||||
# Detect API key providers
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
for env_var in "${!PROVIDER_NAMES[@]}"; do
|
||||
@@ -859,6 +917,7 @@ try:
|
||||
sub = ''
|
||||
if llm.get('use_claude_code_subscription'): sub = 'claude_code'
|
||||
elif llm.get('use_codex_subscription'): sub = 'codex'
|
||||
elif llm.get('use_kimi_code_subscription'): sub = 'kimi_code'
|
||||
elif llm.get('provider', '') == 'minimax' or 'api.minimax.io' in llm.get('api_base', ''): sub = 'minimax_code'
|
||||
elif 'api.z.ai' in llm.get('api_base', ''): sub = 'zai_code'
|
||||
print(f'PREV_SUB_MODE={sub}')
|
||||
@@ -875,6 +934,7 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
|
||||
claude_code) [ "$CLAUDE_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
zai_code) [ "$ZAI_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
codex) [ "$CODEX_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
kimi_code) [ "$KIMI_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
*)
|
||||
# API key provider — check if the env var is set
|
||||
if [ -n "$PREV_ENV_VAR" ] && [ -n "${!PREV_ENV_VAR}" ]; then
|
||||
@@ -889,15 +949,17 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
|
||||
zai_code) DEFAULT_CHOICE=2 ;;
|
||||
codex) DEFAULT_CHOICE=3 ;;
|
||||
minimax_code) DEFAULT_CHOICE=4 ;;
|
||||
kimi_code) DEFAULT_CHOICE=5 ;;
|
||||
esac
|
||||
if [ -z "$DEFAULT_CHOICE" ]; then
|
||||
case "$PREV_PROVIDER" in
|
||||
anthropic) DEFAULT_CHOICE=5 ;;
|
||||
openai) DEFAULT_CHOICE=6 ;;
|
||||
gemini) DEFAULT_CHOICE=7 ;;
|
||||
groq) DEFAULT_CHOICE=8 ;;
|
||||
cerebras) DEFAULT_CHOICE=9 ;;
|
||||
anthropic) DEFAULT_CHOICE=6 ;;
|
||||
openai) DEFAULT_CHOICE=7 ;;
|
||||
gemini) DEFAULT_CHOICE=8 ;;
|
||||
groq) DEFAULT_CHOICE=9 ;;
|
||||
cerebras) DEFAULT_CHOICE=10 ;;
|
||||
minimax) DEFAULT_CHOICE=4 ;;
|
||||
kimi) DEFAULT_CHOICE=5 ;;
|
||||
esac
|
||||
fi
|
||||
fi
|
||||
@@ -936,14 +998,21 @@ else
|
||||
echo -e " ${CYAN}4)${NC} MiniMax Coding Key ${DIM}(use your MiniMax coding key)${NC}"
|
||||
fi
|
||||
|
||||
# 5) Kimi Code
|
||||
if [ "$KIMI_CRED_DETECTED" = true ]; then
|
||||
echo -e " ${CYAN}5)${NC} Kimi Code Subscription ${DIM}(use your Kimi Code plan)${NC} ${GREEN}(credential detected)${NC}"
|
||||
else
|
||||
echo -e " ${CYAN}5)${NC} Kimi Code Subscription ${DIM}(use your Kimi Code plan)${NC}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo -e " ${CYAN}${BOLD}API key providers:${NC}"
|
||||
|
||||
# 5-9) API key providers — show (credential detected) if key already set
|
||||
# 6-10) API key providers — show (credential detected) if key already set
|
||||
PROVIDER_MENU_ENVS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GROQ_API_KEY CEREBRAS_API_KEY)
|
||||
PROVIDER_MENU_NAMES=("Anthropic (Claude) - Recommended" "OpenAI (GPT)" "Google Gemini - Free tier available" "Groq - Fast, free tier" "Cerebras - Fast, free tier")
|
||||
for idx in 0 1 2 3 4; do
|
||||
num=$((idx + 5))
|
||||
num=$((idx + 6))
|
||||
env_var="${PROVIDER_MENU_ENVS[$idx]}"
|
||||
if [ -n "${!env_var}" ]; then
|
||||
echo -e " ${CYAN}$num)${NC} ${PROVIDER_MENU_NAMES[$idx]} ${GREEN}(credential detected)${NC}"
|
||||
@@ -952,7 +1021,7 @@ for idx in 0 1 2 3 4; do
|
||||
fi
|
||||
done
|
||||
|
||||
echo -e " ${CYAN}10)${NC} Skip for now"
|
||||
echo -e " ${CYAN}11)${NC} Skip for now"
|
||||
echo ""
|
||||
|
||||
if [ -n "$DEFAULT_CHOICE" ]; then
|
||||
@@ -962,15 +1031,15 @@ fi
|
||||
|
||||
while true; do
|
||||
if [ -n "$DEFAULT_CHOICE" ]; then
|
||||
read -r -p "Enter choice (1-10) [$DEFAULT_CHOICE]: " choice || true
|
||||
read -r -p "Enter choice (1-11) [$DEFAULT_CHOICE]: " choice || true
|
||||
choice="${choice:-$DEFAULT_CHOICE}"
|
||||
else
|
||||
read -r -p "Enter choice (1-10): " choice || true
|
||||
read -r -p "Enter choice (1-11): " choice || true
|
||||
fi
|
||||
if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le 10 ]; then
|
||||
if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le 11 ]; then
|
||||
break
|
||||
fi
|
||||
echo -e "${RED}Invalid choice. Please enter 1-10${NC}"
|
||||
echo -e "${RED}Invalid choice. Please enter 1-11${NC}"
|
||||
done
|
||||
|
||||
case $choice in
|
||||
@@ -988,6 +1057,7 @@ case $choice in
|
||||
SELECTED_PROVIDER_ID="anthropic"
|
||||
SELECTED_MODEL="claude-opus-4-6"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=180000 # Claude — 200k context window
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using Claude Code subscription"
|
||||
fi
|
||||
@@ -999,6 +1069,7 @@ case $choice in
|
||||
SELECTED_ENV_VAR="ZAI_API_KEY"
|
||||
SELECTED_MODEL="glm-5"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # GLM-5 — 128k context window
|
||||
PROVIDER_NAME="ZAI"
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using ZAI Code subscription"
|
||||
@@ -1029,6 +1100,7 @@ case $choice in
|
||||
SELECTED_PROVIDER_ID="openai"
|
||||
SELECTED_MODEL="gpt-5.3-codex"
|
||||
SELECTED_MAX_TOKENS=16384
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # GPT Codex — 128k context window
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using OpenAI Codex subscription"
|
||||
fi
|
||||
@@ -1038,46 +1110,62 @@ case $choice in
|
||||
SUBSCRIPTION_MODE="minimax_code"
|
||||
SELECTED_ENV_VAR="MINIMAX_API_KEY"
|
||||
SELECTED_PROVIDER_ID="minimax"
|
||||
SELECTED_MODEL="MiniMax-M2.1"
|
||||
SELECTED_MAX_TOKENS=8192
|
||||
SELECTED_MODEL="MiniMax-M2.5"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=900000 # MiniMax M2.5 — 1M context window
|
||||
SELECTED_API_BASE="https://api.minimax.io/v1"
|
||||
PROVIDER_NAME="MiniMax"
|
||||
SIGNUP_URL="https://platform.minimax.io/user-center/basic-information/interface-key"
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using MiniMax coding key"
|
||||
echo -e " ${DIM}Model: MiniMax-M2.1 | API: api.minimax.io${NC}"
|
||||
echo -e " ${DIM}Model: MiniMax-M2.5 | API: api.minimax.io${NC}"
|
||||
;;
|
||||
5)
|
||||
# Kimi Code Subscription
|
||||
SUBSCRIPTION_MODE="kimi_code"
|
||||
SELECTED_PROVIDER_ID="kimi"
|
||||
SELECTED_ENV_VAR="KIMI_API_KEY"
|
||||
SELECTED_MODEL="kimi-k2.5"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # Kimi K2.5 — 128k context window
|
||||
SELECTED_API_BASE="https://api.kimi.com/coding"
|
||||
PROVIDER_NAME="Kimi"
|
||||
SIGNUP_URL="https://www.kimi.com/code"
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using Kimi Code subscription"
|
||||
echo -e " ${DIM}Model: kimi-k2.5 | API: api.kimi.com/coding${NC}"
|
||||
;;
|
||||
6)
|
||||
SELECTED_ENV_VAR="ANTHROPIC_API_KEY"
|
||||
SELECTED_PROVIDER_ID="anthropic"
|
||||
PROVIDER_NAME="Anthropic"
|
||||
SIGNUP_URL="https://console.anthropic.com/settings/keys"
|
||||
;;
|
||||
6)
|
||||
7)
|
||||
SELECTED_ENV_VAR="OPENAI_API_KEY"
|
||||
SELECTED_PROVIDER_ID="openai"
|
||||
PROVIDER_NAME="OpenAI"
|
||||
SIGNUP_URL="https://platform.openai.com/api-keys"
|
||||
;;
|
||||
7)
|
||||
8)
|
||||
SELECTED_ENV_VAR="GEMINI_API_KEY"
|
||||
SELECTED_PROVIDER_ID="gemini"
|
||||
PROVIDER_NAME="Google Gemini"
|
||||
SIGNUP_URL="https://aistudio.google.com/apikey"
|
||||
;;
|
||||
8)
|
||||
9)
|
||||
SELECTED_ENV_VAR="GROQ_API_KEY"
|
||||
SELECTED_PROVIDER_ID="groq"
|
||||
PROVIDER_NAME="Groq"
|
||||
SIGNUP_URL="https://console.groq.com/keys"
|
||||
;;
|
||||
9)
|
||||
10)
|
||||
SELECTED_ENV_VAR="CEREBRAS_API_KEY"
|
||||
SELECTED_PROVIDER_ID="cerebras"
|
||||
PROVIDER_NAME="Cerebras"
|
||||
SIGNUP_URL="https://cloud.cerebras.ai/"
|
||||
;;
|
||||
10)
|
||||
11)
|
||||
echo ""
|
||||
echo -e "${YELLOW}Skipped.${NC} An LLM API key is required to test and use worker agents."
|
||||
echo -e "Add your API key later by running:"
|
||||
@@ -1090,7 +1178,7 @@ case $choice in
|
||||
esac
|
||||
|
||||
# For API-key providers: prompt for key (allow replacement if already set)
|
||||
if { [ -z "$SUBSCRIPTION_MODE" ] || [ "$SUBSCRIPTION_MODE" = "minimax_code" ]; } && [ -n "$SELECTED_ENV_VAR" ]; then
|
||||
if { [ -z "$SUBSCRIPTION_MODE" ] || [ "$SUBSCRIPTION_MODE" = "minimax_code" ] || [ "$SUBSCRIPTION_MODE" = "kimi_code" ]; } && [ -n "$SELECTED_ENV_VAR" ]; then
|
||||
while true; do
|
||||
CURRENT_KEY="${!SELECTED_ENV_VAR}"
|
||||
if [ -n "$CURRENT_KEY" ]; then
|
||||
@@ -1118,7 +1206,7 @@ if { [ -z "$SUBSCRIPTION_MODE" ] || [ "$SUBSCRIPTION_MODE" = "minimax_code" ]; }
|
||||
echo -e "${GREEN}⬢${NC} API key saved to $SHELL_RC_FILE"
|
||||
# Health check the new key
|
||||
echo -n " Verifying API key... "
|
||||
if [ "$SUBSCRIPTION_MODE" = "minimax_code" ] && [ -n "${SELECTED_API_BASE:-}" ]; then
|
||||
if { [ "$SUBSCRIPTION_MODE" = "minimax_code" ] || [ "$SUBSCRIPTION_MODE" = "kimi_code" ]; } && [ -n "${SELECTED_API_BASE:-}" ]; then
|
||||
HC_RESULT=$(uv run python "$SCRIPT_DIR/scripts/check_llm_key.py" "$SELECTED_PROVIDER_ID" "$API_KEY" "$SELECTED_API_BASE" 2>/dev/null) || true
|
||||
else
|
||||
HC_RESULT=$(uv run python "$SCRIPT_DIR/scripts/check_llm_key.py" "$SELECTED_PROVIDER_ID" "$API_KEY" 2>/dev/null) || true
|
||||
@@ -1231,15 +1319,17 @@ if [ -n "$SELECTED_PROVIDER_ID" ]; then
|
||||
echo ""
|
||||
echo -n " Saving configuration... "
|
||||
if [ "$SUBSCRIPTION_MODE" = "claude_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "true" "" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "true" "" > /dev/null
|
||||
elif [ "$SUBSCRIPTION_MODE" = "codex" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "" "" "true" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "" "true" > /dev/null
|
||||
elif [ "$SUBSCRIPTION_MODE" = "zai_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "" "https://api.z.ai/api/coding/paas/v4" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "https://api.z.ai/api/coding/paas/v4" > /dev/null
|
||||
elif [ "$SUBSCRIPTION_MODE" = "minimax_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "" "$SELECTED_API_BASE" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null
|
||||
elif [ "$SUBSCRIPTION_MODE" = "kimi_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null
|
||||
else
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" > /dev/null
|
||||
fi
|
||||
echo -e "${GREEN}⬢${NC}"
|
||||
echo -e " ${DIM}~/.hive/configuration.json${NC}"
|
||||
|
||||
@@ -56,6 +56,53 @@ def check_openai_compatible(api_key: str, endpoint: str, name: str) -> dict:
|
||||
return {"valid": False, "message": f"{name} API returned status {r.status_code}"}
|
||||
|
||||
|
||||
def check_minimax(
|
||||
api_key: str, api_base: str = "https://api.minimax.io/v1", **_: str
|
||||
) -> dict:
|
||||
"""Validate via chatcompletion_v2 endpoint with empty messages.
|
||||
|
||||
MiniMax doesn't support GET /models; their native endpoint is
|
||||
/v1/text/chatcompletion_v2.
|
||||
"""
|
||||
with httpx.Client(timeout=TIMEOUT) as client:
|
||||
r = client.post(
|
||||
f"{api_base.rstrip('/')}/text/chatcompletion_v2",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"model": "MiniMax-M2.5", "messages": []},
|
||||
)
|
||||
if r.status_code in (200, 400, 422, 429):
|
||||
return {"valid": True, "message": "MiniMax API key valid"}
|
||||
if r.status_code == 401:
|
||||
return {"valid": False, "message": "Invalid MiniMax API key"}
|
||||
if r.status_code == 403:
|
||||
return {"valid": False, "message": "MiniMax API key lacks permissions"}
|
||||
return {"valid": False, "message": f"MiniMax API returned status {r.status_code}"}
|
||||
|
||||
|
||||
def check_anthropic_compatible(api_key: str, endpoint: str, name: str) -> dict:
|
||||
"""POST empty messages to an Anthropic-compatible endpoint to validate key."""
|
||||
with httpx.Client(timeout=TIMEOUT) as client:
|
||||
r = client.post(
|
||||
endpoint,
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"model": "kimi-k2.5", "max_tokens": 1, "messages": []},
|
||||
)
|
||||
if r.status_code in (200, 400, 429):
|
||||
return {"valid": True, "message": f"{name} API key valid"}
|
||||
if r.status_code == 401:
|
||||
return {"valid": False, "message": f"Invalid {name} API key"}
|
||||
if r.status_code == 403:
|
||||
return {"valid": False, "message": f"{name} API key lacks permissions"}
|
||||
return {"valid": False, "message": f"{name} API returned status {r.status_code}"}
|
||||
|
||||
|
||||
def check_gemini(api_key: str, **_: str) -> dict:
|
||||
"""List models with query param auth."""
|
||||
with httpx.Client(timeout=TIMEOUT) as client:
|
||||
@@ -82,8 +129,11 @@ PROVIDERS = {
|
||||
"cerebras": lambda key, **kw: check_openai_compatible(
|
||||
key, "https://api.cerebras.ai/v1/models", "Cerebras"
|
||||
),
|
||||
"minimax": lambda key, **kw: check_openai_compatible(
|
||||
key, "https://api.minimax.io/v1/models", "MiniMax"
|
||||
"minimax": lambda key, **kw: check_minimax(key),
|
||||
# Kimi For Coding uses an Anthropic-compatible endpoint; check via /v1/messages
|
||||
# with empty messages (same as check_anthropic, triggers 400 not 401).
|
||||
"kimi": lambda key, **kw: check_anthropic_compatible(
|
||||
key, "https://api.kimi.com/coding/v1/messages", "Kimi"
|
||||
),
|
||||
}
|
||||
|
||||
@@ -105,12 +155,17 @@ def main() -> None:
|
||||
api_base = sys.argv[3] if len(sys.argv) > 3 else ""
|
||||
|
||||
try:
|
||||
if api_base:
|
||||
if api_base and provider_id == "minimax":
|
||||
result = check_minimax(api_key, api_base)
|
||||
elif api_base and provider_id == "kimi":
|
||||
# Kimi uses an Anthropic-compatible endpoint; check via /v1/messages
|
||||
result = check_anthropic_compatible(
|
||||
api_key, api_base.rstrip("/") + "/v1/messages", "Kimi"
|
||||
)
|
||||
elif api_base:
|
||||
# Custom API base (ZAI or other OpenAI-compatible)
|
||||
endpoint = api_base.rstrip("/") + "/models"
|
||||
name = {"zai": "ZAI", "minimax": "MiniMax"}.get(
|
||||
provider_id, "Custom provider"
|
||||
)
|
||||
name = {"zai": "ZAI"}.get(provider_id, "Custom provider")
|
||||
result = check_openai_compatible(api_key, endpoint, name)
|
||||
elif provider_id in PROVIDERS:
|
||||
result = PROVIDERS[provider_id](api_key)
|
||||
|
||||
+208
-30
@@ -334,8 +334,10 @@ def undo_changes(path: str = "") -> str:
|
||||
@mcp.tool()
|
||||
def list_agent_tools(
|
||||
server_config_path: str = "",
|
||||
output_schema: str = "simple",
|
||||
output_schema: str = "summary",
|
||||
group: str = "all",
|
||||
credentials: str = "all",
|
||||
service: str = "",
|
||||
) -> str:
|
||||
"""Discover tools available for agent building, grouped by provider.
|
||||
|
||||
@@ -343,22 +345,52 @@ def list_agent_tools(
|
||||
BEFORE designing an agent to know exactly which tools exist. Only use
|
||||
tools from this list in node definitions — never guess or fabricate.
|
||||
|
||||
Progressive disclosure workflow (start narrow, drill in):
|
||||
list_agent_tools() # provider summary: counts + credential status
|
||||
list_agent_tools(group="google", output_schema="summary") # service breakdown within google
|
||||
list_agent_tools(group="google", service="gmail") # tool names for just gmail
|
||||
list_agent_tools(group="google", service="gmail", output_schema="full") # full detail
|
||||
|
||||
Args:
|
||||
server_config_path: Path to mcp_servers.json. Default: tools/mcp_servers.json
|
||||
(the standard hive-tools server). Can also point to an agent's config
|
||||
to see what tools that specific agent has access to.
|
||||
output_schema: "simple" (default) returns name and description per tool.
|
||||
"full" also includes server and input_schema.
|
||||
output_schema: Controls verbosity of the response.
|
||||
"summary" (default) — provider list with tool counts + credential status. Very compact.
|
||||
When group is specified, shows service-level breakdown within that provider.
|
||||
"names" — tool names only (no descriptions), grouped by provider.
|
||||
"simple" — names + truncated descriptions.
|
||||
"full" — names + descriptions + server + input_schema.
|
||||
group: "all" (default) returns all providers. A provider like "google"
|
||||
returns only that provider's tools. Legacy prefix filters (e.g. "gmail")
|
||||
are still supported.
|
||||
credentials: Filter by credential availability.
|
||||
"all" (default) — show every tool regardless of credential status.
|
||||
"available" — only tools whose credentials are already configured.
|
||||
"unavailable" — only tools that still need credential setup.
|
||||
service: Filter to a specific service within a provider (e.g. service="gmail"
|
||||
when group="google"). Matches tools whose name starts with "<service>_".
|
||||
|
||||
Returns:
|
||||
JSON with tools grouped by provider.
|
||||
"""
|
||||
if output_schema not in ("simple", "full"):
|
||||
if output_schema not in ("summary", "names", "simple", "full"):
|
||||
return json.dumps(
|
||||
{"error": f"Invalid output_schema: {output_schema!r}. Use 'simple' or 'full'."}
|
||||
{
|
||||
"error": (
|
||||
f"Invalid output_schema: {output_schema!r}. "
|
||||
"Use 'summary', 'names', 'simple', or 'full'."
|
||||
)
|
||||
}
|
||||
)
|
||||
if credentials not in ("all", "available", "unavailable"):
|
||||
return json.dumps(
|
||||
{
|
||||
"error": (
|
||||
f"Invalid credentials: {credentials!r}. "
|
||||
"Use 'all', 'available', or 'unavailable'."
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Resolve config path
|
||||
@@ -472,6 +504,33 @@ def list_agent_tools(
|
||||
|
||||
tool_provider_auth, tool_providers = _build_provider_metadata()
|
||||
|
||||
def _get_available_credential_names() -> set[str]:
|
||||
"""Return set of credential spec keys whose env_var is set in the environment."""
|
||||
try:
|
||||
from framework.credentials.validation import ensure_credential_key_env
|
||||
|
||||
ensure_credential_key_env()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
except ImportError:
|
||||
return set()
|
||||
return {
|
||||
cred_name
|
||||
for cred_name, spec in CREDENTIAL_SPECS.items()
|
||||
if spec.env_var and os.environ.get(spec.env_var)
|
||||
}
|
||||
|
||||
def _tool_credentials_available(tool_name: str, available_creds: set[str]) -> bool:
|
||||
"""True if all credentials required by tool_name are available (or tool needs none)."""
|
||||
required = set()
|
||||
for provider_creds in tool_provider_auth.get(tool_name, {}).values():
|
||||
required.update(provider_creds.keys())
|
||||
if not required:
|
||||
return True # no credentials needed
|
||||
return required.issubset(available_creds)
|
||||
|
||||
def _group_by_provider(tools: list[dict]) -> dict[str, dict]:
|
||||
"""Group tools by provider, including auth metadata and providerless tools."""
|
||||
groups: dict[str, dict] = {}
|
||||
@@ -481,16 +540,20 @@ def list_agent_tools(
|
||||
if not providers:
|
||||
providers = ["no_provider"]
|
||||
|
||||
desc = t["description"]
|
||||
if output_schema == "simple" and desc and len(desc) > 200:
|
||||
desc = desc[:200].rsplit(" ", 1)[0] + "..."
|
||||
tool_payload = {
|
||||
"name": t["name"],
|
||||
"description": desc,
|
||||
}
|
||||
if output_schema == "full":
|
||||
tool_payload["server"] = t["server"]
|
||||
tool_payload["input_schema"] = t["input_schema"]
|
||||
if output_schema == "names":
|
||||
# Store just the name string — will be collapsed to flat list below
|
||||
tool_payload: dict | str = t["name"]
|
||||
else:
|
||||
desc = t["description"]
|
||||
if output_schema == "simple" and desc and len(desc) > 200:
|
||||
desc = desc[:200].rsplit(" ", 1)[0] + "..."
|
||||
tool_payload = {
|
||||
"name": t["name"],
|
||||
"description": desc,
|
||||
}
|
||||
if output_schema == "full":
|
||||
tool_payload["server"] = t["server"]
|
||||
tool_payload["input_schema"] = t["input_schema"]
|
||||
|
||||
for provider in providers:
|
||||
bucket = groups.setdefault(
|
||||
@@ -502,17 +565,48 @@ def list_agent_tools(
|
||||
)
|
||||
bucket["tools"].append(tool_payload)
|
||||
|
||||
provider_auth = tool_provider_auth.get(t["name"], {}).get(provider, {})
|
||||
for cred_name, auth in provider_auth.items():
|
||||
bucket["authorization"][cred_name] = auth
|
||||
# Only accumulate full auth metadata for simple/full schemas.
|
||||
# summary/names use compact representations.
|
||||
if output_schema not in ("summary", "names"):
|
||||
provider_auth = tool_provider_auth.get(t["name"], {}).get(provider, {})
|
||||
for cred_name, auth in provider_auth.items():
|
||||
bucket["authorization"][cred_name] = auth
|
||||
|
||||
for _provider, bucket in groups.items():
|
||||
bucket["tools"] = sorted(bucket["tools"], key=lambda x: x["name"])
|
||||
bucket["authorization"] = dict(sorted(bucket["authorization"].items()))
|
||||
for provider, bucket in groups.items():
|
||||
if output_schema == "names":
|
||||
# Collapse to compact structure: flat sorted name list + credential keys only
|
||||
tool_names = sorted(set(bucket["tools"]))
|
||||
cred_keys: set[str] = set()
|
||||
for tn in tool_names:
|
||||
for prov_creds in tool_provider_auth.get(tn, {}).values():
|
||||
cred_keys.update(prov_creds.keys())
|
||||
groups[provider] = {
|
||||
"tool_count": len(tool_names),
|
||||
"credentials_required": sorted(cred_keys),
|
||||
"tool_names": tool_names,
|
||||
}
|
||||
else:
|
||||
bucket["tools"] = sorted(bucket["tools"], key=lambda x: x["name"])
|
||||
bucket["authorization"] = dict(sorted(bucket["authorization"].items()))
|
||||
|
||||
return dict(sorted(groups.items()))
|
||||
|
||||
provider_groups = _group_by_provider(all_tools)
|
||||
# Compute credential availability once (used for filtering and summary)
|
||||
available_creds: set[str] = (
|
||||
_get_available_credential_names() if credentials != "all" or output_schema == "summary"
|
||||
else set()
|
||||
)
|
||||
|
||||
# Apply credentials filter before grouping (filter tool list)
|
||||
filtered_tools = all_tools
|
||||
if credentials != "all":
|
||||
filtered_tools = [
|
||||
t
|
||||
for t in all_tools
|
||||
if (credentials == "available") == _tool_credentials_available(t["name"], available_creds)
|
||||
]
|
||||
|
||||
provider_groups = _group_by_provider(filtered_tools)
|
||||
|
||||
# Filter to a specific provider (preferred) or legacy prefix (fallback)
|
||||
if group != "all":
|
||||
@@ -520,20 +614,104 @@ def list_agent_tools(
|
||||
provider_groups = {group: provider_groups[group]}
|
||||
else:
|
||||
prefixed_tools = []
|
||||
for t in all_tools:
|
||||
for t in filtered_tools:
|
||||
parts = t["name"].split("_", 1)
|
||||
prefix = parts[0] if len(parts) > 1 else "general"
|
||||
if prefix == group:
|
||||
prefixed_tools.append(t)
|
||||
provider_groups = _group_by_provider(prefixed_tools)
|
||||
|
||||
all_names = sorted({t["name"] for p in provider_groups.values() for t in p["tools"]})
|
||||
result: dict = {
|
||||
"total": len(all_names),
|
||||
"tools_by_provider": provider_groups,
|
||||
"tools_by_category": provider_groups, # backward-compat alias
|
||||
"all_tool_names": all_names,
|
||||
}
|
||||
# Apply service filter (tool name prefix within a provider, e.g. service="gmail")
|
||||
if service:
|
||||
service_prefix = service.rstrip("_") + "_"
|
||||
service_filtered: list[dict] = []
|
||||
for t in filtered_tools:
|
||||
# Only include tools from the already-filtered provider set
|
||||
tool_name = t["name"]
|
||||
in_provider = any(tool_name in p.get("tool_names", [tool_entry.get("name") for tool_entry in p.get("tools", [])]) for p in provider_groups.values())
|
||||
if in_provider and tool_name.startswith(service_prefix):
|
||||
service_filtered.append(t)
|
||||
provider_groups = _group_by_provider(service_filtered)
|
||||
|
||||
def _infer_service(tool_name: str) -> str:
|
||||
"""Infer service name from tool name prefix (e.g. 'gmail' from 'gmail_send_message')."""
|
||||
return tool_name.split("_", 1)[0]
|
||||
|
||||
# Summary mode: compact overview with counts + credential status
|
||||
if output_schema == "summary":
|
||||
if group == "all":
|
||||
# Provider-level summary (default first call)
|
||||
full_groups = _group_by_provider(all_tools) if credentials != "all" else provider_groups
|
||||
summary_providers: dict = {}
|
||||
for prov, bucket in full_groups.items():
|
||||
cred_names = bucket.get("credentials_required", sorted(bucket.get("authorization", {}).keys()))
|
||||
creds_ok = all(c in available_creds for c in cred_names) if cred_names else True
|
||||
summary_providers[prov] = {
|
||||
"tool_count": len(bucket.get("tool_names", bucket.get("tools", []))),
|
||||
"credentials_required": cred_names,
|
||||
"credentials_available": creds_ok,
|
||||
}
|
||||
result: dict = {
|
||||
"total_tools": sum(v["tool_count"] for v in summary_providers.values()),
|
||||
"providers": summary_providers,
|
||||
"hint": (
|
||||
"Use list_agent_tools(group='<provider>', output_schema='summary') for service breakdown, "
|
||||
"list_agent_tools(group='<provider>', service='<service>') for tool names. "
|
||||
"Filter by credentials='available' to see only ready-to-use tools."
|
||||
),
|
||||
}
|
||||
else:
|
||||
# Service-level breakdown within a specific provider
|
||||
# Re-build from all filtered tools for this provider (ignore service filter for summary)
|
||||
provider_tool_names: list[str] = []
|
||||
for bucket in provider_groups.values():
|
||||
provider_tool_names.extend(
|
||||
bucket.get("tool_names", [e.get("name") for e in bucket.get("tools", [])])
|
||||
)
|
||||
|
||||
services: dict = {}
|
||||
for tn in sorted(set(provider_tool_names)):
|
||||
svc = _infer_service(tn)
|
||||
if svc not in services:
|
||||
svc_creds: set[str] = set()
|
||||
for prov_creds in tool_provider_auth.get(tn, {}).values():
|
||||
svc_creds.update(prov_creds.keys())
|
||||
services[svc] = {"tool_count": 0, "credentials_required": sorted(svc_creds)}
|
||||
services[svc]["tool_count"] += 1
|
||||
# Accumulate credentials for other tools in this service
|
||||
for prov_creds in tool_provider_auth.get(tn, {}).values():
|
||||
existing = set(services[svc]["credentials_required"])
|
||||
existing.update(prov_creds.keys())
|
||||
services[svc]["credentials_required"] = sorted(existing)
|
||||
|
||||
result = {
|
||||
"provider": group,
|
||||
"total_tools": len(provider_tool_names),
|
||||
"services": services,
|
||||
"hint": (
|
||||
f"Use list_agent_tools(group='{group}', service='<service>') "
|
||||
"for tool names within a service."
|
||||
),
|
||||
}
|
||||
if errors:
|
||||
result["errors"] = errors
|
||||
return json.dumps(result, indent=2, default=str)
|
||||
|
||||
if output_schema == "names":
|
||||
# Compact result: no duplication, no all_tool_names list
|
||||
total = sum(p["tool_count"] for p in provider_groups.values())
|
||||
result = {
|
||||
"total": total,
|
||||
"tools_by_provider": provider_groups,
|
||||
}
|
||||
else:
|
||||
all_names = sorted({t["name"] for p in provider_groups.values() for t in p["tools"]})
|
||||
result = {
|
||||
"total": len(all_names),
|
||||
"tools_by_provider": provider_groups,
|
||||
"tools_by_category": provider_groups, # backward-compat alias
|
||||
"all_tool_names": all_names,
|
||||
}
|
||||
if errors:
|
||||
result["errors"] = errors
|
||||
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Windows atomic file replacement with DACL preservation.
|
||||
|
||||
Uses ReplaceFileW for atomic replacement, then SetFileSecurityW to
|
||||
restore the exact original DACL. ReplaceFileW merges ACEs from the
|
||||
temp file, which can duplicate inherited entries. SetFileSecurityW
|
||||
restores the security descriptor as-is without re-evaluating
|
||||
inheritance (unlike SetNamedSecurityInfoW).
|
||||
|
||||
On non-NTFS volumes (e.g. FAT32), DACL snapshot/restore is skipped
|
||||
gracefully and only the atomic replacement is performed.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import ctypes.wintypes
|
||||
|
||||
_DACL_SECURITY_INFORMATION = 0x00000004
|
||||
_REPLACEFILE_IGNORE_MERGE_ERRORS = 0x00000002
|
||||
|
||||
_advapi32 = None
|
||||
_kernel32 = None
|
||||
|
||||
if hasattr(ctypes, "windll"):
|
||||
_advapi32 = ctypes.windll.advapi32
|
||||
_kernel32 = ctypes.windll.kernel32
|
||||
|
||||
_advapi32.GetFileSecurityW.argtypes = [
|
||||
ctypes.wintypes.LPCWSTR, # lpFileName
|
||||
ctypes.wintypes.DWORD, # RequestedInformation
|
||||
ctypes.c_void_p, # pSecurityDescriptor
|
||||
ctypes.wintypes.DWORD, # nLength
|
||||
ctypes.POINTER(ctypes.wintypes.DWORD), # lpnLengthNeeded
|
||||
]
|
||||
_advapi32.GetFileSecurityW.restype = ctypes.wintypes.BOOL
|
||||
|
||||
_advapi32.SetFileSecurityW.argtypes = [
|
||||
ctypes.wintypes.LPCWSTR, # lpFileName
|
||||
ctypes.wintypes.DWORD, # SecurityInformation
|
||||
ctypes.c_void_p, # pSecurityDescriptor
|
||||
]
|
||||
_advapi32.SetFileSecurityW.restype = ctypes.wintypes.BOOL
|
||||
|
||||
_kernel32.ReplaceFileW.argtypes = [
|
||||
ctypes.wintypes.LPCWSTR, # lpReplacedFileName
|
||||
ctypes.wintypes.LPCWSTR, # lpReplacementFileName
|
||||
ctypes.wintypes.LPCWSTR, # lpBackupFileName
|
||||
ctypes.wintypes.DWORD, # dwReplaceFlags
|
||||
ctypes.c_void_p, # lpExclude (reserved)
|
||||
ctypes.c_void_p, # lpReserved
|
||||
]
|
||||
_kernel32.ReplaceFileW.restype = ctypes.wintypes.BOOL
|
||||
|
||||
|
||||
def snapshot_dacl(path: str) -> ctypes.Array | None:
|
||||
"""Save a file's DACL as raw bytes. Returns None on non-NTFS."""
|
||||
if _advapi32 is None:
|
||||
return None
|
||||
|
||||
needed = ctypes.wintypes.DWORD()
|
||||
_advapi32.GetFileSecurityW(
|
||||
path,
|
||||
_DACL_SECURITY_INFORMATION,
|
||||
None,
|
||||
0,
|
||||
ctypes.byref(needed),
|
||||
)
|
||||
if needed.value == 0:
|
||||
return None
|
||||
sd_buf = ctypes.create_string_buffer(needed.value)
|
||||
if not _advapi32.GetFileSecurityW(
|
||||
path,
|
||||
_DACL_SECURITY_INFORMATION,
|
||||
sd_buf,
|
||||
needed.value,
|
||||
ctypes.byref(needed),
|
||||
):
|
||||
return None
|
||||
return sd_buf
|
||||
|
||||
|
||||
def atomic_replace(target: str, replacement: str) -> None:
|
||||
"""Atomically replace *target* with *replacement*, preserving the DACL.
|
||||
|
||||
Uses ReplaceFileW for the atomic swap, then restores the original
|
||||
DACL via SetFileSecurityW (best-effort).
|
||||
"""
|
||||
if _kernel32 is None or _advapi32 is None:
|
||||
raise OSError("atomic_replace is only available on Windows")
|
||||
|
||||
sd_buf = snapshot_dacl(target)
|
||||
|
||||
if not _kernel32.ReplaceFileW(
|
||||
target,
|
||||
replacement,
|
||||
None,
|
||||
_REPLACEFILE_IGNORE_MERGE_ERRORS,
|
||||
None,
|
||||
None,
|
||||
):
|
||||
raise ctypes.WinError()
|
||||
|
||||
# Best-effort: content is already saved, don't fail the whole edit
|
||||
# over a DACL restore failure.
|
||||
if sd_buf is not None:
|
||||
_advapi32.SetFileSecurityW(
|
||||
target,
|
||||
_DACL_SECURITY_INFORMATION,
|
||||
sd_buf,
|
||||
)
|
||||
@@ -40,7 +40,6 @@ Credential categories:
|
||||
- discord.py: Discord bot credentials
|
||||
- github.py: GitHub API credentials
|
||||
- google_analytics.py: Google Analytics 4 Data API credentials
|
||||
- google_docs.py: Google Docs API credentials
|
||||
- google_maps.py: Google Maps Platform credentials
|
||||
- hubspot.py: HubSpot CRM credentials
|
||||
- intercom.py: Intercom customer messaging credentials
|
||||
@@ -81,7 +80,6 @@ from .gcp_vision import GCP_VISION_CREDENTIALS
|
||||
from .github import GITHUB_CREDENTIALS
|
||||
from .gitlab import GITLAB_CREDENTIALS
|
||||
from .google_analytics import GOOGLE_ANALYTICS_CREDENTIALS
|
||||
from .google_docs import GOOGLE_DOCS_CREDENTIALS
|
||||
from .google_maps import GOOGLE_MAPS_CREDENTIALS
|
||||
from .google_search_console import GOOGLE_SEARCH_CONSOLE_CREDENTIALS
|
||||
from .greenhouse import GREENHOUSE_CREDENTIALS
|
||||
@@ -171,7 +169,6 @@ CREDENTIAL_SPECS = {
|
||||
**GREENHOUSE_CREDENTIALS,
|
||||
**GITLAB_CREDENTIALS,
|
||||
**GOOGLE_ANALYTICS_CREDENTIALS,
|
||||
**GOOGLE_DOCS_CREDENTIALS,
|
||||
**GOOGLE_MAPS_CREDENTIALS,
|
||||
**GOOGLE_SEARCH_CONSOLE_CREDENTIALS,
|
||||
**HUBSPOT_CREDENTIALS,
|
||||
@@ -264,7 +261,6 @@ __all__ = [
|
||||
"GREENHOUSE_CREDENTIALS",
|
||||
"GITLAB_CREDENTIALS",
|
||||
"GOOGLE_ANALYTICS_CREDENTIALS",
|
||||
"GOOGLE_DOCS_CREDENTIALS",
|
||||
"GOOGLE_MAPS_CREDENTIALS",
|
||||
"GOOGLE_SEARCH_CONSOLE_CREDENTIALS",
|
||||
"HUBSPOT_CREDENTIALS",
|
||||
|
||||
@@ -69,12 +69,26 @@ EMAIL_CREDENTIALS = {
|
||||
"google_sheets_batch_clear_values",
|
||||
"google_sheets_add_sheet",
|
||||
"google_sheets_delete_sheet",
|
||||
# Google Docs tools
|
||||
"google_docs_create_document",
|
||||
"google_docs_get_document",
|
||||
"google_docs_insert_text",
|
||||
"google_docs_replace_all_text",
|
||||
"google_docs_insert_image",
|
||||
"google_docs_format_text",
|
||||
"google_docs_batch_update",
|
||||
"google_docs_create_list",
|
||||
"google_docs_add_comment",
|
||||
"google_docs_list_comments",
|
||||
"google_docs_export_content",
|
||||
],
|
||||
node_types=[],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://hive.adenhq.com",
|
||||
description="Google OAuth2 access token (via Aden) - used for Gmail, Calendar, and Sheets",
|
||||
description=(
|
||||
"Google OAuth2 access token (via Aden) - used for Gmail, Calendar, Sheets, and Docs"
|
||||
),
|
||||
aden_supported=True,
|
||||
aden_provider_name="google",
|
||||
direct_api_key_supported=False,
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
Google Docs tool credentials.
|
||||
|
||||
Contains credentials for Google Docs API integration.
|
||||
"""
|
||||
|
||||
from .base import CredentialSpec
|
||||
|
||||
GOOGLE_DOCS_CREDENTIALS = {
|
||||
"google_docs": CredentialSpec(
|
||||
env_var="GOOGLE_DOCS_ACCESS_TOKEN",
|
||||
tools=[
|
||||
"google_docs_create_document",
|
||||
"google_docs_get_document",
|
||||
"google_docs_insert_text",
|
||||
"google_docs_replace_all_text",
|
||||
"google_docs_insert_image",
|
||||
"google_docs_format_text",
|
||||
"google_docs_batch_update",
|
||||
"google_docs_create_list",
|
||||
"google_docs_add_comment",
|
||||
"google_docs_list_comments",
|
||||
"google_docs_export_content",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://console.cloud.google.com/apis/credentials",
|
||||
description="Google Docs OAuth2 access token",
|
||||
# Auth method support
|
||||
aden_supported=True,
|
||||
aden_provider_name="google",
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Google Docs access token:
|
||||
1. Go to Google Cloud Console: https://console.cloud.google.com/
|
||||
2. Create a new project or select an existing one
|
||||
3. Enable the Google Docs API and Google Drive API
|
||||
4. Go to APIs & Services > Credentials
|
||||
5. Create OAuth 2.0 credentials (Web application or Desktop app)
|
||||
6. Use the OAuth 2.0 Playground or your app to get an access token
|
||||
7. Required scopes:
|
||||
- https://www.googleapis.com/auth/documents
|
||||
- https://www.googleapis.com/auth/drive.file
|
||||
- https://www.googleapis.com/auth/drive (for export/comments)""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://docs.googleapis.com/v1/documents/1",
|
||||
health_check_method="GET",
|
||||
# Credential store mapping
|
||||
credential_id="google_docs",
|
||||
credential_key="access_token",
|
||||
),
|
||||
}
|
||||
@@ -1068,16 +1068,6 @@ class ExaSearchHealthChecker(BaseHttpHealthChecker):
|
||||
return {"query": "test", "numResults": 1}
|
||||
|
||||
|
||||
class GoogleDocsHealthChecker(OAuthBearerHealthChecker):
|
||||
"""Health checker for Google Docs OAuth tokens."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
endpoint="https://docs.googleapis.com/v1/documents/1",
|
||||
service_name="Google Docs",
|
||||
)
|
||||
|
||||
|
||||
class CalcomHealthChecker(BaseHttpHealthChecker):
|
||||
"""Health checker for Cal.com API key."""
|
||||
|
||||
@@ -1334,7 +1324,6 @@ HEALTH_CHECKERS: dict[str, CredentialHealthChecker] = {
|
||||
"github": GitHubHealthChecker(),
|
||||
"gitlab_token": GitLabHealthChecker(),
|
||||
"google": GoogleHealthChecker(),
|
||||
"google_docs": GoogleDocsHealthChecker(),
|
||||
"google_maps": GoogleMapsHealthChecker(),
|
||||
"google_search": GoogleSearchHealthChecker(),
|
||||
"google_search_console": GoogleSearchConsoleHealthChecker(),
|
||||
|
||||
@@ -23,6 +23,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
@@ -965,16 +966,25 @@ def register_file_tools(
|
||||
try:
|
||||
if before_write:
|
||||
before_write()
|
||||
original_mode = os.stat(resolved).st_mode
|
||||
fd, tmp_path = tempfile.mkstemp(dir=os.path.dirname(resolved))
|
||||
fd_open = True
|
||||
try:
|
||||
if hasattr(os, "fchmod"):
|
||||
os.fchmod(fd, original_mode)
|
||||
match sys.platform:
|
||||
case "win32":
|
||||
pass # ACL preservation handled by atomic_replace below
|
||||
case _:
|
||||
original_mode = os.stat(resolved).st_mode
|
||||
os.fchmod(fd, original_mode)
|
||||
with os.fdopen(fd, "w", encoding=encoding, newline="") as f:
|
||||
fd_open = False
|
||||
f.write(joined)
|
||||
os.replace(tmp_path, resolved)
|
||||
match sys.platform:
|
||||
case "win32":
|
||||
from aden_tools._win32_atomic import atomic_replace
|
||||
|
||||
atomic_replace(resolved, tmp_path)
|
||||
case _:
|
||||
os.replace(tmp_path, resolved)
|
||||
except BaseException:
|
||||
if fd_open:
|
||||
os.close(fd)
|
||||
|
||||
@@ -2,6 +2,7 @@ import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
@@ -380,16 +381,25 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
|
||||
# 9. Atomic write (write-to-tmp + os.replace)
|
||||
try:
|
||||
original_mode = os.stat(secure_path).st_mode
|
||||
fd, tmp_path = tempfile.mkstemp(dir=os.path.dirname(secure_path))
|
||||
fd_open = True
|
||||
try:
|
||||
if hasattr(os, "fchmod"):
|
||||
os.fchmod(fd, original_mode)
|
||||
match sys.platform:
|
||||
case "win32":
|
||||
pass # ACL preservation handled by atomic_replace below
|
||||
case _:
|
||||
original_mode = os.stat(secure_path).st_mode
|
||||
os.fchmod(fd, original_mode)
|
||||
with os.fdopen(fd, "w", encoding=encoding, newline="") as f:
|
||||
fd_open = False
|
||||
f.write(joined)
|
||||
os.replace(tmp_path, secure_path)
|
||||
match sys.platform:
|
||||
case "win32":
|
||||
from aden_tools._win32_atomic import atomic_replace
|
||||
|
||||
atomic_replace(secure_path, tmp_path)
|
||||
case _:
|
||||
os.replace(tmp_path, secure_path)
|
||||
except BaseException:
|
||||
if fd_open:
|
||||
os.close(fd)
|
||||
|
||||
@@ -26,14 +26,13 @@ Create and manage Google Docs documents via the Google Docs API v1.
|
||||
6. Set the environment variable:
|
||||
|
||||
```bash
|
||||
export GOOGLE_DOCS_ACCESS_TOKEN="your-access-token"
|
||||
export GOOGLE_ACCESS_TOKEN="your-access-token"
|
||||
```
|
||||
|
||||
### Required OAuth Scopes
|
||||
|
||||
- `https://www.googleapis.com/auth/documents` - Full access to Google Docs
|
||||
- `https://www.googleapis.com/auth/drive.file` - Access to files created/opened by the app
|
||||
- `https://www.googleapis.com/auth/drive` - Required for document export and comment functionality
|
||||
- `https://www.googleapis.com/auth/documents` - Google Docs API (create, read, edit documents)
|
||||
- `https://www.googleapis.com/auth/drive.file` - Google Drive API (export, comments)
|
||||
|
||||
## Available Tools
|
||||
|
||||
@@ -144,4 +143,4 @@ All tools return a dict. On error, the dict contains an `"error"` key with a des
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `GOOGLE_DOCS_ACCESS_TOKEN` | Yes | OAuth2 access token |
|
||||
| `GOOGLE_ACCESS_TOKEN` | Yes | OAuth2 access token (shared with Gmail, Calendar, Sheets) |
|
||||
|
||||
@@ -3,7 +3,7 @@ Google Docs Tool - Create and manage Google Docs documents via Google Docs API v
|
||||
|
||||
Supports:
|
||||
- OAuth2 tokens via the credential store
|
||||
- Direct access token (GOOGLE_DOCS_ACCESS_TOKEN)
|
||||
- Direct access token (GOOGLE_ACCESS_TOKEN)
|
||||
|
||||
API Reference: https://developers.google.com/docs/api/reference/rest
|
||||
|
||||
@@ -18,7 +18,6 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -30,8 +29,6 @@ if TYPE_CHECKING:
|
||||
|
||||
GOOGLE_DOCS_API_BASE = "https://docs.googleapis.com/v1"
|
||||
GOOGLE_DRIVE_API_BASE = "https://www.googleapis.com/drive/v3"
|
||||
GOOGLE_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# Allowed URL schemes for image insertion
|
||||
ALLOWED_IMAGE_SCHEMES = {"https", "http"}
|
||||
# Regex pattern for valid URLs
|
||||
@@ -99,105 +96,6 @@ def _get_document_end_index(doc: dict[str, Any]) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _create_service_account_token(service_account_json: str) -> str | None:
|
||||
"""Create an access token from a service account JSON using JWT.
|
||||
|
||||
This implements the OAuth 2.0 service account flow:
|
||||
1. Create a signed JWT
|
||||
2. Exchange it for an access token
|
||||
|
||||
Args:
|
||||
service_account_json: The service account JSON string
|
||||
|
||||
Returns:
|
||||
Access token string, or None if token creation failed
|
||||
"""
|
||||
try:
|
||||
sa_data = json.loads(service_account_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
# Check if this is actually a service account
|
||||
if sa_data.get("type") != "service_account":
|
||||
# Not a service account, check for direct access token
|
||||
return sa_data.get("access_token")
|
||||
|
||||
# Required fields for service account
|
||||
private_key = sa_data.get("private_key")
|
||||
client_email = sa_data.get("client_email")
|
||||
token_uri = sa_data.get("token_uri", GOOGLE_OAUTH_TOKEN_URL)
|
||||
|
||||
if not private_key or not client_email:
|
||||
return None
|
||||
|
||||
# Create JWT header and claims
|
||||
now = int(time.time())
|
||||
header = {"alg": "RS256", "typ": "JWT"}
|
||||
claims = {
|
||||
"iss": client_email,
|
||||
"sub": client_email,
|
||||
"aud": token_uri,
|
||||
"iat": now,
|
||||
"exp": now + 3600, # 1 hour expiry
|
||||
"scope": (
|
||||
"https://www.googleapis.com/auth/documents "
|
||||
"https://www.googleapis.com/auth/drive.file "
|
||||
"https://www.googleapis.com/auth/drive"
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
# Try using cryptography library for RSA signing
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
# Encode header and claims
|
||||
def _b64url_encode(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8")
|
||||
|
||||
header_b64 = _b64url_encode(json.dumps(header).encode())
|
||||
claims_b64 = _b64url_encode(json.dumps(claims).encode())
|
||||
signing_input = f"{header_b64}.{claims_b64}"
|
||||
|
||||
# Load private key and sign
|
||||
private_key_obj = serialization.load_pem_private_key(
|
||||
private_key.encode(), password=None, backend=default_backend()
|
||||
)
|
||||
signature = private_key_obj.sign(
|
||||
signing_input.encode(),
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
signature_b64 = _b64url_encode(signature)
|
||||
|
||||
jwt_token = f"{signing_input}.{signature_b64}"
|
||||
|
||||
# Exchange JWT for access token
|
||||
response = httpx.post(
|
||||
token_uri,
|
||||
data={
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
"assertion": jwt_token,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
token_data = response.json()
|
||||
return token_data.get("access_token")
|
||||
|
||||
return None
|
||||
|
||||
except ImportError:
|
||||
# cryptography not available, cannot sign JWT
|
||||
# Fall back to checking for pre-exchanged token
|
||||
return sa_data.get("access_token")
|
||||
except Exception:
|
||||
# Any signing/exchange error
|
||||
return None
|
||||
|
||||
|
||||
class _GoogleDocsClient:
|
||||
"""Internal client wrapping Google Docs API v1 calls."""
|
||||
|
||||
@@ -486,25 +384,16 @@ def register_tools(
|
||||
if credentials is not None:
|
||||
if account:
|
||||
return credentials.get_by_alias(
|
||||
"google_docs",
|
||||
"google",
|
||||
account,
|
||||
)
|
||||
token = credentials.get("google_docs")
|
||||
token = credentials.get("google")
|
||||
if token is not None and not isinstance(token, str):
|
||||
raise TypeError(
|
||||
f"Expected string from credentials.get('google_docs'), "
|
||||
f"got {type(token).__name__}"
|
||||
f"Expected string from credentials.get('google'), got {type(token).__name__}"
|
||||
)
|
||||
return token
|
||||
# Try environment variables - direct access token first
|
||||
token = os.getenv("GOOGLE_DOCS_ACCESS_TOKEN")
|
||||
if token:
|
||||
return token
|
||||
# Try service account JSON with proper JWT token exchange
|
||||
service_account = os.getenv("GOOGLE_SERVICE_ACCOUNT_JSON")
|
||||
if service_account:
|
||||
return _create_service_account_token(service_account)
|
||||
return None
|
||||
return os.getenv("GOOGLE_ACCESS_TOKEN")
|
||||
|
||||
def _get_client(account: str = "") -> _GoogleDocsClient | dict[str, str]:
|
||||
"""Get a Google Docs client, or return an error dict if no credentials."""
|
||||
@@ -513,9 +402,8 @@ def register_tools(
|
||||
return {
|
||||
"error": "Google Docs credentials not configured",
|
||||
"help": (
|
||||
"Set GOOGLE_DOCS_ACCESS_TOKEN environment variable "
|
||||
"or configure via credential store. "
|
||||
"Get credentials at: https://console.cloud.google.com/apis/credentials"
|
||||
"Set GOOGLE_ACCESS_TOKEN environment variable "
|
||||
"or configure 'google' via credential store"
|
||||
),
|
||||
}
|
||||
return _GoogleDocsClient(token)
|
||||
|
||||
@@ -49,16 +49,6 @@ class TestGoogleDocsCreateDocument:
|
||||
assert "not configured" in result["error"]
|
||||
assert "help" in result
|
||||
|
||||
def test_service_account_json_without_access_token_is_not_used(self, mcp):
|
||||
"""Test that service account JSON alone is not treated as an access token."""
|
||||
with patch.dict(
|
||||
"os.environ", {"GOOGLE_SERVICE_ACCOUNT_JSON": '{"type":"service_account"}'}
|
||||
):
|
||||
tool_fn = get_tool_fn(mcp, "google_docs_create_document")
|
||||
result = tool_fn(title="Test Document")
|
||||
assert "error" in result
|
||||
assert "not configured" in result["error"]
|
||||
|
||||
@patch("httpx.post")
|
||||
def test_create_document_success(self, mock_post, mcp_with_credentials):
|
||||
"""Test successful document creation."""
|
||||
@@ -444,34 +434,6 @@ class TestReplaceAllTextValidation:
|
||||
assert "empty" in result["error"].lower()
|
||||
|
||||
|
||||
class TestServiceAccountTokenExchange:
|
||||
"""Tests for service account JWT token exchange."""
|
||||
|
||||
@patch("httpx.post")
|
||||
@patch.dict(
|
||||
"os.environ",
|
||||
{"GOOGLE_SERVICE_ACCOUNT_JSON": '{"access_token": "pre-exchanged-token"}'},
|
||||
)
|
||||
def test_fallback_to_pre_exchanged_token(self, mock_post):
|
||||
"""Test that pre-exchanged tokens in JSON are used as fallback."""
|
||||
server = FastMCP("test")
|
||||
register_tools(server)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"documentId": "doc123",
|
||||
"title": "Test",
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool_fn = get_tool_fn(server, "google_docs_create_document")
|
||||
result = tool_fn(title="Test")
|
||||
|
||||
# Should use the pre-exchanged token and make the API call
|
||||
assert "error" not in result or "not configured" not in result.get("error", "")
|
||||
|
||||
|
||||
class TestGoogleDocsListComments:
|
||||
"""Tests for google_docs_list_comments tool."""
|
||||
|
||||
|
||||
@@ -73,7 +73,6 @@ class TestHealthCheckerRegistry:
|
||||
"github",
|
||||
"gitlab_token",
|
||||
"google",
|
||||
"google_docs",
|
||||
"google_maps",
|
||||
"google_search",
|
||||
"google_search_console",
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
"""Tests for DNS Security Scanner tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.dns_security_scanner import register_tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dns_tools(mcp: FastMCP):
|
||||
"""Register DNS security tools and return tool functions."""
|
||||
register_tools(mcp)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scan_fn(dns_tools):
|
||||
return dns_tools["dns_security_scan"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input Validation & Cleaning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
"""Test domain input cleaning and validation."""
|
||||
|
||||
def test_strips_https_prefix(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
import dns.resolver
|
||||
|
||||
mock = MagicMock()
|
||||
mock.resolve.side_effect = dns.resolver.NXDOMAIN()
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("https://example.com")
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
def test_strips_http_prefix(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
import dns.resolver
|
||||
|
||||
mock = MagicMock()
|
||||
mock.resolve.side_effect = dns.resolver.NXDOMAIN()
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("http://example.com")
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
def test_strips_trailing_slash(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
import dns.resolver
|
||||
|
||||
mock = MagicMock()
|
||||
mock.resolve.side_effect = dns.resolver.NXDOMAIN()
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com/")
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
def test_strips_path(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
import dns.resolver
|
||||
|
||||
mock = MagicMock()
|
||||
mock.resolve.side_effect = dns.resolver.NXDOMAIN()
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com/path/to/page")
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
def test_strips_port(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
import dns.resolver
|
||||
|
||||
mock = MagicMock()
|
||||
mock.resolve.side_effect = dns.resolver.NXDOMAIN()
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com:8080")
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DNS Library Availability
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDnsAvailability:
|
||||
"""Test behavior when dnspython is not installed."""
|
||||
|
||||
def test_dns_not_available(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", False
|
||||
):
|
||||
result = scan_fn("example.com")
|
||||
assert "error" in result
|
||||
assert "dnspython" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SPF Record Checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSpfChecks:
|
||||
"""Test SPF record detection and policy analysis."""
|
||||
|
||||
def test_spf_hardfail_detected(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
mock = MagicMock()
|
||||
mock_rdata = MagicMock()
|
||||
mock_rdata.to_text.return_value = '"v=spf1 include:_spf.google.com -all"'
|
||||
mock.resolve.return_value = [mock_rdata]
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["spf"]["present"] is True
|
||||
assert result["spf"]["policy"] == "hardfail"
|
||||
assert result["grade_input"]["spf_strict"] is True
|
||||
|
||||
def test_spf_softfail_detected(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
mock = MagicMock()
|
||||
mock_rdata = MagicMock()
|
||||
mock_rdata.to_text.return_value = '"v=spf1 include:_spf.google.com ~all"'
|
||||
mock.resolve.return_value = [mock_rdata]
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["spf"]["present"] is True
|
||||
assert result["spf"]["policy"] == "softfail"
|
||||
assert result["grade_input"]["spf_strict"] is False
|
||||
|
||||
def test_spf_pass_all_dangerous(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
mock = MagicMock()
|
||||
mock_rdata = MagicMock()
|
||||
mock_rdata.to_text.return_value = '"v=spf1 +all"'
|
||||
mock.resolve.return_value = [mock_rdata]
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["spf"]["policy"] == "pass_all"
|
||||
assert len(result["spf"]["issues"]) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DMARC Record Checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDmarcChecks:
|
||||
"""Test DMARC record detection and policy analysis."""
|
||||
|
||||
def test_dmarc_reject_policy(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
mock = MagicMock()
|
||||
|
||||
def mock_resolve(domain, record_type):
|
||||
import dns.resolver
|
||||
|
||||
if record_type == "TXT" and "_dmarc" in domain:
|
||||
rdata = MagicMock()
|
||||
rdata.to_text.return_value = '"v=DMARC1; p=reject"'
|
||||
return [rdata]
|
||||
raise dns.resolver.NXDOMAIN()
|
||||
|
||||
mock.resolve = mock_resolve
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["dmarc"]["present"] is True
|
||||
assert result["dmarc"]["policy"] == "reject"
|
||||
assert result["grade_input"]["dmarc_enforcing"] is True
|
||||
|
||||
def test_dmarc_none_policy(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
mock = MagicMock()
|
||||
|
||||
def mock_resolve(domain, record_type):
|
||||
if record_type == "TXT" and "_dmarc" in domain:
|
||||
rdata = MagicMock()
|
||||
rdata.to_text.return_value = '"v=DMARC1; p=none"'
|
||||
return [rdata]
|
||||
import dns.resolver
|
||||
|
||||
raise dns.resolver.NXDOMAIN()
|
||||
|
||||
mock.resolve = mock_resolve
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["dmarc"]["policy"] == "none"
|
||||
assert result["grade_input"]["dmarc_enforcing"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grade Input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGradeInput:
|
||||
"""Test grade_input dict is properly constructed."""
|
||||
|
||||
def test_grade_input_keys_present(self, scan_fn):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner._DNS_AVAILABLE", True
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.dns_security_scanner.dns_security_scanner.dns.resolver.Resolver"
|
||||
) as MockResolver:
|
||||
mock = MagicMock()
|
||||
import dns.resolver
|
||||
|
||||
mock.resolve.side_effect = dns.resolver.NXDOMAIN()
|
||||
mock.timeout = 10
|
||||
mock.lifetime = 10
|
||||
MockResolver.return_value = mock
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert "grade_input" in result
|
||||
grade = result["grade_input"]
|
||||
assert "spf_present" in grade
|
||||
assert "spf_strict" in grade
|
||||
assert "dmarc_present" in grade
|
||||
assert "dmarc_enforcing" in grade
|
||||
assert "dkim_found" in grade
|
||||
assert "dnssec_enabled" in grade
|
||||
assert "zone_transfer_blocked" in grade
|
||||
@@ -3,6 +3,7 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
@@ -401,6 +402,95 @@ class TestHashlineEditAtomicWrite:
|
||||
hashline_edit(path="f.txt", edits=edits)
|
||||
assert os.stat(f).st_mode & 0o777 == 0o755
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-only ACL test")
|
||||
def test_acl_preserved_after_edit_windows(self, tools, tmp_path):
|
||||
"""Atomic replace preserves the target file's DACL on Windows."""
|
||||
import ctypes
|
||||
|
||||
advapi32 = ctypes.windll.advapi32
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
SE_FILE_OBJECT = 1
|
||||
DACL_SECURITY_INFORMATION = 0x00000004
|
||||
|
||||
advapi32.GetNamedSecurityInfoW.argtypes = [
|
||||
ctypes.wintypes.LPCWSTR, # pObjectName
|
||||
ctypes.c_uint, # ObjectType (SE_OBJECT_TYPE enum)
|
||||
ctypes.wintypes.DWORD, # SecurityInfo
|
||||
ctypes.c_void_p, # ppsidOwner
|
||||
ctypes.c_void_p, # ppsidGroup
|
||||
ctypes.c_void_p, # ppDacl
|
||||
ctypes.c_void_p, # ppSacl
|
||||
ctypes.c_void_p, # ppSecurityDescriptor
|
||||
]
|
||||
advapi32.GetNamedSecurityInfoW.restype = ctypes.wintypes.DWORD
|
||||
|
||||
advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW.argtypes = [
|
||||
ctypes.c_void_p, # SecurityDescriptor
|
||||
ctypes.wintypes.DWORD, # RequestedStringSDRevision
|
||||
ctypes.wintypes.DWORD, # SecurityInformation
|
||||
ctypes.c_void_p, # StringSecurityDescriptor (out)
|
||||
ctypes.c_void_p, # StringSecurityDescriptorLen (out, optional)
|
||||
]
|
||||
advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW.restype = ctypes.wintypes.BOOL
|
||||
|
||||
kernel32.LocalFree.argtypes = [ctypes.c_void_p]
|
||||
kernel32.LocalFree.restype = ctypes.c_void_p
|
||||
|
||||
hashline_edit = tools[0]["hashline_edit"]
|
||||
f = tmp_path / "f.txt"
|
||||
f.write_text("aaa\nbbb\n")
|
||||
|
||||
def _read_dacl_sddl(path):
|
||||
sd = ctypes.c_void_p()
|
||||
dacl = ctypes.c_void_p()
|
||||
rc = advapi32.GetNamedSecurityInfoW(
|
||||
str(path),
|
||||
SE_FILE_OBJECT,
|
||||
DACL_SECURITY_INFORMATION,
|
||||
None,
|
||||
None,
|
||||
ctypes.byref(dacl),
|
||||
None,
|
||||
ctypes.byref(sd),
|
||||
)
|
||||
assert rc == 0, f"GetNamedSecurityInfoW failed: {rc}"
|
||||
sddl = ctypes.c_wchar_p()
|
||||
assert advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW(
|
||||
sd,
|
||||
1,
|
||||
DACL_SECURITY_INFORMATION,
|
||||
ctypes.byref(sddl),
|
||||
None,
|
||||
)
|
||||
value = sddl.value
|
||||
kernel32.LocalFree(sddl)
|
||||
kernel32.LocalFree(sd)
|
||||
return value
|
||||
|
||||
acl_before = _read_dacl_sddl(f)
|
||||
|
||||
edits = json.dumps([{"op": "set_line", "anchor": _anchor(1, "aaa"), "content": "AAA"}])
|
||||
hashline_edit(path="f.txt", edits=edits)
|
||||
|
||||
acl_after = _read_dacl_sddl(f)
|
||||
|
||||
assert acl_before == acl_after, f"ACL changed after edit: {acl_before} -> {acl_after}"
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-only ACL test")
|
||||
def test_edit_succeeds_when_dacl_unavailable_windows(self, tools, tmp_path):
|
||||
"""Edit still works on volumes without ACL support (e.g. FAT32)."""
|
||||
from aden_tools import _win32_atomic
|
||||
|
||||
hashline_edit = tools[0]["hashline_edit"]
|
||||
f = tmp_path / "f.txt"
|
||||
f.write_text("aaa\nbbb\n")
|
||||
|
||||
with patch.object(_win32_atomic, "snapshot_dacl", return_value=None):
|
||||
edits = json.dumps([{"op": "set_line", "anchor": _anchor(1, "aaa"), "content": "AAA"}])
|
||||
hashline_edit(path="f.txt", edits=edits)
|
||||
|
||||
assert f.read_text().splitlines()[0].endswith("AAA")
|
||||
|
||||
def test_preserves_trailing_newline(self, tools, tmp_path):
|
||||
"""Files with trailing newline keep it after edit."""
|
||||
hashline_edit = tools[0]["hashline_edit"]
|
||||
|
||||
@@ -0,0 +1,599 @@
|
||||
"""Tests for Google Docs tool with FastMCP.
|
||||
|
||||
Covers:
|
||||
- Credential handling (credential store, env var, service account, missing)
|
||||
- _GoogleDocsClient methods (create, get, insert, replace, image, format, list, batch, export)
|
||||
- HTTP error handling (401, 403, 404, 429, 500, timeout)
|
||||
- All MCP tool functions via register_tools
|
||||
- Input validation (image URI, JSON parsing, list types, format types)
|
||||
- Helper functions (_validate_image_uri, _get_document_end_index)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.google_docs_tool.google_docs_tool import (
|
||||
GOOGLE_DOCS_API_BASE,
|
||||
_get_document_end_index,
|
||||
_GoogleDocsClient,
|
||||
_validate_image_uri,
|
||||
register_tools,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp():
|
||||
"""Create a FastMCP instance for testing."""
|
||||
return FastMCP("test-server")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a _GoogleDocsClient with a test token."""
|
||||
return _GoogleDocsClient("test-token")
|
||||
|
||||
|
||||
def _register(mcp, credentials=None):
|
||||
"""Helper to register tools and return the tool lookup dict."""
|
||||
register_tools(mcp, credentials=credentials)
|
||||
return mcp._tool_manager._tools
|
||||
|
||||
|
||||
def _tool_fn(mcp, name, credentials=None):
|
||||
"""Register tools and return a single tool function by name."""
|
||||
tools = _register(mcp, credentials)
|
||||
return tools[name].fn
|
||||
|
||||
|
||||
def _mock_response(status_code=200, json_data=None, text="", content=b""):
|
||||
"""Create a mock httpx.Response."""
|
||||
resp = MagicMock(spec=httpx.Response)
|
||||
resp.status_code = status_code
|
||||
resp.text = text
|
||||
resp.content = content
|
||||
if json_data is not None:
|
||||
resp.json.return_value = json_data
|
||||
else:
|
||||
resp.json.return_value = {}
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper function tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateImageUri:
|
||||
"""Tests for _validate_image_uri."""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
assert _validate_image_uri("https://example.com/image.png") is None
|
||||
|
||||
def test_valid_http_url(self):
|
||||
assert _validate_image_uri("http://example.com/image.jpg") is None
|
||||
|
||||
def test_empty_uri(self):
|
||||
result = _validate_image_uri("")
|
||||
assert result is not None
|
||||
assert "error" in result
|
||||
|
||||
def test_whitespace_uri(self):
|
||||
result = _validate_image_uri(" ")
|
||||
assert result is not None
|
||||
assert "error" in result
|
||||
|
||||
def test_missing_scheme(self):
|
||||
result = _validate_image_uri("example.com/image.png")
|
||||
assert result is not None
|
||||
assert "missing scheme" in result["error"]
|
||||
|
||||
def test_disallowed_scheme_ftp(self):
|
||||
result = _validate_image_uri("ftp://example.com/image.png")
|
||||
assert result is not None
|
||||
assert "Only" in result["error"]
|
||||
|
||||
def test_disallowed_scheme_javascript(self):
|
||||
result = _validate_image_uri("javascript:alert(1)")
|
||||
assert result is not None
|
||||
assert "error" in result
|
||||
|
||||
def test_missing_domain(self):
|
||||
result = _validate_image_uri("https://")
|
||||
assert result is not None
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestGetDocumentEndIndex:
|
||||
"""Tests for _get_document_end_index."""
|
||||
|
||||
def test_returns_end_index_minus_one(self):
|
||||
doc = {
|
||||
"body": {
|
||||
"content": [
|
||||
{"startIndex": 1, "endIndex": 50},
|
||||
]
|
||||
}
|
||||
}
|
||||
assert _get_document_end_index(doc) == 49
|
||||
|
||||
def test_empty_content_returns_one(self):
|
||||
doc = {"body": {"content": []}}
|
||||
assert _get_document_end_index(doc) == 1
|
||||
|
||||
def test_no_body_returns_one(self):
|
||||
doc = {}
|
||||
assert _get_document_end_index(doc) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _GoogleDocsClient unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoogleDocsClientHeaders:
|
||||
def test_headers_contain_bearer_token(self, client):
|
||||
headers = client._headers
|
||||
assert headers["Authorization"] == "Bearer test-token"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
|
||||
class TestGoogleDocsClientHandleResponse:
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_substr",
|
||||
[
|
||||
(401, "Invalid or expired"),
|
||||
(403, "Insufficient permissions"),
|
||||
(404, "not found"),
|
||||
(429, "rate limit"),
|
||||
],
|
||||
)
|
||||
def test_known_error_codes(self, client, status_code, expected_substr):
|
||||
resp = _mock_response(status_code=status_code)
|
||||
result = client._handle_response(resp)
|
||||
assert "error" in result
|
||||
assert expected_substr in result["error"]
|
||||
|
||||
def test_generic_error_with_nested_message(self, client):
|
||||
resp = _mock_response(
|
||||
status_code=400,
|
||||
json_data={"error": {"message": "Invalid request"}},
|
||||
)
|
||||
result = client._handle_response(resp)
|
||||
assert "Invalid request" in result["error"]
|
||||
|
||||
def test_success_returns_json(self, client):
|
||||
resp = _mock_response(200, {"documentId": "doc-1"})
|
||||
assert client._handle_response(resp) == {"documentId": "doc-1"}
|
||||
|
||||
|
||||
class TestGoogleDocsClientCreateDocument:
|
||||
def test_posts_title(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"documentId": "doc-1", "title": "My Doc"})
|
||||
result = client.create_document("My Doc")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body == {"title": "My Doc"}
|
||||
assert result["documentId"] == "doc-1"
|
||||
|
||||
|
||||
class TestGoogleDocsClientGetDocument:
|
||||
def test_gets_correct_url(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"documentId": "doc-1"})
|
||||
client.get_document("doc-1")
|
||||
args, _ = mock_get.call_args
|
||||
assert args[0] == f"{GOOGLE_DOCS_API_BASE}/documents/doc-1"
|
||||
|
||||
|
||||
class TestGoogleDocsClientBatchUpdate:
|
||||
def test_batch_update_sends_requests(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
requests = [{"insertText": {"text": "hello", "location": {"index": 1}}}]
|
||||
client.batch_update("doc-1", requests)
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["requests"] == requests
|
||||
|
||||
|
||||
class TestGoogleDocsClientInsertText:
|
||||
def test_insert_at_index(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
client.insert_text("doc-1", "Hello", index=5)
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
req = body["requests"][0]["insertText"]
|
||||
assert req["text"] == "Hello"
|
||||
assert req["location"]["index"] == 5
|
||||
|
||||
def test_insert_at_end_fetches_doc(self, client):
|
||||
with patch("httpx.get") as mock_get, patch("httpx.post") as mock_post:
|
||||
mock_get.return_value = _mock_response(
|
||||
200,
|
||||
{"body": {"content": [{"startIndex": 1, "endIndex": 20}]}},
|
||||
)
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
client.insert_text("doc-1", "Appended text")
|
||||
# Should have fetched doc to determine end index
|
||||
mock_get.assert_called_once()
|
||||
|
||||
|
||||
class TestGoogleDocsClientReplaceAllText:
|
||||
def test_replace_sends_correct_request(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
client.replace_all_text("doc-1", "{{NAME}}", "Alice")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
req = body["requests"][0]["replaceAllText"]
|
||||
assert req["containsText"]["text"] == "{{NAME}}"
|
||||
assert req["replaceText"] == "Alice"
|
||||
|
||||
def test_empty_find_text_returns_error(self, client):
|
||||
result = client.replace_all_text("doc-1", "", "Alice")
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestGoogleDocsClientInsertImage:
|
||||
def test_valid_image_insertion(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
client.insert_image("doc-1", "https://example.com/img.png", index=1)
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
req = body["requests"][0]["insertInlineImage"]
|
||||
assert req["uri"] == "https://example.com/img.png"
|
||||
|
||||
def test_invalid_uri_returns_error(self, client):
|
||||
result = client.insert_image("doc-1", "ftp://bad.com/img.png", index=1)
|
||||
assert "error" in result
|
||||
|
||||
def test_image_with_dimensions(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
client.insert_image(
|
||||
"doc-1",
|
||||
"https://example.com/img.png",
|
||||
index=1,
|
||||
width_pt=200.0,
|
||||
height_pt=100.0,
|
||||
)
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
req = body["requests"][0]["insertInlineImage"]
|
||||
assert req["objectSize"]["width"]["magnitude"] == 200.0
|
||||
|
||||
|
||||
class TestGoogleDocsClientFormatText:
|
||||
def test_bold_formatting(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
client.format_text("doc-1", 1, 10, bold=True)
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
req = body["requests"][0]["updateTextStyle"]
|
||||
assert req["textStyle"]["bold"] is True
|
||||
assert "bold" in req["fields"]
|
||||
|
||||
def test_no_options_returns_error(self, client):
|
||||
result = client.format_text("doc-1", 1, 10)
|
||||
assert "error" in result
|
||||
assert "No formatting" in result["error"]
|
||||
|
||||
|
||||
class TestGoogleDocsClientExportDocument:
|
||||
def test_export_pdf(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, content=b"%PDF-1.4 content")
|
||||
result = client.export_document("doc-1", "application/pdf")
|
||||
assert result["mime_type"] == "application/pdf"
|
||||
assert result["size_bytes"] == len(b"%PDF-1.4 content")
|
||||
assert "content_base64" in result
|
||||
|
||||
|
||||
class TestGoogleDocsClientComments:
|
||||
def test_add_comment(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(
|
||||
200, {"id": "comment-1", "content": "Nice work"}
|
||||
)
|
||||
result = client.add_comment("doc-1", "Nice work")
|
||||
assert result["id"] == "comment-1"
|
||||
|
||||
def test_add_comment_with_quoted_text(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"id": "comment-1"})
|
||||
client.add_comment("doc-1", "Fix this", quoted_text="typo here")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["quotedFileContent"]["value"] == "typo here"
|
||||
|
||||
def test_list_comments(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(
|
||||
200, {"comments": [{"id": "c1"}], "nextPageToken": "tok2"}
|
||||
)
|
||||
result = client.list_comments("doc-1", page_size=10)
|
||||
assert len(result["comments"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential handling via register_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoogleDocsCredentials:
|
||||
def test_no_credentials_returns_error(self, mcp, monkeypatch):
|
||||
monkeypatch.delenv("GOOGLE_ACCESS_TOKEN", raising=False)
|
||||
fn = _tool_fn(mcp, "google_docs_get_document")
|
||||
result = fn(document_id="doc-1")
|
||||
assert "error" in result
|
||||
assert "not configured" in result["error"]
|
||||
|
||||
def test_env_var_credential(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "env-tok")
|
||||
fn = _tool_fn(mcp, "google_docs_get_document")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"documentId": "doc-1"})
|
||||
fn(document_id="doc-1")
|
||||
headers = mock_get.call_args.kwargs["headers"]
|
||||
assert headers["Authorization"] == "Bearer env-tok"
|
||||
|
||||
def test_credential_store_used(self, mcp):
|
||||
creds = MagicMock()
|
||||
creds.get.return_value = "store-tok"
|
||||
fn = _tool_fn(mcp, "google_docs_get_document", credentials=creds)
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"documentId": "doc-1"})
|
||||
fn(document_id="doc-1")
|
||||
creds.get.assert_called_once_with("google")
|
||||
|
||||
def test_credential_store_non_string_raises(self, mcp):
|
||||
creds = MagicMock()
|
||||
creds.get.return_value = {"key": "value"}
|
||||
fn = _tool_fn(mcp, "google_docs_get_document", credentials=creds)
|
||||
with pytest.raises(TypeError, match="Expected string"):
|
||||
fn(document_id="doc-1")
|
||||
|
||||
def test_credential_store_account_alias(self, mcp):
|
||||
creds = MagicMock()
|
||||
creds.get_by_alias.return_value = "alias-tok"
|
||||
fn = _tool_fn(mcp, "google_docs_get_document", credentials=creds)
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"documentId": "doc-1"})
|
||||
fn(document_id="doc-1", account="my-account")
|
||||
creds.get_by_alias.assert_called_once_with("google", "my-account")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Document Management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoogleDocsCreateDocument:
|
||||
def test_success_returns_url(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_create_document")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(
|
||||
200, {"documentId": "new-doc", "title": "My Doc"}
|
||||
)
|
||||
result = fn(title="My Doc")
|
||||
assert result["document_id"] == "new-doc"
|
||||
assert "document_url" in result
|
||||
assert "new-doc" in result["document_url"]
|
||||
|
||||
def test_timeout(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_create_document")
|
||||
with patch("httpx.post", side_effect=httpx.TimeoutException("t")):
|
||||
result = fn(title="Doc")
|
||||
assert result == {"error": "Request timed out"}
|
||||
|
||||
|
||||
class TestGoogleDocsGetDocument:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_get_document")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"documentId": "doc-1", "title": "Test"})
|
||||
result = fn(document_id="doc-1")
|
||||
assert result["documentId"] == "doc-1"
|
||||
|
||||
|
||||
class TestGoogleDocsInsertText:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_insert_text")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
result = fn(document_id="doc-1", text="Hello", index=1)
|
||||
assert "error" not in result
|
||||
|
||||
|
||||
class TestGoogleDocsReplaceAllText:
|
||||
def test_success_with_count(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_replace_all_text")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(
|
||||
200,
|
||||
{"replies": [{"replaceAllText": {"occurrencesChanged": 3}}]},
|
||||
)
|
||||
result = fn(
|
||||
document_id="doc-1",
|
||||
find_text="{{NAME}}",
|
||||
replace_text="Alice",
|
||||
)
|
||||
assert result["occurrences_replaced"] == 3
|
||||
|
||||
|
||||
class TestGoogleDocsInsertImage:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_insert_image")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
result = fn(
|
||||
document_id="doc-1",
|
||||
image_uri="https://example.com/img.png",
|
||||
index=1,
|
||||
)
|
||||
assert "error" not in result
|
||||
|
||||
def test_invalid_uri(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_insert_image")
|
||||
# This gets caught by the client-level validation
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
result = fn(
|
||||
document_id="doc-1",
|
||||
image_uri="ftp://bad.com/img.png",
|
||||
index=1,
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestGoogleDocsFormatText:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_format_text")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
result = fn(
|
||||
document_id="doc-1",
|
||||
start_index=1,
|
||||
end_index=10,
|
||||
bold=True,
|
||||
)
|
||||
assert "error" not in result
|
||||
|
||||
|
||||
class TestGoogleDocsBatchUpdate:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_batch_update")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
requests = [{"insertText": {"text": "Hi", "location": {"index": 1}}}]
|
||||
result = fn(
|
||||
document_id="doc-1",
|
||||
requests_json=json.dumps(requests),
|
||||
)
|
||||
assert "error" not in result
|
||||
|
||||
def test_invalid_json(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_batch_update")
|
||||
result = fn(document_id="doc-1", requests_json="not json")
|
||||
assert "error" in result
|
||||
assert "Invalid JSON" in result["error"]
|
||||
|
||||
def test_non_array_json(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_batch_update")
|
||||
result = fn(document_id="doc-1", requests_json='{"key": "value"}')
|
||||
assert "error" in result
|
||||
assert "JSON array" in result["error"]
|
||||
|
||||
|
||||
class TestGoogleDocsCreateList:
|
||||
def test_bullet_list(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_create_list")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
result = fn(
|
||||
document_id="doc-1",
|
||||
start_index=1,
|
||||
end_index=20,
|
||||
list_type="bullet",
|
||||
)
|
||||
assert "error" not in result
|
||||
|
||||
def test_numbered_list(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_create_list")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"replies": []})
|
||||
result = fn(
|
||||
document_id="doc-1",
|
||||
start_index=1,
|
||||
end_index=20,
|
||||
list_type="numbered",
|
||||
)
|
||||
assert "error" not in result
|
||||
|
||||
|
||||
class TestGoogleDocsAddComment:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_add_comment")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"id": "comment-1", "content": "Fix this"})
|
||||
result = fn(document_id="doc-1", content="Fix this")
|
||||
assert result["id"] == "comment-1"
|
||||
|
||||
|
||||
class TestGoogleDocsListComments:
|
||||
def test_success_returns_structured(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_list_comments")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(
|
||||
200,
|
||||
{"comments": [{"id": "c1"}], "nextPageToken": "tok2"},
|
||||
)
|
||||
result = fn(document_id="doc-1")
|
||||
assert result["document_id"] == "doc-1"
|
||||
assert len(result["comments"]) == 1
|
||||
assert result["next_page_token"] == "tok2"
|
||||
|
||||
|
||||
class TestGoogleDocsExportContent:
|
||||
def test_export_pdf(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "google_docs_export_content")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, content=b"PDF data here")
|
||||
result = fn(document_id="doc-1", format="pdf")
|
||||
assert result["mime_type"] == "application/pdf"
|
||||
assert "content_base64" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolRegistration:
|
||||
"""Verify all Google Docs tools are registered."""
|
||||
|
||||
EXPECTED_TOOLS = [
|
||||
"google_docs_create_document",
|
||||
"google_docs_get_document",
|
||||
"google_docs_insert_text",
|
||||
"google_docs_replace_all_text",
|
||||
"google_docs_insert_image",
|
||||
"google_docs_format_text",
|
||||
"google_docs_batch_update",
|
||||
"google_docs_create_list",
|
||||
"google_docs_add_comment",
|
||||
"google_docs_list_comments",
|
||||
"google_docs_export_content",
|
||||
]
|
||||
|
||||
def test_all_tools_registered(self, mcp):
|
||||
tools = _register(mcp)
|
||||
for name in self.EXPECTED_TOOLS:
|
||||
assert name in tools, f"Tool {name} not registered"
|
||||
|
||||
def test_tool_count(self, mcp):
|
||||
tools = _register(mcp)
|
||||
gdocs_tools = [k for k in tools if k.startswith("google_docs_")]
|
||||
assert len(gdocs_tools) == len(self.EXPECTED_TOOLS)
|
||||
@@ -1304,3 +1304,96 @@ class TestPermissionsPreservation:
|
||||
|
||||
assert result["success"] is True
|
||||
assert f.stat().st_mode & 0o777 == mode
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-only ACL test")
|
||||
def test_acl_preserved_after_edit_windows(
|
||||
self, hashline_edit_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Atomic replace preserves the target file's DACL on Windows."""
|
||||
import ctypes
|
||||
|
||||
advapi32 = ctypes.windll.advapi32
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
SE_FILE_OBJECT = 1
|
||||
DACL_SECURITY_INFORMATION = 0x00000004
|
||||
|
||||
advapi32.GetNamedSecurityInfoW.argtypes = [
|
||||
ctypes.wintypes.LPCWSTR, # pObjectName
|
||||
ctypes.c_uint, # ObjectType (SE_OBJECT_TYPE enum)
|
||||
ctypes.wintypes.DWORD, # SecurityInfo
|
||||
ctypes.c_void_p, # ppsidOwner
|
||||
ctypes.c_void_p, # ppsidGroup
|
||||
ctypes.c_void_p, # ppDacl
|
||||
ctypes.c_void_p, # ppSacl
|
||||
ctypes.c_void_p, # ppSecurityDescriptor
|
||||
]
|
||||
advapi32.GetNamedSecurityInfoW.restype = ctypes.wintypes.DWORD
|
||||
|
||||
advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW.argtypes = [
|
||||
ctypes.c_void_p, # SecurityDescriptor
|
||||
ctypes.wintypes.DWORD, # RequestedStringSDRevision
|
||||
ctypes.wintypes.DWORD, # SecurityInformation
|
||||
ctypes.c_void_p, # StringSecurityDescriptor (out)
|
||||
ctypes.c_void_p, # StringSecurityDescriptorLen (out, optional)
|
||||
]
|
||||
advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW.restype = ctypes.wintypes.BOOL
|
||||
|
||||
kernel32.LocalFree.argtypes = [ctypes.c_void_p]
|
||||
kernel32.LocalFree.restype = ctypes.c_void_p
|
||||
|
||||
f = tmp_path / "test.txt"
|
||||
f.write_text("aaa\nbbb\n")
|
||||
|
||||
def _read_dacl_sddl(path):
|
||||
sd = ctypes.c_void_p()
|
||||
dacl = ctypes.c_void_p()
|
||||
rc = advapi32.GetNamedSecurityInfoW(
|
||||
str(path),
|
||||
SE_FILE_OBJECT,
|
||||
DACL_SECURITY_INFORMATION,
|
||||
None,
|
||||
None,
|
||||
ctypes.byref(dacl),
|
||||
None,
|
||||
ctypes.byref(sd),
|
||||
)
|
||||
assert rc == 0, f"GetNamedSecurityInfoW failed: {rc}"
|
||||
sddl = ctypes.c_wchar_p()
|
||||
assert advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW(
|
||||
sd,
|
||||
1,
|
||||
DACL_SECURITY_INFORMATION,
|
||||
ctypes.byref(sddl),
|
||||
None,
|
||||
)
|
||||
value = sddl.value
|
||||
kernel32.LocalFree(sddl)
|
||||
kernel32.LocalFree(sd)
|
||||
return value
|
||||
|
||||
acl_before = _read_dacl_sddl(f)
|
||||
|
||||
edits = json.dumps([{"op": "set_line", "anchor": _anchor(1, "aaa"), "content": "AAA"}])
|
||||
result = hashline_edit_fn(path="test.txt", edits=edits, **mock_workspace)
|
||||
assert result["success"] is True
|
||||
|
||||
acl_after = _read_dacl_sddl(f)
|
||||
|
||||
assert acl_before == acl_after, f"ACL changed after edit: {acl_before} -> {acl_after}"
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-only ACL test")
|
||||
def test_edit_succeeds_when_dacl_unavailable_windows(
|
||||
self, hashline_edit_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Edit still works on volumes without ACL support (e.g. FAT32)."""
|
||||
from aden_tools import _win32_atomic
|
||||
|
||||
f = tmp_path / "test.txt"
|
||||
f.write_text("aaa\nbbb\n")
|
||||
|
||||
with patch.object(_win32_atomic, "snapshot_dacl", return_value=None):
|
||||
edits = json.dumps([{"op": "set_line", "anchor": _anchor(1, "aaa"), "content": "AAA"}])
|
||||
result = hashline_edit_fn(path="test.txt", edits=edits, **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert f.read_text().splitlines()[0].endswith("AAA")
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
"""Tests for HTTP Headers Scanner tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.http_headers_scanner import register_tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headers_tools(mcp: FastMCP):
|
||||
"""Register HTTP headers tools and return tool functions."""
|
||||
register_tools(mcp)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scan_fn(headers_tools):
|
||||
return headers_tools["http_headers_scan"]
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
headers: dict | None = None,
|
||||
url: str = "https://example.com",
|
||||
) -> MagicMock:
|
||||
"""Create a mock httpx.Response."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.url = url
|
||||
resp.headers = httpx.Headers(headers or {})
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input Validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
"""Test URL input cleaning and validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_prefix_https(self, scan_fn):
|
||||
mock_resp = _mock_response(headers={"strict-transport-security": "max-age=31536000"})
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("example.com")
|
||||
assert "error" not in result
|
||||
# Verify https was prefixed
|
||||
mock_client.get.assert_called_once()
|
||||
call_url = mock_client.get.call_args[0][0]
|
||||
assert call_url.startswith("https://")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connection Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectionErrors:
|
||||
"""Test error handling for connection failures."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_error(self, scan_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.ConnectError("Connection refused")
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert "error" in result
|
||||
assert "Connection failed" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_error(self, scan_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.TimeoutException("Request timed out")
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert "error" in result
|
||||
assert "timed out" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security Headers Detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
"""Test detection of OWASP security headers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_headers_present(self, scan_fn):
|
||||
headers = {
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
"Content-Security-Policy": "default-src 'self'",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
"Permissions-Policy": "camera=(), microphone=()",
|
||||
}
|
||||
mock_resp = _mock_response(headers=headers)
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert len(result["headers_present"]) == 6
|
||||
assert len(result["headers_missing"]) == 0
|
||||
assert result["grade_input"]["hsts"] is True
|
||||
assert result["grade_input"]["csp"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_hsts(self, scan_fn):
|
||||
headers = {
|
||||
"Content-Security-Policy": "default-src 'self'",
|
||||
"X-Frame-Options": "DENY",
|
||||
}
|
||||
mock_resp = _mock_response(headers=headers)
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert result["grade_input"]["hsts"] is False
|
||||
missing_names = [h["header"] for h in result["headers_missing"]]
|
||||
assert "Strict-Transport-Security" in missing_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_csp(self, scan_fn):
|
||||
headers = {
|
||||
"Strict-Transport-Security": "max-age=31536000",
|
||||
}
|
||||
mock_resp = _mock_response(headers=headers)
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert result["grade_input"]["csp"] is False
|
||||
missing_names = [h["header"] for h in result["headers_missing"]]
|
||||
assert "Content-Security-Policy" in missing_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Leaky Headers Detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLeakyHeaders:
|
||||
"""Test detection of information-leaking headers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_header_leaked(self, scan_fn):
|
||||
headers = {"Server": "Apache/2.4.41 (Ubuntu)"}
|
||||
mock_resp = _mock_response(headers=headers)
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert len(result["leaky_headers"]) > 0
|
||||
leaky_names = [h["header"] for h in result["leaky_headers"]]
|
||||
assert "Server" in leaky_names
|
||||
assert result["grade_input"]["no_leaky_headers"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_x_powered_by_leaked(self, scan_fn):
|
||||
headers = {"X-Powered-By": "PHP/8.1.0"}
|
||||
mock_resp = _mock_response(headers=headers)
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
leaky_names = [h["header"] for h in result["leaky_headers"]]
|
||||
assert "X-Powered-By" in leaky_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_leaky_headers(self, scan_fn):
|
||||
headers = {
|
||||
"Strict-Transport-Security": "max-age=31536000",
|
||||
"Content-Type": "text/html",
|
||||
}
|
||||
mock_resp = _mock_response(headers=headers)
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert len(result["leaky_headers"]) == 0
|
||||
assert result["grade_input"]["no_leaky_headers"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deprecated Headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeprecatedHeaders:
|
||||
"""Test detection of deprecated headers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_xss_protection_deprecated(self, scan_fn):
|
||||
headers = {"X-XSS-Protection": "1; mode=block"}
|
||||
mock_resp = _mock_response(headers=headers)
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert "X-XSS-Protection (deprecated)" in result["headers_present"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grade Input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGradeInput:
|
||||
"""Test grade_input dict is properly constructed."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grade_input_keys_present(self, scan_fn):
|
||||
mock_resp = _mock_response(headers={})
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert "grade_input" in result
|
||||
grade = result["grade_input"]
|
||||
assert "hsts" in grade
|
||||
assert "csp" in grade
|
||||
assert "x_frame_options" in grade
|
||||
assert "x_content_type_options" in grade
|
||||
assert "referrer_policy" in grade
|
||||
assert "permissions_policy" in grade
|
||||
assert "no_leaky_headers" in grade
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response Metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResponseMetadata:
|
||||
"""Test response metadata is captured."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_code_captured(self, scan_fn):
|
||||
mock_resp = _mock_response(status_code=200, headers={})
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_final_url_captured(self, scan_fn):
|
||||
mock_resp = _mock_response(status_code=200, headers={}, url="https://www.example.com/")
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await scan_fn("https://example.com")
|
||||
assert result["url"] == "https://www.example.com/"
|
||||
@@ -0,0 +1,596 @@
|
||||
"""Tests for HubSpot CRM tool with FastMCP.
|
||||
|
||||
Covers:
|
||||
- Credential handling (credential store, env var, missing)
|
||||
- _HubSpotClient methods (search, get, create, update, delete, associations)
|
||||
- HTTP error handling (401, 403, 404, 429, 500, timeout)
|
||||
- All 12 MCP tool functions via register_tools
|
||||
- Input validation (delete_object object_type whitelist)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.hubspot_tool.hubspot_tool import (
|
||||
HUBSPOT_API_BASE,
|
||||
_HubSpotClient,
|
||||
register_tools,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp():
|
||||
"""Create a FastMCP instance for testing."""
|
||||
return FastMCP("test-server")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a _HubSpotClient with a test token."""
|
||||
return _HubSpotClient("test-token")
|
||||
|
||||
|
||||
def _register(mcp, credentials=None):
|
||||
"""Helper to register tools and return the tool lookup dict."""
|
||||
register_tools(mcp, credentials=credentials)
|
||||
return mcp._tool_manager._tools
|
||||
|
||||
|
||||
def _tool_fn(mcp, name, credentials=None):
|
||||
"""Register tools and return a single tool function by name."""
|
||||
tools = _register(mcp, credentials)
|
||||
return tools[name].fn
|
||||
|
||||
|
||||
def _mock_response(status_code=200, json_data=None, text=""):
|
||||
"""Create a mock httpx.Response."""
|
||||
resp = MagicMock(spec=httpx.Response)
|
||||
resp.status_code = status_code
|
||||
resp.text = text
|
||||
if json_data is not None:
|
||||
resp.json.return_value = json_data
|
||||
else:
|
||||
resp.json.return_value = {}
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _HubSpotClient unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHubSpotClientHeaders:
|
||||
"""Verify client sends correct auth headers."""
|
||||
|
||||
def test_headers_contain_bearer_token(self, client):
|
||||
headers = client._headers
|
||||
assert headers["Authorization"] == "Bearer test-token"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["Accept"] == "application/json"
|
||||
|
||||
|
||||
class TestHubSpotClientHandleResponse:
|
||||
"""Verify _handle_response maps HTTP codes to error dicts."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_substr",
|
||||
[
|
||||
(401, "Invalid or expired"),
|
||||
(403, "Insufficient permissions"),
|
||||
(404, "not found"),
|
||||
(429, "rate limit"),
|
||||
],
|
||||
)
|
||||
def test_known_error_codes(self, client, status_code, expected_substr):
|
||||
resp = _mock_response(status_code=status_code)
|
||||
result = client._handle_response(resp)
|
||||
assert "error" in result
|
||||
assert expected_substr in result["error"]
|
||||
|
||||
def test_generic_4xx_with_json_message(self, client):
|
||||
resp = _mock_response(
|
||||
status_code=422,
|
||||
json_data={"message": "Property not found"},
|
||||
)
|
||||
result = client._handle_response(resp)
|
||||
assert "error" in result
|
||||
assert "422" in result["error"]
|
||||
assert "Property not found" in result["error"]
|
||||
|
||||
def test_generic_5xx_fallback_to_text(self, client):
|
||||
resp = _mock_response(status_code=500, text="Internal Server Error")
|
||||
resp.json.side_effect = Exception("not json")
|
||||
result = client._handle_response(resp)
|
||||
assert "error" in result
|
||||
assert "500" in result["error"]
|
||||
|
||||
def test_success_returns_json(self, client):
|
||||
resp = _mock_response(status_code=200, json_data={"id": "123"})
|
||||
result = client._handle_response(resp)
|
||||
assert result == {"id": "123"}
|
||||
|
||||
|
||||
class TestHubSpotClientSearchObjects:
|
||||
"""Tests for _HubSpotClient.search_objects."""
|
||||
|
||||
def test_search_posts_correct_url(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": [], "total": 0})
|
||||
client.search_objects("contacts", query="test@example.com")
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
assert args[0] == f"{HUBSPOT_API_BASE}/crm/v3/objects/contacts/search"
|
||||
|
||||
def test_search_sends_query_and_properties(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": []})
|
||||
client.search_objects(
|
||||
"contacts",
|
||||
query="jane",
|
||||
properties=["email", "firstname"],
|
||||
limit=5,
|
||||
)
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["query"] == "jane"
|
||||
assert body["properties"] == ["email", "firstname"]
|
||||
assert body["limit"] == 5
|
||||
|
||||
def test_search_clamps_limit_to_100(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": []})
|
||||
client.search_objects("contacts", limit=999)
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["limit"] == 100
|
||||
|
||||
|
||||
class TestHubSpotClientGetObject:
|
||||
"""Tests for _HubSpotClient.get_object."""
|
||||
|
||||
def test_get_object_url(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"id": "42"})
|
||||
client.get_object("contacts", "42")
|
||||
args, _ = mock_get.call_args
|
||||
assert args[0] == f"{HUBSPOT_API_BASE}/crm/v3/objects/contacts/42"
|
||||
|
||||
def test_get_object_passes_properties(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"id": "42"})
|
||||
client.get_object("contacts", "42", properties=["email", "phone"])
|
||||
params = mock_get.call_args.kwargs["params"]
|
||||
assert params["properties"] == "email,phone"
|
||||
|
||||
|
||||
class TestHubSpotClientCreateObject:
|
||||
"""Tests for _HubSpotClient.create_object."""
|
||||
|
||||
def test_create_object_posts_properties(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(
|
||||
200, {"id": "99", "properties": {"email": "a@b.com"}}
|
||||
)
|
||||
result = client.create_object("contacts", {"email": "a@b.com", "firstname": "Alice"})
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body == {"properties": {"email": "a@b.com", "firstname": "Alice"}}
|
||||
assert result["id"] == "99"
|
||||
|
||||
|
||||
class TestHubSpotClientUpdateObject:
|
||||
"""Tests for _HubSpotClient.update_object."""
|
||||
|
||||
def test_update_object_uses_patch(self, client):
|
||||
with patch("httpx.patch") as mock_patch:
|
||||
mock_patch.return_value = _mock_response(200, {"id": "42"})
|
||||
client.update_object("contacts", "42", {"phone": "+1234567890"})
|
||||
mock_patch.assert_called_once()
|
||||
args, kwargs = mock_patch.call_args
|
||||
assert "/contacts/42" in args[0]
|
||||
assert kwargs["json"] == {"properties": {"phone": "+1234567890"}}
|
||||
|
||||
|
||||
class TestHubSpotClientDeleteObject:
|
||||
"""Tests for _HubSpotClient.delete_object."""
|
||||
|
||||
def test_delete_returns_status_on_204(self, client):
|
||||
with patch("httpx.delete") as mock_delete:
|
||||
mock_delete.return_value = _mock_response(status_code=204)
|
||||
result = client.delete_object("contacts", "42")
|
||||
assert result["status"] == "deleted"
|
||||
assert result["object_id"] == "42"
|
||||
|
||||
def test_delete_non_204_delegates_to_handle_response(self, client):
|
||||
with patch("httpx.delete") as mock_delete:
|
||||
mock_delete.return_value = _mock_response(
|
||||
status_code=404, json_data={"message": "Not found"}
|
||||
)
|
||||
result = client.delete_object("contacts", "42")
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestHubSpotClientAssociations:
|
||||
"""Tests for association-related client methods."""
|
||||
|
||||
def test_list_associations_url(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"results": []})
|
||||
client.list_associations("contacts", "1", "companies")
|
||||
args, _ = mock_get.call_args
|
||||
assert "/crm/v4/objects/contacts/1/associations/companies" in args[0]
|
||||
|
||||
def test_list_associations_clamps_limit(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"results": []})
|
||||
client.list_associations("contacts", "1", "companies", limit=999)
|
||||
params = mock_get.call_args.kwargs["params"]
|
||||
assert params["limit"] == 500
|
||||
|
||||
def test_create_association_uses_put(self, client):
|
||||
with patch("httpx.put") as mock_put:
|
||||
mock_put.return_value = _mock_response(200, {"status": "ok"})
|
||||
client.create_association("contacts", "1", "companies", "2")
|
||||
mock_put.assert_called_once()
|
||||
body = mock_put.call_args.kwargs["json"]
|
||||
assert body[0]["associationCategory"] == "HUBSPOT_DEFINED"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential handling via register_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHubSpotCredentials:
|
||||
"""Tests for credential resolution in MCP tool functions."""
|
||||
|
||||
def test_no_credentials_returns_error(self, mcp, monkeypatch):
|
||||
monkeypatch.delenv("HUBSPOT_ACCESS_TOKEN", raising=False)
|
||||
fn = _tool_fn(mcp, "hubspot_search_contacts")
|
||||
result = fn()
|
||||
assert "error" in result
|
||||
assert "not configured" in result["error"]
|
||||
assert "help" in result
|
||||
|
||||
def test_env_var_credential(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "env-token")
|
||||
fn = _tool_fn(mcp, "hubspot_search_contacts")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": []})
|
||||
fn(query="test")
|
||||
headers = mock_post.call_args.kwargs["headers"]
|
||||
assert headers["Authorization"] == "Bearer env-token"
|
||||
|
||||
def test_credential_store_used_when_provided(self, mcp):
|
||||
creds = MagicMock()
|
||||
creds.get.return_value = "store-token"
|
||||
fn = _tool_fn(mcp, "hubspot_search_contacts", credentials=creds)
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": []})
|
||||
fn(query="test")
|
||||
creds.get.assert_called_once_with("hubspot")
|
||||
headers = mock_post.call_args.kwargs["headers"]
|
||||
assert headers["Authorization"] == "Bearer store-token"
|
||||
|
||||
def test_credential_store_non_string_raises(self, mcp):
|
||||
creds = MagicMock()
|
||||
creds.get.return_value = {"access_token": "bad"}
|
||||
fn = _tool_fn(mcp, "hubspot_search_contacts", credentials=creds)
|
||||
with pytest.raises(TypeError, match="Expected string"):
|
||||
fn(query="test")
|
||||
|
||||
def test_credential_store_account_alias(self, mcp):
|
||||
creds = MagicMock()
|
||||
creds.get_by_alias.return_value = "alias-token"
|
||||
fn = _tool_fn(mcp, "hubspot_search_contacts", credentials=creds)
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": []})
|
||||
fn(query="test", account="my-account")
|
||||
creds.get_by_alias.assert_called_once_with("hubspot", "my-account")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Contacts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHubSpotSearchContacts:
|
||||
"""Tests for hubspot_search_contacts tool."""
|
||||
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_search_contacts")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": [{"id": "1"}], "total": 1})
|
||||
result = fn(query="jane")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_timeout(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_search_contacts")
|
||||
with patch("httpx.post", side_effect=httpx.TimeoutException("timeout")):
|
||||
result = fn(query="jane")
|
||||
assert result == {"error": "Request timed out"}
|
||||
|
||||
def test_network_error(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_search_contacts")
|
||||
with patch("httpx.post", side_effect=httpx.RequestError("dns fail")):
|
||||
result = fn(query="jane")
|
||||
assert "Network error" in result["error"]
|
||||
|
||||
|
||||
class TestHubSpotGetContact:
|
||||
"""Tests for hubspot_get_contact tool."""
|
||||
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_get_contact")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(
|
||||
200, {"id": "42", "properties": {"email": "a@b.com"}}
|
||||
)
|
||||
result = fn(contact_id="42")
|
||||
assert result["id"] == "42"
|
||||
|
||||
def test_404(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_get_contact")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(status_code=404)
|
||||
result = fn(contact_id="999")
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"]
|
||||
|
||||
|
||||
class TestHubSpotCreateContact:
|
||||
"""Tests for hubspot_create_contact tool."""
|
||||
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_create_contact")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(
|
||||
200, {"id": "99", "properties": {"email": "new@example.com"}}
|
||||
)
|
||||
result = fn(properties={"email": "new@example.com"})
|
||||
assert result["id"] == "99"
|
||||
|
||||
|
||||
class TestHubSpotUpdateContact:
|
||||
"""Tests for hubspot_update_contact tool."""
|
||||
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_update_contact")
|
||||
with patch("httpx.patch") as mock_patch:
|
||||
mock_patch.return_value = _mock_response(200, {"id": "42"})
|
||||
result = fn(contact_id="42", properties={"phone": "+1234567890"})
|
||||
assert result["id"] == "42"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Companies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHubSpotSearchCompanies:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_search_companies")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": [{"id": "c1"}], "total": 1})
|
||||
result = fn(query="Acme")
|
||||
assert result["total"] == 1
|
||||
|
||||
|
||||
class TestHubSpotGetCompany:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_get_company")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(
|
||||
200, {"id": "c1", "properties": {"name": "Acme"}}
|
||||
)
|
||||
result = fn(company_id="c1")
|
||||
assert result["id"] == "c1"
|
||||
|
||||
|
||||
class TestHubSpotCreateCompany:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_create_company")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(
|
||||
200, {"id": "c2", "properties": {"name": "NewCo"}}
|
||||
)
|
||||
result = fn(properties={"name": "NewCo"})
|
||||
assert result["id"] == "c2"
|
||||
|
||||
|
||||
class TestHubSpotUpdateCompany:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_update_company")
|
||||
with patch("httpx.patch") as mock_patch:
|
||||
mock_patch.return_value = _mock_response(200, {"id": "c1"})
|
||||
result = fn(company_id="c1", properties={"industry": "Finance"})
|
||||
assert result["id"] == "c1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Deals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHubSpotSearchDeals:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_search_deals")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"results": [{"id": "d1"}], "total": 1})
|
||||
result = fn(query="big deal")
|
||||
assert result["total"] == 1
|
||||
|
||||
|
||||
class TestHubSpotGetDeal:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_get_deal")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(
|
||||
200, {"id": "d1", "properties": {"dealname": "Big Deal"}}
|
||||
)
|
||||
result = fn(deal_id="d1")
|
||||
assert result["id"] == "d1"
|
||||
|
||||
|
||||
class TestHubSpotCreateDeal:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_create_deal")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(
|
||||
200, {"id": "d2", "properties": {"dealname": "New Deal"}}
|
||||
)
|
||||
result = fn(properties={"dealname": "New Deal", "amount": "10000"})
|
||||
assert result["id"] == "d2"
|
||||
|
||||
|
||||
class TestHubSpotUpdateDeal:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_update_deal")
|
||||
with patch("httpx.patch") as mock_patch:
|
||||
mock_patch.return_value = _mock_response(200, {"id": "d1"})
|
||||
result = fn(deal_id="d1", properties={"amount": "15000"})
|
||||
assert result["id"] == "d1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHubSpotDeleteObject:
|
||||
"""Tests for hubspot_delete_object tool."""
|
||||
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_delete_object")
|
||||
with patch("httpx.delete") as mock_delete:
|
||||
mock_delete.return_value = _mock_response(status_code=204)
|
||||
result = fn(object_type="contacts", object_id="42")
|
||||
assert result["status"] == "deleted"
|
||||
|
||||
def test_invalid_object_type(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_delete_object")
|
||||
result = fn(object_type="tickets", object_id="1")
|
||||
assert "error" in result
|
||||
assert "Unsupported object_type" in result["error"]
|
||||
|
||||
@pytest.mark.parametrize("valid_type", ["contacts", "companies", "deals"])
|
||||
def test_all_valid_object_types(self, mcp, monkeypatch, valid_type):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_delete_object")
|
||||
with patch("httpx.delete") as mock_delete:
|
||||
mock_delete.return_value = _mock_response(status_code=204)
|
||||
result = fn(object_type=valid_type, object_id="1")
|
||||
assert result["status"] == "deleted"
|
||||
|
||||
def test_timeout(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_delete_object")
|
||||
with patch("httpx.delete", side_effect=httpx.TimeoutException("t")):
|
||||
result = fn(object_type="contacts", object_id="1")
|
||||
assert result == {"error": "Request timed out"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Associations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHubSpotListAssociations:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_list_associations")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"results": [{"toObjectId": "c1"}]})
|
||||
result = fn(
|
||||
from_object_type="contacts",
|
||||
from_object_id="1",
|
||||
to_object_type="companies",
|
||||
)
|
||||
assert "results" in result
|
||||
|
||||
def test_timeout(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_list_associations")
|
||||
with patch("httpx.get", side_effect=httpx.TimeoutException("t")):
|
||||
result = fn(
|
||||
from_object_type="contacts",
|
||||
from_object_id="1",
|
||||
to_object_type="companies",
|
||||
)
|
||||
assert result == {"error": "Request timed out"}
|
||||
|
||||
|
||||
class TestHubSpotCreateAssociation:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("HUBSPOT_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "hubspot_create_association")
|
||||
with patch("httpx.put") as mock_put:
|
||||
mock_put.return_value = _mock_response(200, {"status": "ok"})
|
||||
result = fn(
|
||||
from_object_type="contacts",
|
||||
from_object_id="1",
|
||||
to_object_type="companies",
|
||||
to_object_id="2",
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolRegistration:
|
||||
"""Verify all 12 HubSpot tools are registered."""
|
||||
|
||||
EXPECTED_TOOLS = [
|
||||
"hubspot_search_contacts",
|
||||
"hubspot_get_contact",
|
||||
"hubspot_create_contact",
|
||||
"hubspot_update_contact",
|
||||
"hubspot_search_companies",
|
||||
"hubspot_get_company",
|
||||
"hubspot_create_company",
|
||||
"hubspot_update_company",
|
||||
"hubspot_search_deals",
|
||||
"hubspot_get_deal",
|
||||
"hubspot_create_deal",
|
||||
"hubspot_update_deal",
|
||||
"hubspot_delete_object",
|
||||
"hubspot_list_associations",
|
||||
"hubspot_create_association",
|
||||
]
|
||||
|
||||
def test_all_tools_registered(self, mcp):
|
||||
tools = _register(mcp)
|
||||
for name in self.EXPECTED_TOOLS:
|
||||
assert name in tools, f"Tool {name} not registered"
|
||||
|
||||
def test_tool_count(self, mcp):
|
||||
tools = _register(mcp)
|
||||
# Filter to only hubspot tools
|
||||
hubspot_tools = [k for k in tools if k.startswith("hubspot_")]
|
||||
assert len(hubspot_tools) == len(self.EXPECTED_TOOLS)
|
||||
@@ -0,0 +1,543 @@
|
||||
"""Tests for Intercom tool with FastMCP.
|
||||
|
||||
Covers:
|
||||
- Credential handling (credential store, env var, missing)
|
||||
- _IntercomClient methods (search, get, reply, assign, tag, close, create)
|
||||
- HTTP error handling (401, 403, 404, 429, 500, timeout)
|
||||
- All MCP tool functions via register_tools
|
||||
- Input validation (status, assignee_type, limit, role, tag exclusivity)
|
||||
- Admin ID lazy-fetch via /me
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.intercom_tool.intercom_tool import (
|
||||
INTERCOM_API_BASE,
|
||||
_IntercomClient,
|
||||
register_tools,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp():
|
||||
"""Create a FastMCP instance for testing."""
|
||||
return FastMCP("test-server")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create an _IntercomClient with a test token."""
|
||||
return _IntercomClient("test-token")
|
||||
|
||||
|
||||
def _register(mcp, credentials=None):
|
||||
"""Helper to register tools and return the tool lookup dict."""
|
||||
register_tools(mcp, credentials=credentials)
|
||||
return mcp._tool_manager._tools
|
||||
|
||||
|
||||
def _tool_fn(mcp, name, credentials=None):
|
||||
"""Register tools and return a single tool function by name."""
|
||||
tools = _register(mcp, credentials)
|
||||
return tools[name].fn
|
||||
|
||||
|
||||
def _mock_response(status_code=200, json_data=None, text=""):
|
||||
"""Create a mock httpx.Response."""
|
||||
resp = MagicMock(spec=httpx.Response)
|
||||
resp.status_code = status_code
|
||||
resp.text = text
|
||||
if json_data is not None:
|
||||
resp.json.return_value = json_data
|
||||
else:
|
||||
resp.json.return_value = {}
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _IntercomClient unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIntercomClientHeaders:
|
||||
"""Verify client sends correct auth and version headers."""
|
||||
|
||||
def test_headers_contain_bearer_token(self, client):
|
||||
headers = client._headers
|
||||
assert headers["Authorization"] == "Bearer test-token"
|
||||
assert headers["Intercom-Version"] == "2.11"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
|
||||
class TestIntercomClientHandleResponse:
|
||||
"""Verify _handle_response maps HTTP codes to error dicts."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_substr",
|
||||
[
|
||||
(401, "Invalid or expired"),
|
||||
(403, "Insufficient permissions"),
|
||||
(404, "not found"),
|
||||
(429, "rate limit"),
|
||||
],
|
||||
)
|
||||
def test_known_error_codes(self, client, status_code, expected_substr):
|
||||
resp = _mock_response(status_code=status_code)
|
||||
result = client._handle_response(resp)
|
||||
assert "error" in result
|
||||
assert expected_substr in result["error"]
|
||||
|
||||
def test_intercom_error_list_format(self, client):
|
||||
resp = _mock_response(
|
||||
status_code=422,
|
||||
json_data={
|
||||
"type": "error.list",
|
||||
"errors": [{"message": "Field is required"}],
|
||||
},
|
||||
)
|
||||
result = client._handle_response(resp)
|
||||
assert "Field is required" in result["error"]
|
||||
|
||||
def test_generic_error_fallback_to_text(self, client):
|
||||
resp = _mock_response(status_code=500, text="Server Error")
|
||||
resp.json.side_effect = Exception("not json")
|
||||
result = client._handle_response(resp)
|
||||
assert "500" in result["error"]
|
||||
|
||||
def test_success_returns_json(self, client):
|
||||
resp = _mock_response(200, {"id": "abc"})
|
||||
assert client._handle_response(resp) == {"id": "abc"}
|
||||
|
||||
|
||||
class TestIntercomClientAdminId:
|
||||
"""Tests for lazy admin ID fetching via /me."""
|
||||
|
||||
def test_fetches_admin_id_on_first_call(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"id": "admin-123"})
|
||||
result = client._get_admin_id()
|
||||
assert result == "admin-123"
|
||||
mock_get.assert_called_once()
|
||||
assert INTERCOM_API_BASE + "/me" in mock_get.call_args[0][0]
|
||||
|
||||
def test_caches_admin_id(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"id": "admin-123"})
|
||||
client._get_admin_id()
|
||||
client._get_admin_id()
|
||||
# Only called once due to caching
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
def test_returns_error_on_failure(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(401)
|
||||
result = client._get_admin_id()
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestIntercomClientSearchConversations:
|
||||
def test_posts_to_correct_url(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"conversations": []})
|
||||
client.search_conversations({"field": "state", "operator": "=", "value": "open"})
|
||||
args, _ = mock_post.call_args
|
||||
assert args[0] == f"{INTERCOM_API_BASE}/conversations/search"
|
||||
|
||||
def test_clamps_limit(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"conversations": []})
|
||||
client.search_conversations({}, limit=999)
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["pagination"]["per_page"] == 150
|
||||
|
||||
|
||||
class TestIntercomClientGetConversation:
|
||||
def test_url_and_plaintext_param(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"id": "conv-1"})
|
||||
client.get_conversation("conv-1")
|
||||
args, kwargs = mock_get.call_args
|
||||
assert "/conversations/conv-1" in args[0]
|
||||
assert kwargs["params"]["display_as"] == "plaintext"
|
||||
|
||||
|
||||
class TestIntercomClientReplyToConversation:
|
||||
def test_reply_sends_admin_id(self, client):
|
||||
client._admin_id = "admin-1"
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"type": "conversation_part"})
|
||||
client.reply_to_conversation("conv-1", body="Hello", message_type="comment")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["admin_id"] == "admin-1"
|
||||
assert body["message_type"] == "comment"
|
||||
assert body["body"] == "Hello"
|
||||
|
||||
|
||||
class TestIntercomClientCreateContact:
|
||||
def test_creates_with_role_and_email(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"id": "contact-1", "role": "user"})
|
||||
client.create_contact(role="user", email="test@example.com")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["role"] == "user"
|
||||
assert body["email"] == "test@example.com"
|
||||
|
||||
def test_omits_none_fields(self, client):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"id": "contact-1"})
|
||||
client.create_contact(role="lead")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert "email" not in body
|
||||
assert "name" not in body
|
||||
|
||||
|
||||
class TestIntercomClientListConversations:
|
||||
def test_passes_pagination_params(self, client):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"conversations": []})
|
||||
client.list_conversations(limit=10, starting_after="cursor-abc")
|
||||
params = mock_get.call_args.kwargs["params"]
|
||||
assert params["per_page"] == 10
|
||||
assert params["starting_after"] == "cursor-abc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential handling via register_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIntercomCredentials:
|
||||
"""Tests for credential resolution in MCP tool functions."""
|
||||
|
||||
def test_no_credentials_returns_error(self, mcp, monkeypatch):
|
||||
monkeypatch.delenv("INTERCOM_ACCESS_TOKEN", raising=False)
|
||||
fn = _tool_fn(mcp, "intercom_search_conversations")
|
||||
result = fn()
|
||||
assert "error" in result
|
||||
assert "not configured" in result["error"]
|
||||
|
||||
def test_env_var_credential(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "env-tok")
|
||||
fn = _tool_fn(mcp, "intercom_list_teams")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"teams": []})
|
||||
fn()
|
||||
headers = mock_get.call_args.kwargs["headers"]
|
||||
assert headers["Authorization"] == "Bearer env-tok"
|
||||
|
||||
def test_credential_store_used(self, mcp):
|
||||
creds = MagicMock()
|
||||
creds.get.return_value = "store-tok"
|
||||
fn = _tool_fn(mcp, "intercom_list_teams", credentials=creds)
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"teams": []})
|
||||
fn()
|
||||
creds.get.assert_called_once_with("intercom")
|
||||
|
||||
def test_credential_store_non_string_raises(self, mcp):
|
||||
creds = MagicMock()
|
||||
creds.get.return_value = 12345
|
||||
fn = _tool_fn(mcp, "intercom_list_teams", credentials=creds)
|
||||
with pytest.raises(TypeError, match="Expected string"):
|
||||
fn()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Conversations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIntercomSearchConversations:
|
||||
def test_no_filters_returns_recent(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_conversations")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"conversations": [{"id": "1"}]})
|
||||
result = fn()
|
||||
assert "conversations" in result
|
||||
|
||||
def test_invalid_status(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_conversations")
|
||||
result = fn(status="invalid")
|
||||
assert "error" in result
|
||||
assert "status" in result["error"]
|
||||
|
||||
def test_invalid_limit_too_high(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_conversations")
|
||||
result = fn(limit=200)
|
||||
assert "error" in result
|
||||
assert "limit" in result["error"]
|
||||
|
||||
def test_invalid_limit_too_low(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_conversations")
|
||||
result = fn(limit=0)
|
||||
assert "error" in result
|
||||
|
||||
def test_status_filter_applied(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_conversations")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"conversations": []})
|
||||
fn(status="open")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
query = body["query"]
|
||||
assert query["field"] == "state"
|
||||
assert query["value"] == "open"
|
||||
|
||||
def test_invalid_created_after(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_conversations")
|
||||
result = fn(created_after="not-a-date")
|
||||
assert "error" in result
|
||||
assert "ISO date" in result["error"]
|
||||
|
||||
def test_timeout(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_conversations")
|
||||
with patch("httpx.post", side_effect=httpx.TimeoutException("t")):
|
||||
result = fn()
|
||||
assert result == {"error": "Request timed out"}
|
||||
|
||||
|
||||
class TestIntercomGetConversation:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_get_conversation")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"id": "conv-1", "state": "open"})
|
||||
result = fn(conversation_id="conv-1")
|
||||
assert result["id"] == "conv-1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Contacts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIntercomGetContact:
|
||||
def test_by_id(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_get_contact")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"id": "c1", "email": "a@b.com"})
|
||||
result = fn(contact_id="c1")
|
||||
assert result["id"] == "c1"
|
||||
|
||||
def test_by_email_fallback(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_get_contact")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(
|
||||
200, {"data": [{"id": "c1", "email": "a@b.com"}]}
|
||||
)
|
||||
result = fn(email="a@b.com")
|
||||
assert result["id"] == "c1"
|
||||
|
||||
def test_no_id_or_email(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_get_contact")
|
||||
result = fn()
|
||||
assert "error" in result
|
||||
assert "contact_id or email" in result["error"]
|
||||
|
||||
def test_email_not_found(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_get_contact")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"data": []})
|
||||
result = fn(email="missing@example.com")
|
||||
assert "error" in result
|
||||
assert "No contact found" in result["error"]
|
||||
|
||||
|
||||
class TestIntercomSearchContacts:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_contacts")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"data": [{"id": "c1"}]})
|
||||
result = fn(query="jane")
|
||||
assert "data" in result
|
||||
|
||||
def test_invalid_limit(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_search_contacts")
|
||||
result = fn(query="test", limit=200)
|
||||
assert "error" in result
|
||||
assert "limit" in result["error"]
|
||||
|
||||
|
||||
class TestIntercomCreateContact:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_create_contact")
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = _mock_response(200, {"id": "new-c", "role": "user"})
|
||||
result = fn(email="new@example.com")
|
||||
assert result["id"] == "new-c"
|
||||
|
||||
def test_invalid_role(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_create_contact")
|
||||
result = fn(role="admin")
|
||||
assert "error" in result
|
||||
assert "role" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP tool function tests — Notes, Tags, Assignment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIntercomAddNote:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_add_note")
|
||||
with patch("httpx.get") as mock_get, patch("httpx.post") as mock_post:
|
||||
mock_get.return_value = _mock_response(200, {"id": "admin-1"})
|
||||
mock_post.return_value = _mock_response(200, {"type": "conversation_part"})
|
||||
result = fn(conversation_id="conv-1", body="Internal note")
|
||||
assert result["type"] == "conversation_part"
|
||||
|
||||
|
||||
class TestIntercomAddTag:
|
||||
def test_must_provide_target(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_add_tag")
|
||||
result = fn(name="vip")
|
||||
assert "error" in result
|
||||
assert "conversation_id or contact_id" in result["error"]
|
||||
|
||||
def test_cannot_provide_both_targets(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_add_tag")
|
||||
result = fn(name="vip", conversation_id="c1", contact_id="ct1")
|
||||
assert "error" in result
|
||||
assert "not both" in result["error"]
|
||||
|
||||
def test_tag_conversation_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_add_tag")
|
||||
with patch("httpx.get") as mock_get, patch("httpx.post") as mock_post:
|
||||
mock_get.return_value = _mock_response(200, {"id": "admin-1"})
|
||||
# First post: create_or_get_tag, second: tag_conversation
|
||||
mock_post.side_effect = [
|
||||
_mock_response(200, {"id": "tag-1", "name": "vip"}),
|
||||
_mock_response(200, {"tags": {"tags": [{"id": "tag-1"}]}}),
|
||||
]
|
||||
result = fn(name="vip", conversation_id="conv-1")
|
||||
assert "error" not in result
|
||||
|
||||
|
||||
class TestIntercomAssignConversation:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_assign_conversation")
|
||||
with patch("httpx.get") as mock_get, patch("httpx.post") as mock_post:
|
||||
mock_get.return_value = _mock_response(200, {"id": "admin-1"})
|
||||
mock_post.return_value = _mock_response(
|
||||
200, {"id": "conv-1", "assignee": {"id": "admin-2"}}
|
||||
)
|
||||
result = fn(
|
||||
conversation_id="conv-1",
|
||||
assignee_id="admin-2",
|
||||
)
|
||||
assert "error" not in result
|
||||
|
||||
def test_invalid_assignee_type(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_assign_conversation")
|
||||
result = fn(
|
||||
conversation_id="conv-1",
|
||||
assignee_id="1",
|
||||
assignee_type="bot",
|
||||
)
|
||||
assert "error" in result
|
||||
assert "assignee_type" in result["error"]
|
||||
|
||||
|
||||
class TestIntercomCloseConversation:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_close_conversation")
|
||||
with patch("httpx.get") as mock_get, patch("httpx.post") as mock_post:
|
||||
mock_get.return_value = _mock_response(200, {"id": "admin-1"})
|
||||
mock_post.return_value = _mock_response(200, {"state": "closed"})
|
||||
result = fn(conversation_id="conv-1")
|
||||
assert "error" not in result
|
||||
|
||||
def test_empty_conversation_id(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_close_conversation")
|
||||
result = fn(conversation_id="")
|
||||
assert "error" in result
|
||||
assert "required" in result["error"]
|
||||
|
||||
|
||||
class TestIntercomListTeams:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_list_teams")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(
|
||||
200, {"teams": [{"id": "t1", "name": "Support"}]}
|
||||
)
|
||||
result = fn()
|
||||
assert "teams" in result
|
||||
|
||||
|
||||
class TestIntercomListConversations:
|
||||
def test_success(self, mcp, monkeypatch):
|
||||
monkeypatch.setenv("INTERCOM_ACCESS_TOKEN", "tok")
|
||||
fn = _tool_fn(mcp, "intercom_list_conversations")
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value = _mock_response(200, {"conversations": [{"id": "conv-1"}]})
|
||||
result = fn(limit=5)
|
||||
assert "conversations" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolRegistration:
|
||||
"""Verify all Intercom tools are registered."""
|
||||
|
||||
EXPECTED_TOOLS = [
|
||||
"intercom_search_conversations",
|
||||
"intercom_get_conversation",
|
||||
"intercom_get_contact",
|
||||
"intercom_search_contacts",
|
||||
"intercom_add_note",
|
||||
"intercom_add_tag",
|
||||
"intercom_assign_conversation",
|
||||
"intercom_list_teams",
|
||||
"intercom_close_conversation",
|
||||
"intercom_create_contact",
|
||||
"intercom_list_conversations",
|
||||
]
|
||||
|
||||
def test_all_tools_registered(self, mcp):
|
||||
tools = _register(mcp)
|
||||
for name in self.EXPECTED_TOOLS:
|
||||
assert name in tools, f"Tool {name} not registered"
|
||||
|
||||
def test_tool_count(self, mcp):
|
||||
tools = _register(mcp)
|
||||
intercom_tools = [k for k in tools if k.startswith("intercom_")]
|
||||
assert len(intercom_tools) == len(self.EXPECTED_TOOLS)
|
||||
@@ -0,0 +1,282 @@
|
||||
"""Tests for Port Scanner tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.port_scanner import register_tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def port_tools(mcp: FastMCP):
|
||||
"""Register port scanner tools and return tool functions."""
|
||||
register_tools(mcp)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scan_fn(port_tools):
|
||||
return port_tools["port_scan"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input Validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
"""Test hostname and port input validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_https_prefix(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": False}
|
||||
result = await scan_fn("https://example.com", ports="80")
|
||||
assert result["hostname"] == "example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_path(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": False}
|
||||
result = await scan_fn("example.com/path", ports="80")
|
||||
assert result["hostname"] == "example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_port_list(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
result = await scan_fn("example.com", ports="invalid,ports")
|
||||
assert "error" in result
|
||||
assert "Invalid port list" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_port_list(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": False}
|
||||
result = await scan_fn("example.com", ports="22,80,443")
|
||||
assert result["ports_scanned"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_clamped(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": False}
|
||||
# Timeout > 10 should be clamped
|
||||
result = await scan_fn("example.com", ports="80", timeout=100.0)
|
||||
assert "error" not in result
|
||||
assert mock_check.call_args[0][2] <= 10.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DNS Resolution Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDnsResolution:
|
||||
"""Test DNS resolution error handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hostname_not_found(self, scan_fn):
|
||||
with patch("socket.gethostbyname", side_effect=socket.gaierror("not found")):
|
||||
result = await scan_fn("nonexistent.invalid")
|
||||
assert "error" in result
|
||||
assert "resolve hostname" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Port Scanning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPortScanning:
|
||||
"""Test port scanning functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_port_detected(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": True, "banner": ""}
|
||||
result = await scan_fn("example.com", ports="80")
|
||||
assert len(result["open_ports"]) == 1
|
||||
assert result["open_ports"][0]["port"] == 80
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closed_port_detected(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": False}
|
||||
result = await scan_fn("example.com", ports="12345")
|
||||
assert len(result["open_ports"]) == 0
|
||||
assert 12345 in result["closed_ports"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banner_captured(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": True, "banner": "SSH-2.0-OpenSSH_8.9"}
|
||||
result = await scan_fn("example.com", ports="22")
|
||||
assert result["open_ports"][0]["banner"] == "SSH-2.0-OpenSSH_8.9"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Risky Port Detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskyPorts:
|
||||
"""Test detection of risky exposed ports."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_port_flagged(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": True, "banner": ""}
|
||||
result = await scan_fn("example.com", ports="3306") # MySQL
|
||||
assert result["open_ports"][0]["severity"] == "high"
|
||||
assert "MySQL" in result["open_ports"][0]["finding"]
|
||||
assert result["grade_input"]["no_database_ports_exposed"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_port_flagged(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": True, "banner": ""}
|
||||
result = await scan_fn("example.com", ports="3389") # RDP
|
||||
assert result["open_ports"][0]["severity"] == "high"
|
||||
assert result["grade_input"]["no_admin_ports_exposed"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_port_flagged(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": True, "banner": ""}
|
||||
result = await scan_fn("example.com", ports="23") # Telnet
|
||||
assert result["open_ports"][0]["severity"] == "medium"
|
||||
assert result["grade_input"]["no_legacy_ports_exposed"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grade Input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGradeInput:
|
||||
"""Test grade_input dict is properly constructed."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grade_input_keys_present(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": False}
|
||||
result = await scan_fn("example.com", ports="80")
|
||||
assert "grade_input" in result
|
||||
grade = result["grade_input"]
|
||||
assert "no_database_ports_exposed" in grade
|
||||
assert "no_admin_ports_exposed" in grade
|
||||
assert "no_legacy_ports_exposed" in grade
|
||||
assert "only_web_ports" in grade
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_web_ports_true(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
# Only 80 and 443 open
|
||||
async def check_port(ip, port, timeout):
|
||||
if port in (80, 443):
|
||||
return {"open": True, "banner": ""}
|
||||
return {"open": False}
|
||||
|
||||
mock_check.side_effect = check_port
|
||||
result = await scan_fn("example.com", ports="22,80,443")
|
||||
assert result["grade_input"]["only_web_ports"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_web_ports_false(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
# SSH port also open
|
||||
async def check_port(ip, port, timeout):
|
||||
if port in (22, 80, 443):
|
||||
return {"open": True, "banner": ""}
|
||||
return {"open": False}
|
||||
|
||||
mock_check.side_effect = check_port
|
||||
result = await scan_fn("example.com", ports="22,80,443")
|
||||
assert result["grade_input"]["only_web_ports"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top20/Top100 Port Lists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPortLists:
|
||||
"""Test predefined port lists."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top20_ports(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": False}
|
||||
result = await scan_fn("example.com", ports="top20")
|
||||
assert result["ports_scanned"] == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top100_ports(self, scan_fn):
|
||||
with patch("socket.gethostbyname", return_value="93.184.216.34"):
|
||||
with patch(
|
||||
"aden_tools.tools.port_scanner.port_scanner._check_port",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_check:
|
||||
mock_check.return_value = {"open": False}
|
||||
result = await scan_fn("example.com", ports="top100")
|
||||
assert result["ports_scanned"] > 20
|
||||
@@ -0,0 +1,316 @@
|
||||
"""Tests for Risk Scorer tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.risk_scorer import register_tools
|
||||
from aden_tools.tools.risk_scorer.risk_scorer import (
|
||||
SSL_CHECKS,
|
||||
_parse_json,
|
||||
_score_category,
|
||||
_score_to_grade,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def risk_tools(mcp: FastMCP):
|
||||
"""Register risk scorer tools and return tool functions."""
|
||||
register_tools(mcp)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def score_fn(risk_tools):
|
||||
return risk_tools["risk_score"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper Function Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScoreToGrade:
|
||||
"""Test _score_to_grade helper."""
|
||||
|
||||
def test_grade_a(self):
|
||||
assert _score_to_grade(95) == "A"
|
||||
assert _score_to_grade(90) == "A"
|
||||
|
||||
def test_grade_b(self):
|
||||
assert _score_to_grade(85) == "B"
|
||||
assert _score_to_grade(75) == "B"
|
||||
|
||||
def test_grade_c(self):
|
||||
assert _score_to_grade(70) == "C"
|
||||
assert _score_to_grade(60) == "C"
|
||||
|
||||
def test_grade_d(self):
|
||||
assert _score_to_grade(55) == "D"
|
||||
assert _score_to_grade(40) == "D"
|
||||
|
||||
def test_grade_f(self):
|
||||
assert _score_to_grade(39) == "F"
|
||||
assert _score_to_grade(0) == "F"
|
||||
|
||||
|
||||
class TestParseJson:
|
||||
"""Test _parse_json helper."""
|
||||
|
||||
def test_valid_json(self):
|
||||
result = _parse_json('{"key": "value"}')
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_invalid_json(self):
|
||||
result = _parse_json("not json")
|
||||
assert result is None
|
||||
|
||||
def test_empty_string(self):
|
||||
result = _parse_json("")
|
||||
assert result is None
|
||||
|
||||
def test_whitespace_only(self):
|
||||
result = _parse_json(" ")
|
||||
assert result is None
|
||||
|
||||
def test_non_dict_json(self):
|
||||
result = _parse_json("[1, 2, 3]")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestScoreCategory:
|
||||
"""Test _score_category helper."""
|
||||
|
||||
def test_perfect_ssl_score(self):
|
||||
grade_input = {
|
||||
"tls_version_ok": True,
|
||||
"cert_valid": True,
|
||||
"cert_expiring_soon": False, # inverted - False is good
|
||||
"strong_cipher": True,
|
||||
"self_signed": False, # inverted - False is good
|
||||
}
|
||||
score, findings = _score_category(grade_input, SSL_CHECKS)
|
||||
assert score == 100
|
||||
assert len(findings) == 0
|
||||
|
||||
def test_failing_ssl_score(self):
|
||||
grade_input = {
|
||||
"tls_version_ok": False,
|
||||
"cert_valid": False,
|
||||
"cert_expiring_soon": True, # inverted - True is bad
|
||||
"strong_cipher": False,
|
||||
"self_signed": True, # inverted - True is bad
|
||||
}
|
||||
score, findings = _score_category(grade_input, SSL_CHECKS)
|
||||
assert score == 0
|
||||
assert len(findings) == 5
|
||||
|
||||
def test_missing_values_half_credit(self):
|
||||
grade_input = {} # All values missing
|
||||
score, findings = _score_category(grade_input, SSL_CHECKS)
|
||||
# Should get half credit for missing values
|
||||
assert 45 <= score <= 55
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full Scoring Flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFullScoring:
|
||||
"""Test full risk scoring."""
|
||||
|
||||
def test_empty_inputs_returns_zero(self, score_fn):
|
||||
result = score_fn()
|
||||
assert result["overall_score"] == 0
|
||||
assert result["overall_grade"] == "F"
|
||||
|
||||
def test_all_categories_skipped(self, score_fn):
|
||||
result = score_fn()
|
||||
for cat in result["categories"].values():
|
||||
assert cat["skipped"] is True
|
||||
|
||||
def test_ssl_results_only(self, score_fn):
|
||||
ssl_data = {
|
||||
"grade_input": {
|
||||
"tls_version_ok": True,
|
||||
"cert_valid": True,
|
||||
"cert_expiring_soon": False,
|
||||
"strong_cipher": True,
|
||||
"self_signed": False,
|
||||
}
|
||||
}
|
||||
result = score_fn(ssl_results=json.dumps(ssl_data))
|
||||
assert result["categories"]["ssl_tls"]["score"] == 100
|
||||
assert result["categories"]["ssl_tls"]["grade"] == "A"
|
||||
assert result["categories"]["ssl_tls"]["skipped"] is False
|
||||
|
||||
def test_headers_results_only(self, score_fn):
|
||||
headers_data = {
|
||||
"grade_input": {
|
||||
"hsts": True,
|
||||
"csp": True,
|
||||
"x_frame_options": True,
|
||||
"x_content_type_options": True,
|
||||
"referrer_policy": True,
|
||||
"permissions_policy": True,
|
||||
"no_leaky_headers": True,
|
||||
}
|
||||
}
|
||||
result = score_fn(headers_results=json.dumps(headers_data))
|
||||
assert result["categories"]["http_headers"]["score"] == 100
|
||||
assert result["categories"]["http_headers"]["grade"] == "A"
|
||||
|
||||
def test_combined_results(self, score_fn):
|
||||
ssl_data = {
|
||||
"grade_input": {
|
||||
"tls_version_ok": True,
|
||||
"cert_valid": True,
|
||||
"cert_expiring_soon": False,
|
||||
"strong_cipher": True,
|
||||
"self_signed": False,
|
||||
}
|
||||
}
|
||||
headers_data = {
|
||||
"grade_input": {
|
||||
"hsts": True,
|
||||
"csp": True,
|
||||
"x_frame_options": True,
|
||||
"x_content_type_options": True,
|
||||
"referrer_policy": True,
|
||||
"permissions_policy": True,
|
||||
"no_leaky_headers": True,
|
||||
}
|
||||
}
|
||||
result = score_fn(
|
||||
ssl_results=json.dumps(ssl_data),
|
||||
headers_results=json.dumps(headers_data),
|
||||
)
|
||||
# Both categories have perfect scores
|
||||
assert result["categories"]["ssl_tls"]["score"] == 100
|
||||
assert result["categories"]["http_headers"]["score"] == 100
|
||||
# Overall should be 100 (weighted average of two 100s)
|
||||
assert result["overall_score"] == 100
|
||||
assert result["overall_grade"] == "A"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top Risks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTopRisks:
|
||||
"""Test top_risks list generation."""
|
||||
|
||||
def test_top_risks_generated(self, score_fn):
|
||||
ssl_data = {
|
||||
"grade_input": {
|
||||
"tls_version_ok": False, # Failing
|
||||
"cert_valid": True,
|
||||
"cert_expiring_soon": False,
|
||||
"strong_cipher": False, # Failing
|
||||
"self_signed": False,
|
||||
}
|
||||
}
|
||||
result = score_fn(ssl_results=json.dumps(ssl_data))
|
||||
assert len(result["top_risks"]) > 0
|
||||
# Should mention TLS version and cipher issues
|
||||
risks_text = " ".join(result["top_risks"])
|
||||
assert "TLS" in risks_text or "cipher" in risks_text.lower()
|
||||
|
||||
def test_top_risks_limited_to_10(self, score_fn):
|
||||
# Create data with many failures
|
||||
ssl_data = {
|
||||
"grade_input": {
|
||||
"tls_version_ok": False,
|
||||
"cert_valid": False,
|
||||
"cert_expiring_soon": True,
|
||||
"strong_cipher": False,
|
||||
"self_signed": True,
|
||||
}
|
||||
}
|
||||
headers_data = {
|
||||
"grade_input": {
|
||||
"hsts": False,
|
||||
"csp": False,
|
||||
"x_frame_options": False,
|
||||
"x_content_type_options": False,
|
||||
"referrer_policy": False,
|
||||
"permissions_policy": False,
|
||||
"no_leaky_headers": False,
|
||||
}
|
||||
}
|
||||
dns_data = {
|
||||
"grade_input": {
|
||||
"spf_present": False,
|
||||
"spf_strict": False,
|
||||
"dmarc_present": False,
|
||||
"dmarc_enforcing": False,
|
||||
"dkim_found": False,
|
||||
"dnssec_enabled": False,
|
||||
"zone_transfer_blocked": False,
|
||||
}
|
||||
}
|
||||
result = score_fn(
|
||||
ssl_results=json.dumps(ssl_data),
|
||||
headers_results=json.dumps(headers_data),
|
||||
dns_results=json.dumps(dns_data),
|
||||
)
|
||||
# Should be capped at 10
|
||||
assert len(result["top_risks"]) <= 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grade Scale
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGradeScale:
|
||||
"""Test grade_scale is included in output."""
|
||||
|
||||
def test_grade_scale_present(self, score_fn):
|
||||
result = score_fn()
|
||||
assert "grade_scale" in result
|
||||
assert "A" in result["grade_scale"]
|
||||
assert "B" in result["grade_scale"]
|
||||
assert "C" in result["grade_scale"]
|
||||
assert "D" in result["grade_scale"]
|
||||
assert "F" in result["grade_scale"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Category Weights
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCategoryWeights:
|
||||
"""Test category weights are applied correctly."""
|
||||
|
||||
def test_weights_included_in_output(self, score_fn):
|
||||
ssl_data = {"grade_input": {"tls_version_ok": True}}
|
||||
result = score_fn(ssl_results=json.dumps(ssl_data))
|
||||
assert result["categories"]["ssl_tls"]["weight"] == 0.20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge Cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def test_invalid_json_ignored(self, score_fn):
|
||||
result = score_fn(ssl_results="not valid json")
|
||||
assert result["categories"]["ssl_tls"]["skipped"] is True
|
||||
|
||||
def test_missing_grade_input_key(self, score_fn):
|
||||
# JSON without grade_input - should use the dict itself
|
||||
data = {"tls_version_ok": True}
|
||||
result = score_fn(ssl_results=json.dumps(data))
|
||||
# Should not error
|
||||
assert "overall_score" in result
|
||||
@@ -0,0 +1,277 @@
|
||||
"""Tests for SSL/TLS Scanner tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.ssl_tls_scanner import register_tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ssl_tools(mcp: FastMCP):
|
||||
"""Register SSL/TLS tools and return tool functions."""
|
||||
register_tools(mcp)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scan_fn(ssl_tools):
|
||||
return ssl_tools["ssl_tls_scan"]
|
||||
|
||||
|
||||
def _mock_cert_dict(
|
||||
days_until_expiry: int = 365,
|
||||
subject: str = "example.com",
|
||||
issuer: str = "Let's Encrypt",
|
||||
san: list[str] | None = None,
|
||||
):
|
||||
"""Create a mock certificate dict."""
|
||||
now = datetime.now(UTC)
|
||||
not_before = now - timedelta(days=30)
|
||||
not_after = now + timedelta(days=days_until_expiry)
|
||||
|
||||
return {
|
||||
"subject": ((("commonName", subject),),),
|
||||
"issuer": ((("commonName", issuer),),),
|
||||
"notBefore": not_before.strftime("%b %d %H:%M:%S %Y GMT"),
|
||||
"notAfter": not_after.strftime("%b %d %H:%M:%S %Y GMT"),
|
||||
"subjectAltName": tuple(("DNS", s) for s in (san or [subject])),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input Validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
"""Test hostname input cleaning."""
|
||||
|
||||
def test_strips_https_prefix(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_ctx.return_value.wrap_socket.side_effect = TimeoutError()
|
||||
result = scan_fn("https://example.com")
|
||||
assert "example.com" in result["error"]
|
||||
assert "https://" not in result["error"]
|
||||
|
||||
def test_strips_http_prefix(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_ctx.return_value.wrap_socket.side_effect = TimeoutError()
|
||||
result = scan_fn("http://example.com")
|
||||
assert "example.com" in result["error"]
|
||||
assert "http://" not in result["error"]
|
||||
|
||||
def test_strips_path(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_ctx.return_value.wrap_socket.side_effect = TimeoutError()
|
||||
result = scan_fn("example.com/path/to/page")
|
||||
assert "example.com" in result["error"]
|
||||
assert "/path" not in result["error"]
|
||||
|
||||
def test_strips_port_from_hostname(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_ctx.return_value.wrap_socket.side_effect = TimeoutError()
|
||||
result = scan_fn("example.com:8443")
|
||||
assert "example.com:443" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connection Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectionErrors:
|
||||
"""Test error handling for connection failures."""
|
||||
|
||||
def test_timeout_error(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.connect.side_effect = TimeoutError()
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert "error" in result
|
||||
assert "timed out" in result["error"]
|
||||
|
||||
def test_connection_refused(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.connect.side_effect = ConnectionRefusedError()
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert "error" in result
|
||||
assert "refused" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TLS Version Detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTlsVersion:
|
||||
"""Test TLS version detection and validation."""
|
||||
|
||||
def test_tls13_ok(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.version.return_value = "TLSv1.3"
|
||||
mock_conn.cipher.return_value = ("TLS_AES_256_GCM_SHA384", "TLSv1.3", 256)
|
||||
mock_conn.getpeercert.return_value = _mock_cert_dict()
|
||||
mock_conn.getpeercert.side_effect = [
|
||||
b"fake_der_cert",
|
||||
_mock_cert_dict(),
|
||||
]
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["tls_version"] == "TLSv1.3"
|
||||
assert result["grade_input"]["tls_version_ok"] is True
|
||||
|
||||
def test_tls10_insecure(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.version.return_value = "TLSv1"
|
||||
mock_conn.cipher.return_value = ("AES256-SHA", "TLSv1", 256)
|
||||
mock_conn.getpeercert.return_value = _mock_cert_dict()
|
||||
mock_conn.getpeercert.side_effect = [
|
||||
b"fake_der_cert",
|
||||
_mock_cert_dict(),
|
||||
]
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["grade_input"]["tls_version_ok"] is False
|
||||
issues = [i["finding"] for i in result.get("issues", [])]
|
||||
assert any("TLS version" in i for i in issues)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cipher Suite Detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCipherSuite:
|
||||
"""Test cipher suite detection and validation."""
|
||||
|
||||
def test_strong_cipher(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.version.return_value = "TLSv1.3"
|
||||
mock_conn.cipher.return_value = ("TLS_AES_256_GCM_SHA384", "TLSv1.3", 256)
|
||||
mock_conn.getpeercert.return_value = _mock_cert_dict()
|
||||
mock_conn.getpeercert.side_effect = [
|
||||
b"fake_der_cert",
|
||||
_mock_cert_dict(),
|
||||
]
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["grade_input"]["strong_cipher"] is True
|
||||
|
||||
def test_weak_cipher_rc4(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.version.return_value = "TLSv1.2"
|
||||
mock_conn.cipher.return_value = ("RC4-SHA", "TLSv1.2", 128)
|
||||
mock_conn.getpeercert.return_value = _mock_cert_dict()
|
||||
mock_conn.getpeercert.side_effect = [
|
||||
b"fake_der_cert",
|
||||
_mock_cert_dict(),
|
||||
]
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["grade_input"]["strong_cipher"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Certificate Validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCertificateValidation:
|
||||
"""Test certificate validation checks."""
|
||||
|
||||
def test_valid_certificate(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.version.return_value = "TLSv1.3"
|
||||
mock_conn.cipher.return_value = ("TLS_AES_256_GCM_SHA384", "TLSv1.3", 256)
|
||||
mock_conn.getpeercert.return_value = _mock_cert_dict(days_until_expiry=365)
|
||||
mock_conn.getpeercert.side_effect = [
|
||||
b"fake_der_cert",
|
||||
_mock_cert_dict(days_until_expiry=365),
|
||||
]
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["grade_input"]["cert_valid"] is True
|
||||
|
||||
def test_expiring_soon(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.version.return_value = "TLSv1.3"
|
||||
mock_conn.cipher.return_value = ("TLS_AES_256_GCM_SHA384", "TLSv1.3", 256)
|
||||
mock_conn.getpeercert.return_value = _mock_cert_dict(days_until_expiry=15)
|
||||
mock_conn.getpeercert.side_effect = [
|
||||
b"fake_der_cert",
|
||||
_mock_cert_dict(days_until_expiry=15),
|
||||
]
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["grade_input"]["cert_expiring_soon"] is True
|
||||
|
||||
def test_self_signed_detected(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.version.return_value = "TLSv1.3"
|
||||
mock_conn.cipher.return_value = ("TLS_AES_256_GCM_SHA384", "TLSv1.3", 256)
|
||||
# Self-signed: subject == issuer
|
||||
mock_conn.getpeercert.return_value = _mock_cert_dict(
|
||||
subject="example.com", issuer="example.com"
|
||||
)
|
||||
mock_conn.getpeercert.side_effect = [
|
||||
b"fake_der_cert",
|
||||
_mock_cert_dict(subject="example.com", issuer="example.com"),
|
||||
]
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert result["grade_input"]["self_signed"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grade Input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGradeInput:
|
||||
"""Test grade_input dict is properly constructed."""
|
||||
|
||||
def test_grade_input_keys_present(self, scan_fn):
|
||||
with patch("ssl.create_default_context") as mock_ctx:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.version.return_value = "TLSv1.3"
|
||||
mock_conn.cipher.return_value = ("TLS_AES_256_GCM_SHA384", "TLSv1.3", 256)
|
||||
mock_conn.getpeercert.return_value = _mock_cert_dict()
|
||||
mock_conn.getpeercert.side_effect = [
|
||||
b"fake_der_cert",
|
||||
_mock_cert_dict(),
|
||||
]
|
||||
mock_ctx.return_value.wrap_socket.return_value = mock_conn
|
||||
|
||||
result = scan_fn("example.com")
|
||||
assert "grade_input" in result
|
||||
grade = result["grade_input"]
|
||||
assert "tls_version_ok" in grade
|
||||
assert "cert_valid" in grade
|
||||
assert "cert_expiring_soon" in grade
|
||||
assert "strong_cipher" in grade
|
||||
assert "self_signed" in grade
|
||||
@@ -0,0 +1,294 @@
|
||||
"""Tests for Subdomain Enumerator tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.subdomain_enumerator import register_tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def subdomain_tools(mcp: FastMCP):
|
||||
"""Register subdomain enumeration tools and return tool functions."""
|
||||
register_tools(mcp)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enumerate_fn(subdomain_tools):
|
||||
return subdomain_tools["subdomain_enumerate"]
|
||||
|
||||
|
||||
def _mock_crtsh_response(subdomains: list[str], status_code: int = 200) -> MagicMock:
|
||||
"""Create a mock crt.sh response."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.json.return_value = [{"name_value": sub} for sub in subdomains]
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input Validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
"""Test domain input cleaning."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_https_prefix(self, enumerate_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response([])
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("https://example.com")
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_http_prefix(self, enumerate_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response([])
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("http://example.com")
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_path(self, enumerate_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response([])
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com/path")
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_results_clamped(self, enumerate_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response([])
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
# max_results should be clamped to 200
|
||||
result = await enumerate_fn("example.com", max_results=500)
|
||||
# Result should not error
|
||||
assert "error" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connection Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectionErrors:
|
||||
"""Test error handling for crt.sh failures."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_error(self, enumerate_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.TimeoutException("timeout")
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
assert "error" in result
|
||||
assert "timed out" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error(self, enumerate_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response([], status_code=500)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
assert "error" in result
|
||||
assert "500" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subdomain Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubdomainDiscovery:
|
||||
"""Test subdomain extraction from CT logs."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subdomains_extracted(self, enumerate_fn):
|
||||
subdomains = [
|
||||
"www.example.com",
|
||||
"api.example.com",
|
||||
"mail.example.com",
|
||||
]
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response(subdomains)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
assert result["total_found"] == 3
|
||||
assert "www.example.com" in result["subdomains"]
|
||||
assert "api.example.com" in result["subdomains"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcards_filtered(self, enumerate_fn):
|
||||
subdomains = [
|
||||
"*.example.com",
|
||||
"www.example.com",
|
||||
"*.api.example.com",
|
||||
]
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response(subdomains)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
# Wildcards should be filtered out
|
||||
assert "*.example.com" not in result["subdomains"]
|
||||
assert "www.example.com" in result["subdomains"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicates_removed(self, enumerate_fn):
|
||||
subdomains = [
|
||||
"www.example.com",
|
||||
"www.example.com",
|
||||
"www.example.com",
|
||||
]
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response(subdomains)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
assert result["total_found"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Interesting Subdomain Detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInterestingSubdomains:
|
||||
"""Test detection of security-relevant subdomains."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_staging_flagged(self, enumerate_fn):
|
||||
subdomains = ["staging.example.com", "www.example.com"]
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response(subdomains)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
assert len(result["interesting"]) > 0
|
||||
interesting_subs = [i["subdomain"] for i in result["interesting"]]
|
||||
assert "staging.example.com" in interesting_subs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_flagged(self, enumerate_fn):
|
||||
subdomains = ["admin.example.com", "www.example.com"]
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response(subdomains)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
interesting_subs = [i["subdomain"] for i in result["interesting"]]
|
||||
assert "admin.example.com" in interesting_subs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_flagged(self, enumerate_fn):
|
||||
subdomains = ["dev.example.com", "www.example.com"]
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response(subdomains)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
interesting_subs = [i["subdomain"] for i in result["interesting"]]
|
||||
assert "dev.example.com" in interesting_subs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grade Input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGradeInput:
|
||||
"""Test grade_input dict is properly constructed."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grade_input_keys_present(self, enumerate_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response([])
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
assert "grade_input" in result
|
||||
grade = result["grade_input"]
|
||||
assert "no_dev_staging_exposed" in grade
|
||||
assert "no_admin_exposed" in grade
|
||||
assert "reasonable_surface_area" in grade
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_dev_staging_true_when_clean(self, enumerate_fn):
|
||||
subdomains = ["www.example.com", "api.example.com"]
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response(subdomains)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
assert result["grade_input"]["no_dev_staging_exposed"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasonable_surface_area(self, enumerate_fn):
|
||||
# Less than 50 subdomains = reasonable
|
||||
subdomains = [f"sub{i}.example.com" for i in range(30)]
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = _mock_crtsh_response(subdomains)
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await enumerate_fn("example.com")
|
||||
assert result["grade_input"]["reasonable_surface_area"] is True
|
||||
@@ -0,0 +1,269 @@
|
||||
"""Tests for Tech Stack Detector tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.tech_stack_detector import register_tools
|
||||
from aden_tools.tools.tech_stack_detector.tech_stack_detector import (
|
||||
_detect_cdn,
|
||||
_detect_cms_from_html,
|
||||
_detect_js_libraries,
|
||||
_detect_server,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tech_tools(mcp: FastMCP):
|
||||
"""Register tech stack tools and return tool functions."""
|
||||
register_tools(mcp)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def detect_fn(tech_tools):
|
||||
return tech_tools["tech_stack_detect"]
|
||||
|
||||
|
||||
class FakeHeaders:
|
||||
"""Minimal stand-in for httpx.Headers."""
|
||||
|
||||
def __init__(self, headers: dict):
|
||||
self._headers = {k.lower(): v for k, v in headers.items()}
|
||||
|
||||
def get(self, name: str, default=None):
|
||||
return self._headers.get(name.lower(), default)
|
||||
|
||||
def get_list(self, name: str) -> list[str]:
|
||||
val = self._headers.get(name.lower())
|
||||
if val is None:
|
||||
return []
|
||||
if isinstance(val, list):
|
||||
return val
|
||||
return [val]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper Function Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetectServer:
|
||||
"""Test _detect_server helper."""
|
||||
|
||||
def test_server_with_version(self):
|
||||
headers = FakeHeaders({"server": "nginx/1.21.0"})
|
||||
result = _detect_server(headers)
|
||||
assert result["name"] == "nginx"
|
||||
assert result["version"] == "1.21.0"
|
||||
|
||||
def test_server_without_version(self):
|
||||
headers = FakeHeaders({"server": "cloudflare"})
|
||||
result = _detect_server(headers)
|
||||
assert result["name"] == "cloudflare"
|
||||
assert result["version"] is None
|
||||
|
||||
def test_no_server_header(self):
|
||||
headers = FakeHeaders({})
|
||||
result = _detect_server(headers)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDetectCdn:
|
||||
"""Test _detect_cdn helper."""
|
||||
|
||||
def test_cloudflare_detected(self):
|
||||
headers = FakeHeaders({"cf-ray": "123abc"})
|
||||
result = _detect_cdn(headers)
|
||||
assert result == "Cloudflare"
|
||||
|
||||
def test_vercel_detected(self):
|
||||
headers = FakeHeaders({"x-vercel-id": "abc123"})
|
||||
result = _detect_cdn(headers)
|
||||
assert result == "Vercel"
|
||||
|
||||
def test_no_cdn(self):
|
||||
headers = FakeHeaders({"content-type": "text/html"})
|
||||
result = _detect_cdn(headers)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDetectJsLibraries:
|
||||
"""Test _detect_js_libraries helper."""
|
||||
|
||||
def test_react_detected(self):
|
||||
html = '<script src="/static/react.min.js"></script>'
|
||||
result = _detect_js_libraries(html)
|
||||
assert "React" in result
|
||||
|
||||
def test_jquery_detected(self):
|
||||
html = '<script src="https://cdn.example.com/jquery-3.6.0.min.js"></script>'
|
||||
result = _detect_js_libraries(html)
|
||||
assert any("jQuery" in lib for lib in result)
|
||||
|
||||
def test_nextjs_detected(self):
|
||||
html = '<script id="__NEXT_DATA__" type="application/json">{}</script>'
|
||||
result = _detect_js_libraries(html)
|
||||
assert "Next.js" in result
|
||||
|
||||
def test_no_libraries(self):
|
||||
html = "<html><body>Simple page</body></html>"
|
||||
result = _detect_js_libraries(html)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestDetectCms:
|
||||
"""Test _detect_cms_from_html helper."""
|
||||
|
||||
def test_wordpress_detected(self):
|
||||
html = '<link href="/wp-content/themes/theme/style.css">'
|
||||
result = _detect_cms_from_html(html)
|
||||
assert result == "WordPress"
|
||||
|
||||
def test_shopify_detected(self):
|
||||
html = '<script src="https://cdn.shopify.com/s/files/1/theme.js"></script>'
|
||||
result = _detect_cms_from_html(html)
|
||||
assert result == "Shopify"
|
||||
|
||||
def test_drupal_detected(self):
|
||||
html = '<script src="/core/misc/drupal.js"></script>'
|
||||
result = _detect_cms_from_html(html)
|
||||
assert result == "Drupal"
|
||||
|
||||
def test_no_cms(self):
|
||||
html = "<html><body>Custom site</body></html>"
|
||||
result = _detect_cms_from_html(html)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connection Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectionErrors:
|
||||
"""Test error handling for connection failures."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_error(self, detect_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.ConnectError("Connection refused")
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await detect_fn("https://example.com")
|
||||
assert "error" in result
|
||||
assert "Connection failed" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_error(self, detect_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.TimeoutException("timeout")
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await detect_fn("https://example.com")
|
||||
assert "error" in result
|
||||
assert "timed out" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full Detection Flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFullDetection:
|
||||
"""Test full tech stack detection."""
|
||||
|
||||
def _mock_response(
|
||||
self,
|
||||
html: str = "<html></html>",
|
||||
headers: dict | None = None,
|
||||
cookies: dict | None = None,
|
||||
):
|
||||
resp = MagicMock()
|
||||
resp.text = html
|
||||
resp.url = "https://example.com"
|
||||
resp.headers = httpx.Headers(headers or {})
|
||||
resp.cookies = httpx.Cookies(cookies or {})
|
||||
return resp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detects_server(self, detect_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = self._mock_response(headers={"server": "nginx/1.21.0"})
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await detect_fn("https://example.com")
|
||||
assert result["server"]["name"] == "nginx"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detects_framework(self, detect_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = self._mock_response(headers={"x-powered-by": "Express"})
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await detect_fn("https://example.com")
|
||||
assert result["framework"] == "Express"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grade Input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGradeInput:
|
||||
"""Test grade_input dict is properly constructed."""
|
||||
|
||||
def _mock_response(self, html: str = "<html></html>", headers: dict | None = None):
|
||||
resp = MagicMock()
|
||||
resp.text = html
|
||||
resp.url = "https://example.com"
|
||||
resp.headers = httpx.Headers(headers or {})
|
||||
resp.cookies = httpx.Cookies()
|
||||
return resp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grade_input_keys_present(self, detect_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = self._mock_response()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await detect_fn("https://example.com")
|
||||
assert "grade_input" in result
|
||||
grade = result["grade_input"]
|
||||
assert "server_version_hidden" in grade
|
||||
assert "framework_version_hidden" in grade
|
||||
assert "security_txt_present" in grade
|
||||
assert "cookies_secure" in grade
|
||||
assert "cookies_httponly" in grade
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_version_exposed(self, detect_fn):
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = self._mock_response(headers={"server": "Apache/2.4.41"})
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
result = await detect_fn("https://example.com")
|
||||
assert result["grade_input"]["server_version_hidden"] is False
|
||||
Reference in New Issue
Block a user