Compare commits
70 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a2e102fe15 | |||
| 119280da1a | |||
| 4d49f74d5a | |||
| 6a42b9c66b | |||
| fc4a39480a | |||
| b98afb01c8 | |||
| ccd6bb7656 | |||
| ea30e5c631 | |||
| d16a3c3b22 | |||
| a03bd78c2e | |||
| 3cca41aab1 | |||
| d19aaed946 | |||
| 9a7db8cf94 | |||
| f50630c551 | |||
| 0ef2e64733 | |||
| 3a8e121d43 | |||
| 51d341b88c | |||
| 7dd70b8e31 | |||
| 84b332d989 | |||
| bcc6848275 | |||
| 75dd053a40 | |||
| 20f2aa09f2 | |||
| fb8c810b3d | |||
| ad21cf4243 | |||
| 1e45cfff67 | |||
| 0280600a47 | |||
| 571ad518dc | |||
| fe37a25cf1 | |||
| e06138628c | |||
| 1ed0edd158 | |||
| 49dbc46082 | |||
| a16a4adc09 | |||
| b4ab1cbd56 | |||
| 6faa63f0d0 | |||
| 2b44af427f | |||
| 11f7401bc2 | |||
| db7b5180dd | |||
| 5b4e56252c | |||
| e3c71f77de | |||
| b09824faec | |||
| c69bc24598 | |||
| 0cf17e1c63 | |||
| feac803491 | |||
| 4aacec30d8 | |||
| b459a2f7a9 | |||
| ca8ede65f0 | |||
| 9a177c46e1 | |||
| 20bea9cd7f | |||
| a7709d489c | |||
| 18dfc997b8 | |||
| 92d0b6addf | |||
| f305745295 | |||
| fc22586752 | |||
| 646440eba3 | |||
| 53e5579326 | |||
| 29a1630d0f | |||
| 171f4ab2ae | |||
| 363a650dfa | |||
| b6e2634537 | |||
| 9f424f2fc0 | |||
| 0715fc5498 | |||
| f9fddd6663 | |||
| 63d017fc21 | |||
| bfb660275e | |||
| d6ae48bc58 | |||
| 94197cbcb9 | |||
| a96cd546c8 | |||
| eb33d4f1c2 | |||
| 4253956326 | |||
| d6b05bf337 |
@@ -1,44 +0,0 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(npm install:*)",
|
||||
"Bash(npm test:*)",
|
||||
"Skill(building-agents-construction)",
|
||||
"Skill(building-agents-construction:*)",
|
||||
"Bash(PYTHONPATH=core:exports pytest:*)",
|
||||
"mcp__agent-builder__create_session",
|
||||
"mcp__agent-builder__get_session_status",
|
||||
"mcp__agent-builder__set_goal",
|
||||
"mcp__agent-builder__list_mcp_servers",
|
||||
"mcp__agent-builder__test_node",
|
||||
"mcp__agent-builder__add_node",
|
||||
"mcp__agent-builder__add_edge",
|
||||
"mcp__agent-builder__validate_graph",
|
||||
"Bash(ruff check:*)",
|
||||
"Bash(PYTHONPATH=core:exports python:*)",
|
||||
"mcp__agent-builder__list_tests",
|
||||
"mcp__agent-builder__generate_constraint_tests",
|
||||
"Bash(python -m agent:*)",
|
||||
"Bash(python agent.py:*)",
|
||||
"Bash(python -c:*)",
|
||||
"Bash(done)",
|
||||
"Bash(xargs cat:*)",
|
||||
"mcp__agent-builder__list_mcp_tools",
|
||||
"mcp__agent-builder__add_mcp_server",
|
||||
"mcp__agent-builder__check_missing_credentials",
|
||||
"mcp__agent-builder__store_credential",
|
||||
"mcp__agent-builder__list_stored_credentials",
|
||||
"mcp__agent-builder__delete_stored_credential",
|
||||
"mcp__agent-builder__verify_credentials",
|
||||
"Bash(PYTHONPATH=/home/timothy/oss/hive/core:/home/timothy/oss/hive/exports python:*)",
|
||||
"Bash(PYTHONPATH=core:exports:tools/src python -m hubspot_input:*)",
|
||||
"mcp__agent-builder__export_graph",
|
||||
"Bash(python3:*)"
|
||||
]
|
||||
},
|
||||
"enableAllProjectMcpServers": true,
|
||||
"enabledMcpjsonServers": [
|
||||
"agent-builder",
|
||||
"tools"
|
||||
]
|
||||
}
|
||||
+2
-12
@@ -218,19 +218,9 @@ class OnlineResearchAgent:
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
# Load MCP servers (always load, needed for tool validation)
|
||||
agent_dir = Path(__file__).parent
|
||||
mcp_config_path = agent_dir / "mcp_servers.json"
|
||||
|
||||
mcp_config_path = Path(__file__).parent / "mcp_servers.json"
|
||||
if mcp_config_path.exists():
|
||||
with open(mcp_config_path) as f:
|
||||
mcp_servers = json.load(f)
|
||||
|
||||
for server_config in mcp_servers.get("servers", []):
|
||||
# Resolve relative cwd paths
|
||||
cwd = server_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
server_config["cwd"] = str(agent_dir / cwd)
|
||||
tool_registry.register_mcp_server(server_config)
|
||||
tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
llm = None
|
||||
if not mock_mode:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
---
|
||||
name: setup-credentials
|
||||
description: Set up and install credentials for an agent. Detects missing credentials from agent config, collects them from the user, and stores them securely in the encrypted credential store at ~/.hive/credentials.
|
||||
description: Set up and install credentials for an agent. Detects missing credentials from agent config, collects them from the user, and stores them securely in the local encrypted store at ~/.hive/credentials.
|
||||
license: Apache-2.0
|
||||
metadata:
|
||||
author: hive
|
||||
version: "2.1"
|
||||
version: "2.2"
|
||||
type: utility
|
||||
---
|
||||
|
||||
@@ -31,48 +31,96 @@ Determine which agent needs credentials. The user will either:
|
||||
|
||||
Locate the agent's directory under `exports/{agent_name}/`.
|
||||
|
||||
### Step 2: Detect Required Credentials
|
||||
### Step 2: Detect Required Credentials (Bash-First)
|
||||
|
||||
Read the agent's configuration to determine which tools and node types it uses:
|
||||
Use bash commands to determine what the agent needs and what's already configured. This avoids Python import issues and works even when `HIVE_CREDENTIAL_KEY` is not set.
|
||||
|
||||
```python
|
||||
from core.framework.runner import AgentRunner
|
||||
#### Step 2a: Read Agent Requirements
|
||||
|
||||
runner = AgentRunner.load("exports/{agent_name}")
|
||||
validation = runner.validate()
|
||||
Extract `required_tools` and node types from the agent config:
|
||||
|
||||
# validation.missing_credentials contains env var names
|
||||
# validation.warnings contains detailed messages with help URLs
|
||||
```bash
|
||||
# Get required tools
|
||||
jq -r '.required_tools[]?' exports/{agent_name}/agent.json 2>/dev/null
|
||||
|
||||
# Get node types from graph nodes
|
||||
jq -r '.graph.nodes[]?.node_type' exports/{agent_name}/agent.json 2>/dev/null | sort -u
|
||||
```
|
||||
|
||||
Alternatively, check the credential store directly:
|
||||
Map the extracted tools and node types to credentials by reading the spec files directly:
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore
|
||||
|
||||
# Use encrypted storage (default: ~/.hive/credentials)
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
|
||||
# Check what's available
|
||||
available = store.list_credentials()
|
||||
print(f"Available credentials: {available}")
|
||||
|
||||
# Check if specific credential exists
|
||||
if store.is_available("hubspot"):
|
||||
print("HubSpot credential found")
|
||||
else:
|
||||
print("HubSpot credential missing")
|
||||
```bash
|
||||
# Read all credential specs — each file defines tools, node_types, env_var, and credential_id
|
||||
cat tools/src/aden_tools/credentials/llm.py tools/src/aden_tools/credentials/search.py tools/src/aden_tools/credentials/email.py tools/src/aden_tools/credentials/integrations.py
|
||||
```
|
||||
|
||||
To see all known credential specs (for help URLs and setup instructions):
|
||||
For each `CredentialSpec`, match its `tools` and `node_types` lists against the agent's required tools and node types. Extract the `env_var`, `credential_id`, and `credential_group` for every match. This is the list of needed credentials.
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
#### Step 2b: Check Existing Credential Sources
|
||||
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
print(f"{name}: env_var={spec.env_var}, aden={spec.aden_supported}")
|
||||
For each needed credential, check three sources. A credential is "found" if it exists in ANY of them:
|
||||
|
||||
**1. Encrypted store metadata index** (unencrypted JSON — no decryption key needed):
|
||||
|
||||
```bash
|
||||
cat ~/.hive/credentials/metadata/index.json 2>/dev/null | jq -r '.credentials | keys[]'
|
||||
```
|
||||
|
||||
If a credential ID appears in this list, it is stored in the encrypted store.
|
||||
|
||||
**2. Environment variables:**
|
||||
|
||||
```bash
|
||||
# Check each needed env var, e.g.:
|
||||
printenv ANTHROPIC_API_KEY > /dev/null 2>&1 && echo "ANTHROPIC_API_KEY: set" || echo "ANTHROPIC_API_KEY: not set"
|
||||
printenv BRAVE_SEARCH_API_KEY > /dev/null 2>&1 && echo "BRAVE_SEARCH_API_KEY: set" || echo "BRAVE_SEARCH_API_KEY: not set"
|
||||
```
|
||||
|
||||
**3. Project `.env` file:**
|
||||
|
||||
```bash
|
||||
# Check each needed env var, e.g.:
|
||||
grep -q '^ANTHROPIC_API_KEY=' .env 2>/dev/null && echo "ANTHROPIC_API_KEY: in .env" || echo "ANTHROPIC_API_KEY: not in .env"
|
||||
grep -q '^BRAVE_SEARCH_API_KEY=' .env 2>/dev/null && echo "BRAVE_SEARCH_API_KEY: in .env" || echo "BRAVE_SEARCH_API_KEY: not in .env"
|
||||
```
|
||||
|
||||
#### Step 2c: HIVE_CREDENTIAL_KEY Check
|
||||
|
||||
If any credentials were found in the encrypted store metadata index, verify the encryption key is available. The key is typically persisted to shell config by a previous setup-credentials run.
|
||||
|
||||
Check both the current session AND shell config files:
|
||||
|
||||
```bash
|
||||
# Check 1: Current session
|
||||
printenv HIVE_CREDENTIAL_KEY > /dev/null 2>&1 && echo "session: set" || echo "session: not set"
|
||||
|
||||
# Check 2: Shell config files (where setup-credentials persists it)
|
||||
# Note: check each file individually to avoid non-zero exit when one doesn't exist
|
||||
for f in ~/.zshrc ~/.bashrc ~/.profile; do [ -f "$f" ] && grep -q 'HIVE_CREDENTIAL_KEY' "$f" && echo "$f"; done
|
||||
```
|
||||
|
||||
Decision logic:
|
||||
- **In current session** — no action needed, credentials in the store are usable
|
||||
- **In shell config but NOT in current session** — the key is persisted but this shell hasn't sourced it. Run `source ~/.zshrc` (or `~/.bashrc`), then re-check. Credentials in the store are usable after sourcing.
|
||||
- **Not in session AND not in shell config** — the key was never persisted. Warn the user that credentials in the store cannot be decrypted. Help fix the key situation (recover/re-persist), do NOT re-collect credential values that are already stored.
|
||||
|
||||
#### Step 2d: Compute Missing & Group
|
||||
|
||||
Diff the "needed" credentials against the "found" credentials to get the truly missing list.
|
||||
|
||||
Group related credentials by their `credential_group` field from the spec files. Credentials that share the same non-empty `credential_group` value should be presented as a single setup step rather than asking for each one individually.
|
||||
|
||||
**If nothing is missing and there's no HIVE_CREDENTIAL_KEY issue:** Report all credentials as configured and skip Steps 3-5. Example:
|
||||
|
||||
```
|
||||
All required credentials are already configured:
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — found in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — found in environment
|
||||
Your agent is ready to run!
|
||||
```
|
||||
|
||||
**If credentials are missing:** Continue to Step 3 with only the missing ones.
|
||||
|
||||
### Step 3: Present Auth Options for Each Missing Credential
|
||||
|
||||
For each missing credential, check what authentication methods are available:
|
||||
@@ -104,7 +152,7 @@ Present the available options using AskUserQuestion:
|
||||
```
|
||||
Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
|
||||
1) Aden Authorization Server (Recommended)
|
||||
1) Aden Platform (OAuth) (Recommended)
|
||||
Secure OAuth2 flow via integration.adenhq.com
|
||||
- Quick setup with automatic token refresh
|
||||
- No need to manage API keys manually
|
||||
@@ -114,7 +162,7 @@ Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
- Requires creating a HubSpot Private App
|
||||
- Full control over scopes and permissions
|
||||
|
||||
3) Custom Credential Store (Advanced)
|
||||
3) Local Credential Setup (Advanced)
|
||||
Programmatic configuration for CI/CD
|
||||
- For automated deployments
|
||||
- Requires manual API calls
|
||||
@@ -122,7 +170,7 @@ Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
|
||||
### Step 4: Execute Auth Flow Based on User Choice
|
||||
|
||||
#### Option 1: Aden Authorization Server
|
||||
#### Option 1: Aden Platform (OAuth)
|
||||
|
||||
This is the recommended flow for supported integrations (HubSpot, etc.).
|
||||
|
||||
@@ -174,7 +222,7 @@ shell_type = detect_shell() # 'bash', 'zsh', or 'unknown'
|
||||
success, config_path = add_env_var_to_shell_config(
|
||||
"ADEN_API_KEY",
|
||||
user_provided_key,
|
||||
comment="Aden authorization server API key"
|
||||
comment="Aden Platform (OAuth) API key"
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -313,7 +361,7 @@ if not result.valid:
|
||||
# 2. Continue anyway (not recommended)
|
||||
```
|
||||
|
||||
**4.2d. Store in Encrypted Credential Store**
|
||||
**4.2d. Store in Local Encrypted Store**
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore, CredentialObject, CredentialKey
|
||||
@@ -340,7 +388,7 @@ store.save_credential(cred)
|
||||
export HUBSPOT_ACCESS_TOKEN="the-value"
|
||||
```
|
||||
|
||||
#### Option 3: Custom Credential Store (Advanced)
|
||||
#### Option 3: Local Credential Setup (Advanced)
|
||||
|
||||
For programmatic/CI/CD setups.
|
||||
|
||||
@@ -408,10 +456,14 @@ Report the result to the user.
|
||||
|
||||
Health checks validate credentials by making lightweight API calls:
|
||||
|
||||
| Credential | Endpoint | What It Checks |
|
||||
| -------------- | --------------------------------------- | --------------------------------- |
|
||||
| `hubspot` | `GET /crm/v3/objects/contacts?limit=1` | Bearer token validity, CRM scopes |
|
||||
| `brave_search` | `GET /res/v1/web/search?q=test&count=1` | API key validity |
|
||||
| Credential | Endpoint | What It Checks |
|
||||
| --------------- | --------------------------------------- | ---------------------------------- |
|
||||
| `anthropic` | `POST /v1/messages` | API key validity |
|
||||
| `brave_search` | `GET /res/v1/web/search?q=test&count=1` | API key validity |
|
||||
| `google_search` | `GET /customsearch/v1?q=test&num=1` | API key + CSE ID validity |
|
||||
| `github` | `GET /user` | Token validity, user identity |
|
||||
| `hubspot` | `GET /crm/v3/objects/contacts?limit=1` | Bearer token validity, CRM scopes |
|
||||
| `resend` | `GET /domains` | API key validity |
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import check_credential_health, HealthCheckResult
|
||||
@@ -424,7 +476,7 @@ result: HealthCheckResult = check_credential_health("hubspot", token_value)
|
||||
|
||||
## Encryption Key (HIVE_CREDENTIAL_KEY)
|
||||
|
||||
The encrypted credential store requires `HIVE_CREDENTIAL_KEY` to encrypt/decrypt credentials.
|
||||
The local encrypted store requires `HIVE_CREDENTIAL_KEY` to encrypt/decrypt credentials.
|
||||
|
||||
- If the user doesn't have one, `EncryptedFileStorage` will auto-generate one and log it
|
||||
- The user MUST persist this key (e.g., in `~/.bashrc` or a secrets manager)
|
||||
@@ -443,7 +495,7 @@ If `HIVE_CREDENTIAL_KEY` is not set:
|
||||
- **NEVER** store credentials in plaintext files, git-tracked files, or agent configs
|
||||
- **NEVER** hardcode credentials in source code
|
||||
- **ALWAYS** use `SecretStr` from Pydantic when handling credential values in Python
|
||||
- **ALWAYS** use the encrypted credential store (`~/.hive/credentials`) for persistence
|
||||
- **ALWAYS** use the local encrypted store (`~/.hive/credentials`) for persistence
|
||||
- **ALWAYS** run health checks before storing credentials (when possible)
|
||||
- **ALWAYS** verify credentials were stored by re-running validation, not by reading them back
|
||||
- When modifying `~/.bashrc` or `~/.zshrc`, confirm with the user first
|
||||
@@ -456,7 +508,8 @@ All credential specs are defined in `tools/src/aden_tools/credentials/`:
|
||||
| ----------------- | ------------- | --------------------------------------------- | -------------- |
|
||||
| `llm.py` | LLM Providers | `anthropic` | No |
|
||||
| `search.py` | Search Tools | `brave_search`, `google_search`, `google_cse` | No |
|
||||
| `integrations.py` | Integrations | `hubspot` | Yes |
|
||||
| `email.py` | Email | `resend` | No |
|
||||
| `integrations.py` | Integrations | `github`, `hubspot` | No / Yes |
|
||||
|
||||
**Note:** Additional LLM providers (Cerebras, Groq, OpenAI) are handled by LiteLLM via environment
|
||||
variables (`CEREBRAS_API_KEY`, `GROQ_API_KEY`, `OPENAI_API_KEY`) but are not yet in CREDENTIAL_SPECS.
|
||||
@@ -507,66 +560,90 @@ token = store.get_key("hubspot", "access_token")
|
||||
## Example Session
|
||||
|
||||
```
|
||||
User: /setup-credentials for my hubspot-agent
|
||||
User: /setup-credentials for my research-agent
|
||||
|
||||
Agent: Let me check what credentials your hubspot-agent needs.
|
||||
Agent: Let me check what credentials your research-agent needs.
|
||||
|
||||
[Runs validation, finds ANTHROPIC_API_KEY and HUBSPOT_ACCESS_TOKEN missing]
|
||||
[Reads agent config]
|
||||
$ jq -r '.required_tools[]?' exports/research-agent/agent.json
|
||||
web_search
|
||||
google_search
|
||||
|
||||
Agent: Your hubspot-agent requires 2 credentials:
|
||||
$ jq -r '.graph.nodes[]?.node_type' exports/research-agent/agent.json | sort -u
|
||||
llm_tool_use
|
||||
|
||||
1. ANTHROPIC_API_KEY - Anthropic API key for LLM calls
|
||||
2. HUBSPOT_ACCESS_TOKEN - HubSpot access token
|
||||
[Maps tools/nodes to credentials using lookup table]
|
||||
Needed: anthropic, brave_search, google_search, google_cse
|
||||
|
||||
Let's set these up one at a time.
|
||||
[Checks encrypted store metadata index]
|
||||
$ cat ~/.hive/credentials/metadata/index.json | jq -r '.credentials | keys[]'
|
||||
anthropic
|
||||
brave_search
|
||||
|
||||
--- Setting up ANTHROPIC_API_KEY ---
|
||||
[Checks environment variables]
|
||||
$ printenv ANTHROPIC_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv BRAVE_SEARCH_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv GOOGLE_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv GOOGLE_CSE_ID > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
|
||||
This credential only supports direct API key entry.
|
||||
[Checks HIVE_CREDENTIAL_KEY since credentials found in store]
|
||||
$ printenv HIVE_CREDENTIAL_KEY > /dev/null 2>&1 && echo "session: set" || echo "session: not set"
|
||||
session: not set
|
||||
$ for f in ~/.zshrc ~/.bashrc ~/.profile; do [ -f "$f" ] && grep -q 'HIVE_CREDENTIAL_KEY' "$f" && echo "$f"; done
|
||||
/Users/user/.zshrc
|
||||
|
||||
[Shows instructions for getting Anthropic API key]
|
||||
[AskUserQuestion: "Please provide your Anthropic API key:"]
|
||||
[User provides key]
|
||||
[Stores in credential store, exports to env]
|
||||
[Key is in shell config but not current session — sourcing it]
|
||||
$ source ~/.zshrc
|
||||
|
||||
--- Setting up HUBSPOT_ACCESS_TOKEN ---
|
||||
[Computes missing credentials]
|
||||
Found:
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — in encrypted store
|
||||
Missing:
|
||||
✗ google_search (GOOGLE_API_KEY)
|
||||
✗ google_cse (GOOGLE_CSE_ID)
|
||||
|
||||
[AskUserQuestion with 3 options]
|
||||
Agent: 2 of 4 required credentials are already configured. Only Google Custom
|
||||
Search needs setup (2 values as a single group).
|
||||
|
||||
Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
1) Aden Authorization Server (Recommended)
|
||||
2) Direct API Key
|
||||
3) Custom Credential Store (Advanced)
|
||||
--- Setting up Google Custom Search (google_search + google_cse) ---
|
||||
|
||||
User: 1 (Aden)
|
||||
This requires two values that work together.
|
||||
|
||||
Agent: Great! Let me check if you have an Aden API key configured...
|
||||
First, the Google API Key:
|
||||
1. Go to https://console.cloud.google.com/apis/credentials
|
||||
2. Create a new project (or select an existing one)
|
||||
3. Enable the "Custom Search API" from the API Library
|
||||
4. Go to Credentials > Create Credentials > API Key
|
||||
5. Copy the generated API key
|
||||
|
||||
[Checks for ADEN_API_KEY, not found]
|
||||
[Opens browser to integration.adenhq.com/setup]
|
||||
|
||||
Agent: I've opened Aden's setup page. Please:
|
||||
1. Sign in or create an account
|
||||
2. Connect your HubSpot account (OAuth happens on Aden's site)
|
||||
3. Copy the API key shown after setup
|
||||
|
||||
[AskUserQuestion: "Please provide your Aden API key:"]
|
||||
[AskUserQuestion: "Please provide your Google API key:"]
|
||||
[User provides key]
|
||||
|
||||
Agent: Would you like me to save this to your shell config (~/.zshrc) for future sessions?
|
||||
Now, the Custom Search Engine ID:
|
||||
1. Go to https://programmablesearchengine.google.com/controlpanel/all
|
||||
2. Click "Add" to create a new search engine
|
||||
3. Under "What to search", select "Search the entire web"
|
||||
4. Give your search engine a name
|
||||
5. Click "Create"
|
||||
6. Copy the Search Engine ID (cx value)
|
||||
|
||||
[User approves]
|
||||
[Saves to ~/.zshrc and ~/.hive/configuration.json]
|
||||
[AskUserQuestion: "Please provide your Google CSE ID:"]
|
||||
[User provides ID]
|
||||
|
||||
Agent: Let me sync your credentials from Aden...
|
||||
[Runs health check with both values - GET /customsearch/v1?q=test&num=1 → 200 OK]
|
||||
[Stores both in local encrypted store, exports to env]
|
||||
|
||||
[Syncs credentials from Aden server - OAuth already done on Aden's side]
|
||||
[Runs health check]
|
||||
|
||||
Agent: HubSpot credentials validated successfully!
|
||||
✓ Google Custom Search credentials valid
|
||||
|
||||
All credentials are now configured:
|
||||
- ANTHROPIC_API_KEY: Stored in encrypted credential store
|
||||
- HUBSPOT_ACCESS_TOKEN: Synced from Aden (OAuth completed on Aden)
|
||||
- Validation passed - your agent is ready to run!
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — already in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — already in encrypted store
|
||||
✓ google_search (GOOGLE_API_KEY) — stored in encrypted store
|
||||
✓ google_cse (GOOGLE_CSE_ID) — stored in encrypted store
|
||||
Your agent is ready to run!
|
||||
```
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Report a bug to help us improve
|
||||
title: '[Bug]: '
|
||||
labels: bug
|
||||
title: "[Bug]: "
|
||||
labels: bug, enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Describe the Bug
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
---
|
||||
name: Feature Request
|
||||
about: Suggest a new feature or enhancement
|
||||
title: '[Feature]: '
|
||||
title: "[Feature]: "
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Problem Statement
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
---
|
||||
name: Integration Request
|
||||
about: Suggest a new integration
|
||||
title: "[Integration]:"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Service
|
||||
|
||||
Name and brief description of the service and what it enables agents to do.
|
||||
|
||||
**Description:** [e.g., "API key for Slack Bot" — short one-liner for the credential spec]
|
||||
|
||||
## Credential Identity
|
||||
|
||||
- **credential_id:** [e.g., `slack`]
|
||||
- **env_var:** [e.g., `SLACK_BOT_TOKEN`]
|
||||
- **credential_key:** [e.g., `access_token`, `api_key`, `bot_token`]
|
||||
|
||||
## Tools
|
||||
|
||||
Tool function names that require this credential:
|
||||
|
||||
- [e.g., `slack_send_message`]
|
||||
- [e.g., `slack_list_channels`]
|
||||
|
||||
## Auth Methods
|
||||
|
||||
- **Direct API key supported:** Yes / No
|
||||
- **Aden OAuth supported:** Yes / No
|
||||
|
||||
If Aden OAuth is supported, describe the OAuth scopes/permissions required.
|
||||
|
||||
## How to Get the Credential
|
||||
|
||||
Link where users obtain the key/token:
|
||||
|
||||
[e.g., https://api.slack.com/apps]
|
||||
|
||||
Step-by-step instructions:
|
||||
|
||||
1. Go to ...
|
||||
2. Create a ...
|
||||
3. Select scopes/permissions: ...
|
||||
4. Copy the key/token
|
||||
|
||||
## Health Check
|
||||
|
||||
A lightweight API call to validate the credential (no writes, no charges).
|
||||
|
||||
- **Endpoint:** [e.g., `https://slack.com/api/auth.test`]
|
||||
- **Method:** [e.g., `GET` or `POST`]
|
||||
- **Auth header:** [e.g., `Authorization: Bearer {token}` or `X-Api-Key: {key}`]
|
||||
- **Parameters (if any):** [e.g., `?limit=1`]
|
||||
- **200 means:** [e.g., key is valid]
|
||||
- **401 means:** [e.g., invalid or expired]
|
||||
- **429 means:** [e.g., rate limited but key is valid]
|
||||
|
||||
## Credential Group
|
||||
|
||||
Does this require multiple credentials configured together? (e.g., Google Custom Search needs
|
||||
both an API key and a CSE ID)
|
||||
|
||||
- [ ] No, single credential
|
||||
- [ ] Yes — list the other credential IDs in the group:
|
||||
|
||||
## Additional Context
|
||||
|
||||
Links to API docs, rate limits, free tier availability, or anything else relevant.
|
||||
+40
-19
@@ -21,23 +21,22 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
run: uv sync --project core --group dev
|
||||
|
||||
- name: Ruff lint
|
||||
run: |
|
||||
ruff check core/
|
||||
ruff check tools/
|
||||
uv run --project core ruff check core/
|
||||
uv run --project core ruff check tools/
|
||||
|
||||
- name: Ruff format
|
||||
run: |
|
||||
ruff format --check core/
|
||||
ruff format --check tools/
|
||||
uv run --project core ruff format --check core/
|
||||
uv run --project core ruff format --check tools/
|
||||
|
||||
test:
|
||||
name: Test Python Framework
|
||||
@@ -52,23 +51,23 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd core
|
||||
pytest tests/ -v
|
||||
uv run pytest tests/ -v
|
||||
|
||||
validate:
|
||||
name: Validate Agent Exports
|
||||
test-tools:
|
||||
name: Test Tools
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -76,13 +75,35 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies and run tests
|
||||
run: |
|
||||
cd tools
|
||||
uv sync --extra dev
|
||||
uv run pytest tests/ -v
|
||||
|
||||
validate:
|
||||
name: Validate Agent Exports
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test, test-tools]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Validate exported agents
|
||||
run: |
|
||||
|
||||
@@ -80,7 +80,13 @@ jobs:
|
||||
- help wanted: Extra attention is needed (if issue needs community input)
|
||||
- backlog: Tracked for the future, but not currently planned or prioritized
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug" and "help wanted").
|
||||
### 6. Estimate size (if NOT a duplicate, spam, or invalid)
|
||||
Apply exactly ONE size label to help contributors match their capacity to the task:
|
||||
- "size: small": Docs, typos, single-file fixes, config changes
|
||||
- "size: medium": Bug fixes with tests, adding a single tool, changes within one package
|
||||
- "size: large": Cross-package changes (core + tools), new modules, complex logic, architectural refactors
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug", "size: small", and "good first issue").
|
||||
|
||||
## Tools Available:
|
||||
- mcp__github__get_issue: Get issue details
|
||||
|
||||
@@ -21,18 +21,19 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd core
|
||||
pytest tests/ -v
|
||||
uv run pytest tests/ -v
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
|
||||
@@ -69,6 +69,8 @@ exports/*
|
||||
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.claude/settings.local.json
|
||||
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": ".venv/bin/python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core",
|
||||
"env": {
|
||||
"PYTHONPATH": "../tools/src"
|
||||
}
|
||||
"command": "uv",
|
||||
"args": ["run", "-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core"
|
||||
},
|
||||
"tools": {
|
||||
"command": ".venv/bin/python",
|
||||
"args": ["mcp_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"env": {
|
||||
"PYTHONPATH": "src:../core"
|
||||
}
|
||||
"command": "uv",
|
||||
"args": ["run", "mcp_server.py", "--stdio"],
|
||||
"cwd": "tools"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+20
-22
@@ -44,7 +44,7 @@ Aden Agent Framework is a Python-based system for building goal-driven, self-imp
|
||||
Ensure you have installed:
|
||||
|
||||
- **Python 3.11+** - [Download](https://www.python.org/downloads/) (3.12 or 3.13 recommended)
|
||||
- **pip** - Package installer for Python (comes with Python)
|
||||
- **uv** - Python package manager ([Install](https://docs.astral.sh/uv/getting-started/installation/))
|
||||
- **git** - Version control
|
||||
- **Claude Code** - [Install](https://docs.anthropic.com/claude/docs/claude-code) (optional, for using building skills)
|
||||
|
||||
@@ -52,7 +52,7 @@ Verify installation:
|
||||
|
||||
```bash
|
||||
python --version # Should be 3.11+
|
||||
pip --version # Should be latest
|
||||
uv --version # Should be latest
|
||||
git --version # Any recent version
|
||||
```
|
||||
|
||||
@@ -128,8 +128,12 @@ hive/ # Repository root
|
||||
│
|
||||
├── .github/ # GitHub configuration
|
||||
│ ├── workflows/
|
||||
│ │ ├── ci.yml # Runs on every PR
|
||||
│ │ └── release.yml # Runs on tags
|
||||
│ │ ├── ci.yml # Lint, test, validate on every PR
|
||||
│ │ ├── release.yml # Runs on tags
|
||||
│ │ ├── pr-requirements.yml # PR requirement checks
|
||||
│ │ ├── pr-check-command.yml # PR check commands
|
||||
│ │ ├── claude-issue-triage.yml # Automated issue triage
|
||||
│ │ └── auto-close-duplicates.yml # Close duplicate issues
|
||||
│ ├── ISSUE_TEMPLATE/ # Bug report & feature request templates
|
||||
│ ├── PULL_REQUEST_TEMPLATE.md # PR description template
|
||||
│ └── CODEOWNERS # Auto-assign reviewers
|
||||
@@ -166,7 +170,6 @@ hive/ # Repository root
|
||||
│ │ ├── testing/ # Testing utilities
|
||||
│ │ └── __init__.py
|
||||
│ ├── pyproject.toml # Package metadata and dependencies
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ ├── README.md # Framework documentation
|
||||
│ ├── MCP_INTEGRATION_GUIDE.md # MCP server integration guide
|
||||
│ └── docs/ # Protocol documentation
|
||||
@@ -182,7 +185,6 @@ hive/ # Repository root
|
||||
│ │ ├── mcp_server.py # HTTP MCP server
|
||||
│ │ └── __init__.py
|
||||
│ ├── pyproject.toml # Package metadata
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ └── README.md # Tools documentation
|
||||
│
|
||||
├── exports/ # AGENT PACKAGES (user-created, gitignored)
|
||||
@@ -191,14 +193,16 @@ hive/ # Repository root
|
||||
├── docs/ # Documentation
|
||||
│ ├── getting-started.md # Quick start guide
|
||||
│ ├── configuration.md # Configuration reference
|
||||
│ ├── architecture.md # System architecture
|
||||
│ └── articles/ # Technical articles
|
||||
│ ├── architecture/ # System architecture
|
||||
│ ├── articles/ # Technical articles
|
||||
│ ├── quizzes/ # Developer quizzes
|
||||
│ └── i18n/ # Translations
|
||||
│
|
||||
├── scripts/ # Build & utility scripts
|
||||
│ ├── setup-python.sh # Python environment setup
|
||||
│ └── setup.sh # Legacy setup script
|
||||
│
|
||||
├── quickstart.sh # Install Claude Code skills
|
||||
├── quickstart.sh # Interactive setup wizard
|
||||
├── ENVIRONMENT_SETUP.md # Complete Python setup guide
|
||||
├── README.md # Project overview
|
||||
├── DEVELOPER.md # This file
|
||||
@@ -375,7 +379,7 @@ def test_ticket_categorization():
|
||||
- **PEP 8** - Follow Python style guide
|
||||
- **Type hints** - Use for function signatures and class attributes
|
||||
- **Docstrings** - Document classes and public functions
|
||||
- **Black** - Code formatter (run with `black .`)
|
||||
- **Ruff** - Linter and formatter (run with `make check`)
|
||||
|
||||
```python
|
||||
# Good
|
||||
@@ -509,8 +513,8 @@ chore(deps): update React to 18.2.0
|
||||
|
||||
1. Create a feature branch from `main`
|
||||
2. Make your changes with clear commits
|
||||
3. Run tests locally: `PYTHONPATH=core:exports python -m pytest`
|
||||
4. Run linting: `black --check .`
|
||||
3. Run tests locally: `make test`
|
||||
4. Run linting: `make check`
|
||||
5. Push and create a PR
|
||||
6. Fill out the PR template
|
||||
7. Request review from CODEOWNERS
|
||||
@@ -528,16 +532,11 @@ chore(deps): update React to 18.2.0
|
||||
```bash
|
||||
# Add to core framework
|
||||
cd core
|
||||
pip install <package>
|
||||
# Then add to requirements.txt or pyproject.toml
|
||||
uv add <package>
|
||||
|
||||
# Add to tools package
|
||||
cd tools
|
||||
pip install <package>
|
||||
# Then add to requirements.txt or pyproject.toml
|
||||
|
||||
# Reinstall in editable mode
|
||||
pip install -e .
|
||||
uv add <package>
|
||||
```
|
||||
|
||||
### Creating a New Agent
|
||||
@@ -670,9 +669,8 @@ cat .env
|
||||
# Or check shell environment
|
||||
echo $ANTHROPIC_API_KEY
|
||||
|
||||
# Copy from .env.example if needed
|
||||
cp .env.example .env
|
||||
# Then edit .env with your API keys
|
||||
# Create .env if needed
|
||||
# Then add your API keys
|
||||
```
|
||||
|
||||
|
||||
|
||||
+86
-3
@@ -21,6 +21,43 @@ This will:
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
## Quick Setup (Windows – PowerShell)
|
||||
|
||||
Windows users can use the native PowerShell setup script.
|
||||
|
||||
Before running the script, allow script execution for the current session:
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
```
|
||||
|
||||
Run setup from the project root:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
This will:
|
||||
|
||||
- Check Python version (requires 3.11+)
|
||||
- Create a local `.venv` virtual environment
|
||||
- Install the core framework package (`framework`)
|
||||
- Install the tools package (`aden_tools`)
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
After setup, activate the virtual environment:
|
||||
|
||||
```powershell
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Set `PYTHONPATH` (required in every new PowerShell session):
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
```
|
||||
|
||||
## Alpine Linux Setup
|
||||
|
||||
If you are using Alpine Linux (e.g., inside a Docker container), you must install system dependencies and use a virtual environment before running the setup script:
|
||||
@@ -100,6 +137,12 @@ For running agents with real LLMs:
|
||||
export ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
$env:ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
## Running Agents
|
||||
|
||||
All agent commands must be run from the project root with `PYTHONPATH` set:
|
||||
@@ -109,9 +152,14 @@ All agent commands must be run from the project root with `PYTHONPATH` set:
|
||||
PYTHONPATH=core:exports python -m agent_name COMMAND
|
||||
```
|
||||
|
||||
### Example Commands
|
||||
Windows (PowerShell):
|
||||
|
||||
After building an agent via `/building-agents-construction`, use these commands:
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
python -m agent_name COMMAND
|
||||
```
|
||||
|
||||
### Example: Support Ticket Agent
|
||||
|
||||
```bash
|
||||
# Validate agent structure
|
||||
@@ -248,6 +296,14 @@ source .venv/bin/activate
|
||||
PYTHONPATH=core:exports python -m your_agent_name demo
|
||||
```
|
||||
|
||||
### PowerShell: “running scripts is disabled on this system”
|
||||
|
||||
Run once per session:
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'framework'"
|
||||
|
||||
**Solution:** Install the core package:
|
||||
@@ -270,6 +326,12 @@ Or run the setup script:
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'openai.\_models'"
|
||||
|
||||
**Cause:** Outdated `openai` package (0.27.x) incompatible with `litellm`
|
||||
@@ -284,12 +346,21 @@ pip install --upgrade "openai>=1.0.0"
|
||||
|
||||
**Cause:** Not running from project root, missing PYTHONPATH, or agent not yet created
|
||||
|
||||
**Solution:** Ensure you're in the project root directory, have built an agent, and use:
|
||||
**Solution:** Ensure you're in `/hive/` and use:
|
||||
|
||||
Linux/macOS:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m your_agent_name validate
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
python -m support_ticket_agent validate
|
||||
```
|
||||
|
||||
### Agent imports fail with "broken installation"
|
||||
|
||||
**Symptom:** `pip list` shows packages pointing to non-existent directories
|
||||
@@ -304,6 +375,12 @@ pip uninstall -y framework tools
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
## Package Structure
|
||||
|
||||
The Hive framework consists of three Python packages:
|
||||
@@ -402,6 +479,12 @@ This design allows agents in `exports/` to be:
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
### 2. Build Agent (Claude Code)
|
||||
|
||||
```
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
[](https://github.com/adenhq/hive/blob/main/LICENSE)
|
||||
[](https://www.ycombinator.com/companies/aden)
|
||||
[](https://hub.docker.com/u/adenhq)
|
||||
[](https://discord.com/invite/MXE49hrKDk)
|
||||
[](https://x.com/aden_hq)
|
||||
[](https://www.linkedin.com/company/teamaden/)
|
||||
@@ -40,6 +39,31 @@ Build reliable, self-improving AI agents without hardcoding workflows. Define yo
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
Hive is designed for developers and teams who want to build **production-grade AI agents** without manually wiring complex workflows.
|
||||
|
||||
Hive is a good fit if you:
|
||||
|
||||
- Want AI agents that **execute real business processes**, not demos
|
||||
- Prefer **goal-driven development** over hardcoded workflows
|
||||
- Need **self-healing and adaptive agents** that improve over time
|
||||
- Require **human-in-the-loop control**, observability, and cost limits
|
||||
- Plan to run agents in **production environments**
|
||||
|
||||
Hive may not be the best fit if you’re only experimenting with simple agent chains or one-off scripts.
|
||||
|
||||
## When Should You Use Hive?
|
||||
|
||||
Use Hive when you need:
|
||||
|
||||
- Long-running, autonomous agents
|
||||
- Multi-agent coordination
|
||||
- Continuous improvement based on failures
|
||||
- Strong monitoring, safety, and budget controls
|
||||
- A framework that evolves with your goals
|
||||
|
||||
|
||||
## What is Aden
|
||||
|
||||
<p align="center">
|
||||
@@ -64,11 +88,13 @@ Aden is a platform for building, deploying, operating, and adapting AI agents:
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
## Prerequisites
|
||||
|
||||
- [Python 3.11+](https://www.python.org/downloads/) for agent development
|
||||
- Python 3.11+ for agent development
|
||||
- Claude Code or Cursor for utilizing agent skills
|
||||
|
||||
> **Note for Windows Users:** It is strongly recommended to use **WSL (Windows Subsystem for Linux)** or **Git Bash** to run this framework. Some core automation scripts may not execute correctly in standard Command Prompt or PowerShell.
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
|
||||
+1
-1
@@ -268,7 +268,7 @@ classDef done fill:#9e9e9e,color:#fff,stroke:#757575
|
||||
- [ ] Wake-up Tool (resume agent tasks)
|
||||
|
||||
### Deployment (Self-Hosted)
|
||||
- [ ] Docker container standardization
|
||||
- [ ] Workder agent docker container standardization
|
||||
- [ ] Headless backend execution
|
||||
- [ ] Exposed API for frontend attachment
|
||||
- [ ] Local monitoring & observability
|
||||
|
||||
@@ -42,7 +42,7 @@ from framework.credentials.storage import ( # noqa: E402
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.event_loop_node import EventLoopNode, JudgeVerdict, LoopConfig # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
@@ -316,47 +316,6 @@ logger.info(
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# ChatJudge — keeps the event loop alive between user messages
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChatJudge:
|
||||
"""Judge that blocks between user messages, keeping the loop alive.
|
||||
|
||||
After the LLM finishes responding, the judge awaits a signal indicating
|
||||
a new user message has been injected, then returns RETRY to continue.
|
||||
"""
|
||||
|
||||
def __init__(self, on_ready=None):
|
||||
self._message_ready = asyncio.Event()
|
||||
self._shutdown = False
|
||||
self._on_ready = on_ready # async callback fired when waiting for input
|
||||
|
||||
async def evaluate(self, context: dict) -> JudgeVerdict:
|
||||
# Notify client that the LLM is done — ready for next input
|
||||
if self._on_ready:
|
||||
await self._on_ready()
|
||||
|
||||
# Block until next user message (or shutdown)
|
||||
self._message_ready.clear()
|
||||
await self._message_ready.wait()
|
||||
|
||||
if self._shutdown:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
def signal_message(self):
|
||||
"""Unblock the judge — a new user message has been injected."""
|
||||
self._message_ready.set()
|
||||
|
||||
def signal_shutdown(self):
|
||||
"""Unblock the judge and let the loop exit cleanly."""
|
||||
self._shutdown = True
|
||||
self._message_ready.set()
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page (embedded)
|
||||
# -------------------------------------------------------------------------
|
||||
@@ -565,7 +524,7 @@ connect();
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Persistent WebSocket: long-lived EventLoopNode kept alive by ChatJudge."""
|
||||
"""Persistent WebSocket: long-lived EventLoopNode with client_facing blocking."""
|
||||
global STORE
|
||||
|
||||
# -- Event forwarding (WebSocket ← EventBus) ----------------------------
|
||||
@@ -593,15 +552,7 @@ async def handle_ws(websocket):
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
# -- Ready callback (tells browser the LLM is done, waiting for input) --
|
||||
async def send_ready():
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "ready"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -- Per-connection state -----------------------------------------------
|
||||
judge = ChatJudge(on_ready=send_ready)
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
@@ -613,6 +564,7 @@ async def handle_ws(websocket):
|
||||
name="Chat Assistant",
|
||||
description="A conversational assistant that remembers context across messages",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to tools. "
|
||||
"You can search the web, scrape webpages, and query HubSpot CRM. "
|
||||
@@ -621,6 +573,18 @@ async def handle_ws(websocket):
|
||||
),
|
||||
)
|
||||
|
||||
# -- Ready callback: subscribe to CLIENT_INPUT_REQUESTED on the bus ---
|
||||
async def on_input_requested(event):
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "ready"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=on_input_requested,
|
||||
)
|
||||
|
||||
async def start_loop(first_message: str):
|
||||
"""Create an EventLoopNode and run it as a background task."""
|
||||
nonlocal node, loop_task
|
||||
@@ -637,7 +601,6 @@ async def handle_ws(websocket):
|
||||
)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=10_000, max_history_tokens=32_000),
|
||||
conversation_store=STORE,
|
||||
tool_executor=tool_executor,
|
||||
@@ -683,10 +646,11 @@ async def handle_ws(websocket):
|
||||
loop_task = asyncio.create_task(_run())
|
||||
|
||||
async def stop_loop():
|
||||
"""Signal the judge and wait for the loop task to finish."""
|
||||
"""Signal the node and wait for the loop task to finish."""
|
||||
nonlocal node, loop_task
|
||||
if loop_task and not loop_task.done():
|
||||
judge.signal_shutdown()
|
||||
if node:
|
||||
node.signal_shutdown()
|
||||
try:
|
||||
await asyncio.wait_for(loop_task, timeout=5.0)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
@@ -712,8 +676,6 @@ async def handle_ws(websocket):
|
||||
if conv_dir.exists():
|
||||
shutil.rmtree(conv_dir)
|
||||
STORE = FileConversationStore(conv_dir)
|
||||
# Reset judge for next session
|
||||
judge = ChatJudge(on_ready=send_ready)
|
||||
await websocket.send(json.dumps({"type": "cleared"}))
|
||||
logger.info("Conversation cleared")
|
||||
continue
|
||||
@@ -727,10 +689,9 @@ async def handle_ws(websocket):
|
||||
logger.info(f"Starting persistent loop: {topic}")
|
||||
await start_loop(topic)
|
||||
else:
|
||||
# Subsequent message — inject and unblock the judge
|
||||
# Subsequent message — inject into the running loop
|
||||
logger.info(f"Injecting message: {topic}")
|
||||
await node.inject_event(topic)
|
||||
judge.signal_message()
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -15,7 +15,7 @@ You cannot skip steps or bypass validation.
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -26,7 +26,7 @@ from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
|
||||
class BuildPhase(str, Enum):
|
||||
class BuildPhase(StrEnum):
|
||||
"""Current phase of the build process."""
|
||||
|
||||
INIT = "init" # Just started
|
||||
|
||||
@@ -8,7 +8,7 @@ containing one or more keys (e.g., api_key, access_token, refresh_token).
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@@ -19,7 +19,7 @@ def _utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class CredentialType(str, Enum):
|
||||
class CredentialType(StrEnum):
|
||||
"""Types of credentials the store can manage."""
|
||||
|
||||
API_KEY = "api_key"
|
||||
|
||||
@@ -11,11 +11,11 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TokenPlacement(str, Enum):
|
||||
class TokenPlacement(StrEnum):
|
||||
"""Where to place the access token in HTTP requests."""
|
||||
|
||||
HEADER_BEARER = "header_bearer"
|
||||
|
||||
@@ -75,6 +75,16 @@ class Message:
|
||||
)
|
||||
|
||||
|
||||
def _extract_spillover_filename(content: str) -> str | None:
|
||||
"""Extract spillover filename from a truncated tool result.
|
||||
|
||||
Matches the pattern produced by EventLoopNode._truncate_tool_result():
|
||||
"saved to 'tool_github_list_stargazers_abc123.txt'"
|
||||
"""
|
||||
match = re.search(r"saved to '([^']+)'", content)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationStore protocol (Phase 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -357,6 +367,85 @@ class NodeConversation:
|
||||
|
||||
# --- Lifecycle ---------------------------------------------------------
|
||||
|
||||
async def prune_old_tool_results(
|
||||
self,
|
||||
protect_tokens: int = 5000,
|
||||
min_prune_tokens: int = 2000,
|
||||
) -> int:
|
||||
"""Replace old tool result content with compact placeholders.
|
||||
|
||||
Walks backward through messages. Recent tool results (within
|
||||
*protect_tokens*) are kept intact. Older tool results have their
|
||||
content replaced with a ~100-char placeholder that preserves the
|
||||
spillover filename reference (if any). Message structure (role,
|
||||
seq, tool_use_id) stays valid for the LLM API.
|
||||
|
||||
Error tool results are never pruned — they prevent re-calling
|
||||
failing tools.
|
||||
|
||||
Returns the number of messages pruned (0 if nothing was pruned).
|
||||
"""
|
||||
if not self._messages:
|
||||
return 0
|
||||
|
||||
# Phase 1: Walk backward, classify tool results as protected vs pruneable
|
||||
protected_tokens = 0
|
||||
pruneable: list[int] = [] # indices into self._messages
|
||||
pruneable_tokens = 0
|
||||
|
||||
for i in range(len(self._messages) - 1, -1, -1):
|
||||
msg = self._messages[i]
|
||||
if msg.role != "tool":
|
||||
continue
|
||||
if msg.is_error:
|
||||
continue # never prune errors
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
|
||||
est = len(msg.content) // 4
|
||||
if protected_tokens < protect_tokens:
|
||||
protected_tokens += est
|
||||
else:
|
||||
pruneable.append(i)
|
||||
pruneable_tokens += est
|
||||
|
||||
# Phase 2: Only prune if enough to be worthwhile
|
||||
if pruneable_tokens < min_prune_tokens:
|
||||
return 0
|
||||
|
||||
# Phase 3: Replace content with compact placeholder
|
||||
count = 0
|
||||
for i in pruneable:
|
||||
msg = self._messages[i]
|
||||
orig_len = len(msg.content)
|
||||
spillover = _extract_spillover_filename(msg.content)
|
||||
|
||||
if spillover:
|
||||
placeholder = (
|
||||
f"[Pruned tool result: {orig_len} chars. "
|
||||
f"Full data in '{spillover}'. "
|
||||
f"Use load_data('{spillover}') to retrieve.]"
|
||||
)
|
||||
else:
|
||||
placeholder = f"[Pruned tool result: {orig_len} chars cleared from context.]"
|
||||
|
||||
self._messages[i] = Message(
|
||||
seq=msg.seq,
|
||||
role=msg.role,
|
||||
content=placeholder,
|
||||
tool_use_id=msg.tool_use_id,
|
||||
tool_calls=msg.tool_calls,
|
||||
is_error=msg.is_error,
|
||||
)
|
||||
count += 1
|
||||
|
||||
if self._store:
|
||||
await self._store.write_part(msg.seq, self._messages[i].to_storage_dict())
|
||||
|
||||
# Reset token estimate — content lengths changed
|
||||
self._last_api_input_tokens = None
|
||||
return count
|
||||
|
||||
async def compact(self, summary: str, keep_recent: int = 2) -> None:
|
||||
"""Replace old messages with a summary, optionally keeping recent ones.
|
||||
|
||||
@@ -371,12 +460,18 @@ class NodeConversation:
|
||||
# Clamp: must discard at least 1 message
|
||||
keep_recent = max(0, min(keep_recent, len(self._messages) - 1))
|
||||
|
||||
if keep_recent > 0:
|
||||
old_messages = self._messages[:-keep_recent]
|
||||
recent_messages = self._messages[-keep_recent:]
|
||||
else:
|
||||
old_messages = self._messages
|
||||
recent_messages = []
|
||||
total = len(self._messages)
|
||||
split = total - keep_recent if keep_recent > 0 else total
|
||||
|
||||
# Advance split past orphaned tool results at the boundary.
|
||||
# Tool-role messages reference a tool_use from the preceding
|
||||
# assistant message; if that assistant message falls into the
|
||||
# compacted (old) portion the tool_result becomes invalid.
|
||||
while split < total and self._messages[split].role == "tool":
|
||||
split += 1
|
||||
|
||||
old_messages = list(self._messages[:split])
|
||||
recent_messages = list(self._messages[split:])
|
||||
|
||||
# Extract protected values from messages being discarded
|
||||
if self._output_keys:
|
||||
|
||||
@@ -21,7 +21,7 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
|
||||
given the current goal, context, and execution state.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -29,7 +29,7 @@ from pydantic import BaseModel, Field
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
|
||||
class EdgeCondition(str, Enum):
|
||||
class EdgeCondition(StrEnum):
|
||||
"""When an edge should be traversed."""
|
||||
|
||||
ALWAYS = "always" # Always after source completes
|
||||
|
||||
@@ -17,6 +17,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from framework.graph.conversation import ConversationStore, NodeConversation
|
||||
@@ -73,6 +74,15 @@ class LoopConfig:
|
||||
max_history_tokens: int = 32_000
|
||||
store_prefix: str = ""
|
||||
|
||||
# --- Tool result context management ---
|
||||
# When a tool result exceeds this character count, it is truncated in the
|
||||
# conversation context. If *spillover_dir* is set the full result is
|
||||
# written to a file and the truncated message includes the filename so
|
||||
# the agent can retrieve it with load_data(). If *spillover_dir* is
|
||||
# ``None`` the result is simply truncated with an explanatory note.
|
||||
max_tool_result_chars: int = 3_000
|
||||
spillover_dir: str | None = None # Path string; created on first use
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output accumulator with write-through persistence
|
||||
@@ -133,13 +143,21 @@ class EventLoopNode(NodeProtocol):
|
||||
Lifecycle:
|
||||
1. Try to restore from durable state (crash recovery)
|
||||
2. If no prior state, init from NodeSpec.system_prompt + input_keys
|
||||
3. Loop: drain injection queue -> stream LLM -> execute tools -> judge
|
||||
3. Loop: drain injection queue -> stream LLM -> execute tools
|
||||
-> if client_facing + no tools: block for user input (inject_event)
|
||||
-> if not client_facing or tools present: judge evaluates
|
||||
(each add_* and set_output writes through to store immediately)
|
||||
4. Publish events to EventBus at each stage
|
||||
5. Write cursor after each iteration
|
||||
6. Terminate when judge returns ACCEPT (or max iterations)
|
||||
6. Terminate when judge returns ACCEPT, shutdown signaled, or max iterations
|
||||
7. Build output dict from OutputAccumulator
|
||||
|
||||
Client-facing blocking: When ``client_facing=True`` and the LLM produces
|
||||
text without tool calls (a natural conversational turn), the node blocks
|
||||
via ``_await_user_input()`` until ``inject_event()`` or ``signal_shutdown()``
|
||||
is called. This separates blocking (node concern) from output evaluation
|
||||
(judge concern).
|
||||
|
||||
Always returns NodeResult with retryable=False semantics. The executor
|
||||
must NOT retry event loop nodes -- retry is handled internally by the
|
||||
judge (RETRY action continues the loop). See WP-7 enforcement.
|
||||
@@ -159,6 +177,9 @@ class EventLoopNode(NodeProtocol):
|
||||
self._tool_executor = tool_executor
|
||||
self._conversation_store = conversation_store
|
||||
self._injection_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
# Client-facing input blocking state
|
||||
self._input_ready = asyncio.Event()
|
||||
self._shutdown = False
|
||||
|
||||
def validate_input(self, ctx: NodeContext) -> list[str]:
|
||||
"""Validate hard requirements only.
|
||||
@@ -211,6 +232,15 @@ class EventLoopNode(NodeProtocol):
|
||||
if set_output_tool:
|
||||
tools.append(set_output_tool)
|
||||
|
||||
logger.info(
|
||||
"[%s] Tools available (%d): %s | client_facing=%s | judge=%s",
|
||||
node_id,
|
||||
len(tools),
|
||||
[t.name for t in tools],
|
||||
ctx.node_spec.client_facing,
|
||||
type(self._judge).__name__ if self._judge else "None",
|
||||
)
|
||||
|
||||
# 4. Publish loop started
|
||||
await self._publish_loop_started(stream_id, node_id)
|
||||
|
||||
@@ -237,12 +267,27 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
# 6d. Pre-turn compaction check (tiered)
|
||||
if conversation.needs_compaction():
|
||||
await self._compact_tiered(ctx, conversation)
|
||||
await self._compact_tiered(ctx, conversation, accumulator)
|
||||
|
||||
# 6e. Run single LLM turn
|
||||
logger.info(
|
||||
"[%s] iter=%d: running LLM turn (msgs=%d)",
|
||||
node_id,
|
||||
iteration,
|
||||
len(conversation.messages),
|
||||
)
|
||||
assistant_text, tool_results_list, turn_tokens = await self._run_single_turn(
|
||||
ctx, conversation, tools, iteration, accumulator
|
||||
)
|
||||
logger.info(
|
||||
"[%s] iter=%d: LLM done — text=%d chars, tool_calls=%d, tokens=%s, accumulator=%s",
|
||||
node_id,
|
||||
iteration,
|
||||
len(assistant_text),
|
||||
len(tool_results_list),
|
||||
turn_tokens,
|
||||
{k: ("set" if v is not None else "None") for k, v in accumulator.to_dict().items()},
|
||||
)
|
||||
total_input_tokens += turn_tokens.get("input", 0)
|
||||
total_output_tokens += turn_tokens.get("output", 0)
|
||||
|
||||
@@ -253,7 +298,7 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
# 6e''. Post-turn compaction check (catches tool-result bloat)
|
||||
if conversation.needs_compaction():
|
||||
await self._compact_tiered(ctx, conversation)
|
||||
await self._compact_tiered(ctx, conversation, accumulator)
|
||||
|
||||
# 6f. Stall detection
|
||||
recent_responses.append(assistant_text)
|
||||
@@ -276,12 +321,85 @@ class EventLoopNode(NodeProtocol):
|
||||
# 6g. Write cursor checkpoint
|
||||
await self._write_cursor(ctx, conversation, accumulator, iteration)
|
||||
|
||||
# 6h. Judge evaluation
|
||||
# 6h. Client-facing input wait
|
||||
logger.info(
|
||||
"[%s] iter=%d: 6h check — client_facing=%s, tool_results=%d",
|
||||
node_id,
|
||||
iteration,
|
||||
ctx.node_spec.client_facing,
|
||||
len(tool_results_list),
|
||||
)
|
||||
if ctx.node_spec.client_facing and not tool_results_list:
|
||||
# LLM finished speaking (no tool calls) on a client-facing node.
|
||||
# This is a conversational turn boundary: block for user input
|
||||
# instead of running the judge.
|
||||
if self._shutdown:
|
||||
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
logger.info("[%s] iter=%d: blocking for user input...", node_id, iteration)
|
||||
got_input = await self._await_user_input(ctx)
|
||||
logger.info("[%s] iter=%d: unblocked, got_input=%s", node_id, iteration, got_input)
|
||||
if not got_input:
|
||||
# Shutdown signaled during wait
|
||||
await self._publish_loop_completed(stream_id, node_id, iteration + 1)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return NodeResult(
|
||||
success=True,
|
||||
output=accumulator.to_dict(),
|
||||
tokens_used=total_input_tokens + total_output_tokens,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
# Clear stall detection — user input resets the conversation
|
||||
recent_responses.clear()
|
||||
|
||||
# For nodes with an explicit judge, fall through to judge
|
||||
# evaluation so the LLM gets structured feedback about missing
|
||||
# outputs (e.g. "Missing output keys: [...]"). Without this,
|
||||
# the LLM may generate text like "Ready to proceed!" without
|
||||
# ever calling set_output, and the judge feedback never reaches it.
|
||||
#
|
||||
# For nodes without a judge (HITL review/approval with all-
|
||||
# nullable keys), keep conversing UNLESS the LLM has already
|
||||
# set an output — in that case fall through to the implicit
|
||||
# judge which will ACCEPT and terminate the node.
|
||||
if self._judge is None:
|
||||
has_outputs = accumulator and any(
|
||||
v is not None for v in accumulator.to_dict().values()
|
||||
)
|
||||
if not has_outputs:
|
||||
logger.info(
|
||||
"[%s] iter=%d: no judge, no outputs, continuing",
|
||||
node_id,
|
||||
iteration,
|
||||
)
|
||||
continue
|
||||
logger.info(
|
||||
"[%s] iter=%d: no judge, outputs set — implicit judge",
|
||||
node_id,
|
||||
iteration,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"[%s] iter=%d: has judge, falling through to 6i",
|
||||
node_id,
|
||||
iteration,
|
||||
)
|
||||
|
||||
# 6i. Judge evaluation
|
||||
should_judge = (
|
||||
(iteration + 1) % self._config.judge_every_n_turns == 0
|
||||
or not tool_results_list # no tool calls = natural stop
|
||||
)
|
||||
|
||||
logger.info("[%s] iter=%d: 6i should_judge=%s", node_id, iteration, should_judge)
|
||||
if should_judge:
|
||||
verdict = await self._evaluate(
|
||||
ctx,
|
||||
@@ -291,15 +409,31 @@ class EventLoopNode(NodeProtocol):
|
||||
tool_results_list,
|
||||
iteration,
|
||||
)
|
||||
fb_preview = (verdict.feedback or "")[:200]
|
||||
logger.info(
|
||||
"[%s] iter=%d: judge verdict=%s feedback=%r",
|
||||
node_id,
|
||||
iteration,
|
||||
verdict.action,
|
||||
fb_preview,
|
||||
)
|
||||
|
||||
if verdict.action == "ACCEPT":
|
||||
# Check for missing output keys
|
||||
missing = self._get_missing_output_keys(accumulator, ctx.node_spec.output_keys)
|
||||
missing = self._get_missing_output_keys(
|
||||
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
|
||||
)
|
||||
if missing and self._judge is not None:
|
||||
hint = (
|
||||
f"Missing required output keys: {missing}. "
|
||||
"Use set_output to provide them."
|
||||
)
|
||||
logger.info(
|
||||
"[%s] iter=%d: ACCEPT but missing keys %s",
|
||||
node_id,
|
||||
iteration,
|
||||
missing,
|
||||
)
|
||||
await conversation.add_user_message(hint)
|
||||
continue
|
||||
|
||||
@@ -348,8 +482,38 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
The content becomes a user message prepended to the next iteration.
|
||||
Thread-safe via asyncio.Queue.
|
||||
Also unblocks _await_user_input() if the node is waiting.
|
||||
"""
|
||||
await self._injection_queue.put(content)
|
||||
self._input_ready.set()
|
||||
|
||||
def signal_shutdown(self) -> None:
|
||||
"""Signal the node to exit its loop cleanly.
|
||||
|
||||
Unblocks any pending _await_user_input() call and causes
|
||||
the loop to exit on the next check.
|
||||
"""
|
||||
self._shutdown = True
|
||||
self._input_ready.set()
|
||||
|
||||
async def _await_user_input(self, ctx: NodeContext) -> bool:
|
||||
"""Block until user input arrives or shutdown is signaled.
|
||||
|
||||
Called when a client_facing node produces text without tool calls —
|
||||
a natural conversational turn boundary.
|
||||
|
||||
Returns True if input arrived, False if shutdown was signaled.
|
||||
"""
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_client_input_requested(
|
||||
stream_id=ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
prompt="",
|
||||
)
|
||||
|
||||
self._input_ready.clear()
|
||||
await self._input_ready.wait()
|
||||
return not self._shutdown
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Single LLM turn with caller-managed tool orchestration
|
||||
@@ -382,7 +546,7 @@ class EventLoopNode(NodeProtocol):
|
||||
"Pre-send guard: context at %.0f%% of budget, compacting",
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
await self._compact_tiered(ctx, conversation)
|
||||
await self._compact_tiered(ctx, conversation, accumulator)
|
||||
|
||||
messages = conversation.to_llm_messages()
|
||||
accumulated_text = ""
|
||||
@@ -414,6 +578,12 @@ class EventLoopNode(NodeProtocol):
|
||||
logger.warning(f"Recoverable stream error: {event.error}")
|
||||
|
||||
final_text = accumulated_text
|
||||
logger.info(
|
||||
"[%s] LLM response: text=%r tool_calls=%s",
|
||||
node_id,
|
||||
accumulated_text[:300] if accumulated_text else "(empty)",
|
||||
[tc.tool_name for tc in tool_calls] if tool_calls else "[]",
|
||||
)
|
||||
|
||||
# Record assistant message (write-through via conversation store)
|
||||
tc_dicts = None
|
||||
@@ -440,13 +610,14 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
# Execute tool calls
|
||||
tool_results: list[dict] = []
|
||||
limit_hit = False
|
||||
executed_in_batch = 0
|
||||
for tc in tool_calls:
|
||||
tool_call_count += 1
|
||||
if tool_call_count > self._config.max_tool_calls_per_turn:
|
||||
logger.warning(
|
||||
f"Max tool calls per turn ({self._config.max_tool_calls_per_turn}) exceeded"
|
||||
)
|
||||
limit_hit = True
|
||||
break
|
||||
executed_in_batch += 1
|
||||
|
||||
# Publish tool call started
|
||||
await self._publish_tool_started(
|
||||
@@ -454,6 +625,12 @@ class EventLoopNode(NodeProtocol):
|
||||
)
|
||||
|
||||
# Handle set_output synthetic tool
|
||||
logger.info(
|
||||
"[%s] tool_call: %s(%s)",
|
||||
node_id,
|
||||
tc.tool_name,
|
||||
json.dumps(tc.tool_input)[:200],
|
||||
)
|
||||
if tc.tool_name == "set_output":
|
||||
result = self._handle_set_output(tc.tool_input, ctx.node_spec.output_keys)
|
||||
result = ToolResult(
|
||||
@@ -467,6 +644,8 @@ class EventLoopNode(NodeProtocol):
|
||||
else:
|
||||
# Execute real tool
|
||||
result = await self._execute_tool(tc)
|
||||
# Truncate large results to prevent context blowup
|
||||
result = self._truncate_tool_result(result, tc.tool_name)
|
||||
|
||||
# Record tool result in conversation (write-through)
|
||||
await conversation.add_tool_result(
|
||||
@@ -493,6 +672,57 @@ class EventLoopNode(NodeProtocol):
|
||||
result.is_error,
|
||||
)
|
||||
|
||||
# If the limit was hit, add error results for every remaining
|
||||
# tool call so the conversation stays consistent. Without this,
|
||||
# the assistant message contains tool_calls that have no
|
||||
# corresponding tool results, causing the LLM to repeat them
|
||||
# in the next turn (infinite loop).
|
||||
if limit_hit:
|
||||
max_tc = self._config.max_tool_calls_per_turn
|
||||
skipped = tool_calls[executed_in_batch:]
|
||||
logger.warning(
|
||||
"Max tool calls per turn (%d) exceeded — discarding %d remaining call(s): %s",
|
||||
max_tc,
|
||||
len(skipped),
|
||||
", ".join(tc.tool_name for tc in skipped),
|
||||
)
|
||||
discard_msg = (
|
||||
f"Tool call discarded: max tool calls per turn "
|
||||
f"({max_tc}) exceeded. Consolidate your work and "
|
||||
f"use fewer tool calls."
|
||||
)
|
||||
for tc in skipped:
|
||||
await conversation.add_tool_result(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=discard_msg,
|
||||
is_error=True,
|
||||
)
|
||||
tool_results.append(
|
||||
{
|
||||
"tool_use_id": tc.tool_use_id,
|
||||
"tool_name": tc.tool_name,
|
||||
"content": discard_msg,
|
||||
"is_error": True,
|
||||
}
|
||||
)
|
||||
# Limit hit — return from this turn so the judge can
|
||||
# evaluate instead of looping back for another stream.
|
||||
return final_text, tool_results, token_counts
|
||||
|
||||
# --- Mid-turn pruning: prevent context blowup within a single turn ---
|
||||
if conversation.usage_ratio() >= 0.6:
|
||||
protect = max(2000, self._config.max_history_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
protect_tokens=protect,
|
||||
min_prune_tokens=max(1000, protect // 3),
|
||||
)
|
||||
if pruned > 0:
|
||||
logger.info(
|
||||
"Mid-turn pruning: cleared %d old tool results (usage now %.0f%%)",
|
||||
pruned,
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
|
||||
# Tool calls processed -- loop back to stream with updated conversation
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
@@ -533,8 +763,36 @@ class EventLoopNode(NodeProtocol):
|
||||
) -> ToolResult:
|
||||
"""Handle set_output tool call. Returns ToolResult (sync)."""
|
||||
key = tool_input.get("key", "")
|
||||
value = tool_input.get("value", "")
|
||||
valid_keys = output_keys or []
|
||||
|
||||
# Recover from truncated JSON (max_tokens hit mid-argument).
|
||||
# The _raw key is set by litellm when json.loads fails.
|
||||
if not key and "_raw" in tool_input:
|
||||
import re
|
||||
|
||||
raw = tool_input["_raw"]
|
||||
key_match = re.search(r'"key"\s*:\s*"(\w+)"', raw)
|
||||
if key_match:
|
||||
key = key_match.group(1)
|
||||
val_match = re.search(r'"value"\s*:\s*"', raw)
|
||||
if val_match:
|
||||
start = val_match.end()
|
||||
value = raw[start:].rstrip()
|
||||
for suffix in ('"}\n', '"}', '"'):
|
||||
if value.endswith(suffix):
|
||||
value = value[: -len(suffix)]
|
||||
break
|
||||
if key:
|
||||
logger.warning(
|
||||
"Recovered set_output args from truncated JSON: key=%s, value_len=%d",
|
||||
key,
|
||||
len(value),
|
||||
)
|
||||
# Re-inject so the caller sees proper key/value
|
||||
tool_input["key"] = key
|
||||
tool_input["value"] = value
|
||||
|
||||
if key not in valid_keys:
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
@@ -567,18 +825,21 @@ class EventLoopNode(NodeProtocol):
|
||||
"assistant_text": assistant_text,
|
||||
"tool_calls": tool_results,
|
||||
"output_accumulator": accumulator.to_dict(),
|
||||
"accumulator": accumulator,
|
||||
"iteration": iteration,
|
||||
"conversation_summary": conversation.export_summary(),
|
||||
"output_keys": ctx.node_spec.output_keys,
|
||||
"missing_keys": self._get_missing_output_keys(
|
||||
accumulator, ctx.node_spec.output_keys
|
||||
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
|
||||
),
|
||||
}
|
||||
return await self._judge.evaluate(context)
|
||||
|
||||
# Implicit judge: accept when no tool calls and all output keys present
|
||||
if not tool_results:
|
||||
missing = self._get_missing_output_keys(accumulator, ctx.node_spec.output_keys)
|
||||
missing = self._get_missing_output_keys(
|
||||
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
|
||||
)
|
||||
if not missing:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
else:
|
||||
@@ -596,6 +857,60 @@ class EventLoopNode(NodeProtocol):
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_call_history(
|
||||
conversation: NodeConversation,
|
||||
max_entries: int = 30,
|
||||
) -> str:
|
||||
"""Build a compact tool call history from the conversation.
|
||||
|
||||
Used in compaction summaries to prevent the LLM from re-calling
|
||||
tools it already called. Extracts:
|
||||
- Tool call counts (e.g. "github_list_pull_requests (6x)")
|
||||
- Files saved via save_data
|
||||
- Outputs set via set_output
|
||||
- Errors encountered
|
||||
"""
|
||||
tool_counts: dict[str, int] = {}
|
||||
files_saved: list[str] = []
|
||||
outputs_set: list[str] = []
|
||||
errors: list[str] = []
|
||||
|
||||
for msg in conversation.messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
tool_counts[name] = tool_counts.get(name, 0) + 1
|
||||
try:
|
||||
args = json.loads(func.get("arguments", "{}"))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
if name == "save_data" and args.get("filename"):
|
||||
files_saved.append(args["filename"])
|
||||
if name == "set_output" and args.get("key"):
|
||||
outputs_set.append(args["key"])
|
||||
|
||||
if msg.role == "tool" and msg.is_error:
|
||||
preview = msg.content[:120].replace("\n", " ")
|
||||
errors.append(preview)
|
||||
|
||||
parts: list[str] = []
|
||||
if tool_counts:
|
||||
lines = [f" {n} ({c}x)" for n, c in tool_counts.items()]
|
||||
parts.append("TOOLS ALREADY CALLED:\n" + "\n".join(lines[:max_entries]))
|
||||
if files_saved:
|
||||
unique = list(dict.fromkeys(files_saved))
|
||||
parts.append("FILES SAVED: " + ", ".join(unique))
|
||||
if outputs_set:
|
||||
unique = list(dict.fromkeys(outputs_set))
|
||||
parts.append("OUTPUTS SET: " + ", ".join(unique))
|
||||
if errors:
|
||||
parts.append(
|
||||
"ERRORS (do NOT retry these):\n" + "\n".join(f" - {e}" for e in errors[:10])
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _build_initial_message(self, ctx: NodeContext) -> str:
|
||||
"""Build the initial user message from input data and memory.
|
||||
|
||||
@@ -624,11 +939,13 @@ class EventLoopNode(NodeProtocol):
|
||||
self,
|
||||
accumulator: OutputAccumulator,
|
||||
output_keys: list[str] | None,
|
||||
nullable_keys: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Return output keys that have not been set yet."""
|
||||
"""Return output keys that have not been set yet (excluding nullable keys)."""
|
||||
if not output_keys:
|
||||
return []
|
||||
return [k for k in output_keys if accumulator.get(k) is None]
|
||||
skip = set(nullable_keys) if nullable_keys else set()
|
||||
return [k for k in output_keys if k not in skip and accumulator.get(k) is None]
|
||||
|
||||
def _is_stalled(self, recent_responses: list[str]) -> bool:
|
||||
"""Detect stall: N consecutive identical non-empty responses."""
|
||||
@@ -652,10 +969,84 @@ class EventLoopNode(NodeProtocol):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
def _truncate_tool_result(
|
||||
self,
|
||||
result: ToolResult,
|
||||
tool_name: str,
|
||||
) -> ToolResult:
|
||||
"""Truncate a large tool result to keep the conversation context small.
|
||||
|
||||
If *spillover_dir* is configured and the result exceeds
|
||||
*max_tool_result_chars*, the full content is written to a file and
|
||||
the in-context result is replaced with a preview + filename reference.
|
||||
Without *spillover_dir*, large results are truncated with a note.
|
||||
|
||||
Small results (and errors) pass through unchanged.
|
||||
"""
|
||||
limit = self._config.max_tool_result_chars
|
||||
if limit <= 0 or result.is_error or len(result.content) <= limit:
|
||||
return result
|
||||
|
||||
# Determine a preview size — leave room for the metadata wrapper
|
||||
preview_chars = max(limit - 300, limit // 2)
|
||||
preview = result.content[:preview_chars]
|
||||
|
||||
spill_dir = self._config.spillover_dir
|
||||
if spill_dir:
|
||||
spill_path = Path(spill_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
# Use tool_use_id for uniqueness, sanitise for filesystem
|
||||
safe_id = result.tool_use_id.replace("/", "_")[:60]
|
||||
filename = f"tool_{tool_name}_{safe_id}.txt"
|
||||
|
||||
# Pretty-print JSON content so load_data's line-based
|
||||
# pagination works correctly. Compact JSON (no newlines)
|
||||
# would produce a single line that defeats pagination.
|
||||
write_content = result.content
|
||||
try:
|
||||
parsed = json.loads(result.content)
|
||||
write_content = json.dumps(parsed, indent=2, ensure_ascii=False)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass # Not JSON — write as-is
|
||||
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
|
||||
truncated = (
|
||||
f"[Result from {tool_name}: {len(result.content)} chars — "
|
||||
f"too large for context, saved to '{filename}'. "
|
||||
f"Use load_data('{filename}') to read the full result.]\n\n"
|
||||
f"Preview:\n{preview}…"
|
||||
)
|
||||
logger.info(
|
||||
"Tool result spilled to file: %s (%d chars → %s)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
filename,
|
||||
)
|
||||
else:
|
||||
truncated = (
|
||||
f"[Result from {tool_name}: {len(result.content)} chars — "
|
||||
f"truncated to fit context budget. Only the first "
|
||||
f"{preview_chars} chars are shown.]\n\n{preview}…"
|
||||
)
|
||||
logger.info(
|
||||
"Tool result truncated in-place: %s (%d → %d chars)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
len(truncated),
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=truncated,
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def _compact_tiered(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator | None = None,
|
||||
) -> None:
|
||||
"""Run compaction with aggressiveness scaled to usage level.
|
||||
|
||||
@@ -667,27 +1058,90 @@ class EventLoopNode(NodeProtocol):
|
||||
"""
|
||||
ratio = conversation.usage_ratio()
|
||||
|
||||
if ratio >= 1.2:
|
||||
# Emergency -- don't risk another LLM call on a bloated context
|
||||
logger.warning("Emergency compaction triggered (usage %.0f%%)", ratio * 100)
|
||||
await conversation.compact(
|
||||
"Previous conversation context (emergency compaction).",
|
||||
keep_recent=1,
|
||||
# --- Tier 0: Prune old tool results (zero-cost, no LLM call) ---
|
||||
protect = max(2000, self._config.max_history_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
protect_tokens=protect,
|
||||
min_prune_tokens=max(1000, protect // 3),
|
||||
)
|
||||
if pruned > 0:
|
||||
new_ratio = conversation.usage_ratio()
|
||||
logger.info(
|
||||
"Pruned %d old tool results: %.0f%% -> %.0f%%",
|
||||
pruned,
|
||||
ratio * 100,
|
||||
new_ratio * 100,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
# Pruning freed enough — skip full compaction entirely
|
||||
if self._event_bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
await self._event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CUSTOM,
|
||||
stream_id=ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data={
|
||||
"custom_type": "node_compaction",
|
||||
"node_id": ctx.node_id,
|
||||
"level": "prune_only",
|
||||
"usage_before": round(ratio * 100),
|
||||
"usage_after": round(new_ratio * 100),
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
ratio = new_ratio
|
||||
|
||||
if ratio >= 1.2:
|
||||
level = "emergency"
|
||||
logger.warning("Emergency compaction triggered (usage %.0f%%)", ratio * 100)
|
||||
summary = self._build_emergency_summary(ctx, accumulator, conversation)
|
||||
await conversation.compact(summary, keep_recent=1)
|
||||
elif ratio >= 1.0:
|
||||
level = "aggressive"
|
||||
logger.info("Aggressive compaction triggered (usage %.0f%%)", ratio * 100)
|
||||
summary = await self._generate_compaction_summary(ctx, conversation)
|
||||
await conversation.compact(summary, keep_recent=2)
|
||||
else:
|
||||
level = "normal"
|
||||
summary = await self._generate_compaction_summary(ctx, conversation)
|
||||
await conversation.compact(summary, keep_recent=4)
|
||||
|
||||
new_ratio = conversation.usage_ratio()
|
||||
logger.info(
|
||||
"Compaction complete (%s): %.0f%% -> %.0f%%",
|
||||
level,
|
||||
ratio * 100,
|
||||
new_ratio * 100,
|
||||
)
|
||||
if self._event_bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
await self._event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CUSTOM,
|
||||
stream_id=ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data={
|
||||
"custom_type": "node_compaction",
|
||||
"node_id": ctx.node_id,
|
||||
"level": level,
|
||||
"usage_before": round(ratio * 100),
|
||||
"usage_after": round(new_ratio * 100),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def _generate_compaction_summary(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
) -> str:
|
||||
"""Use LLM to generate a conversation summary for compaction."""
|
||||
tool_history = self._extract_tool_call_history(conversation)
|
||||
|
||||
messages_text = "\n".join(
|
||||
f"[{m.role}]: {m.content[:200]}" for m in conversation.messages[-10:]
|
||||
)
|
||||
@@ -696,17 +1150,106 @@ class EventLoopNode(NodeProtocol):
|
||||
"preserving key decisions and results:\n\n"
|
||||
f"{messages_text}"
|
||||
)
|
||||
if tool_history:
|
||||
prompt += (
|
||||
"\n\nINCLUDE this tool history verbatim in your summary "
|
||||
"(the agent needs it to avoid re-calling tools):\n\n"
|
||||
f"{tool_history}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = ctx.llm.complete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="Summarize conversations concisely.",
|
||||
max_tokens=300,
|
||||
system=(
|
||||
"Summarize conversations concisely. Always preserve the tool history section."
|
||||
),
|
||||
max_tokens=500,
|
||||
)
|
||||
return response.content
|
||||
summary = response.content
|
||||
# Ensure tool history is present even if LLM dropped it
|
||||
if tool_history and "TOOLS ALREADY CALLED" not in summary:
|
||||
summary += "\n\n" + tool_history
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.warning(f"Compaction summary generation failed: {e}")
|
||||
if tool_history:
|
||||
return f"Previous conversation context (summary unavailable).\n\n{tool_history}"
|
||||
return "Previous conversation context (summary unavailable)."
|
||||
|
||||
def _build_emergency_summary(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
accumulator: OutputAccumulator | None = None,
|
||||
conversation: NodeConversation | None = None,
|
||||
) -> str:
|
||||
"""Build a structured emergency compaction summary.
|
||||
|
||||
Unlike normal/aggressive compaction which uses an LLM summary,
|
||||
emergency compaction cannot afford an LLM call (context is already
|
||||
way over budget). Instead, build a deterministic summary from the
|
||||
node's known state so the LLM can continue working after
|
||||
compaction without losing track of its task and inputs.
|
||||
"""
|
||||
parts = [
|
||||
"EMERGENCY COMPACTION — previous conversation was too large "
|
||||
"and has been replaced with this summary.\n"
|
||||
]
|
||||
|
||||
# 1. Node identity
|
||||
spec = ctx.node_spec
|
||||
parts.append(f"NODE: {spec.name} (id={spec.id})")
|
||||
if spec.description:
|
||||
parts.append(f"PURPOSE: {spec.description}")
|
||||
|
||||
# 2. Inputs the node received
|
||||
input_lines = []
|
||||
for key in spec.input_keys:
|
||||
value = ctx.input_data.get(key) or ctx.memory.read(key)
|
||||
if value is not None:
|
||||
# Truncate long values but keep them recognisable
|
||||
v_str = str(value)
|
||||
if len(v_str) > 200:
|
||||
v_str = v_str[:200] + "…"
|
||||
input_lines.append(f" {key}: {v_str}")
|
||||
if input_lines:
|
||||
parts.append("INPUTS:\n" + "\n".join(input_lines))
|
||||
|
||||
# 3. Output accumulator state (what's been set so far)
|
||||
if accumulator:
|
||||
acc_state = accumulator.to_dict()
|
||||
set_keys = {k: v for k, v in acc_state.items() if v is not None}
|
||||
missing = [k for k, v in acc_state.items() if v is None]
|
||||
if set_keys:
|
||||
lines = [f" {k}: {str(v)[:150]}" for k, v in set_keys.items()]
|
||||
parts.append("OUTPUTS ALREADY SET:\n" + "\n".join(lines))
|
||||
if missing:
|
||||
parts.append(f"OUTPUTS STILL NEEDED: {', '.join(missing)}")
|
||||
elif spec.output_keys:
|
||||
parts.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}")
|
||||
|
||||
# 4. Available tools reminder
|
||||
if spec.tools:
|
||||
parts.append(f"AVAILABLE TOOLS: {', '.join(spec.tools)}")
|
||||
|
||||
# 5. Spillover files hint
|
||||
if self._config.spillover_dir:
|
||||
parts.append(
|
||||
"NOTE: Large tool results were saved to files. "
|
||||
"Use load_data('<filename>') to read them."
|
||||
)
|
||||
|
||||
# 6. Tool call history (prevent re-calling tools)
|
||||
if conversation is not None:
|
||||
tool_history = self._extract_tool_call_history(conversation)
|
||||
if tool_history:
|
||||
parts.append(tool_history)
|
||||
|
||||
parts.append(
|
||||
"\nContinue working towards setting the remaining outputs. "
|
||||
"Use your tools and the inputs above."
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Persistence: restore, cursor, injection, pause
|
||||
# -------------------------------------------------------------------
|
||||
@@ -761,6 +1304,10 @@ class EventLoopNode(NodeProtocol):
|
||||
while not self._injection_queue.empty():
|
||||
try:
|
||||
content = self._injection_queue.get_nowait()
|
||||
logger.info(
|
||||
"[drain] injected message: %s",
|
||||
content[:200] if content else "(empty)",
|
||||
)
|
||||
await conversation.add_user_message(f"[External event]: {content}")
|
||||
count += 1
|
||||
except asyncio.QueueEmpty:
|
||||
@@ -776,7 +1323,10 @@ class EventLoopNode(NodeProtocol):
|
||||
"""Check if pause has been requested. Returns True if paused."""
|
||||
pause_requested = ctx.input_data.get("pause_requested", False)
|
||||
if not pause_requested:
|
||||
pause_requested = ctx.memory.read("pause_requested") or False
|
||||
try:
|
||||
pause_requested = ctx.memory.read("pause_requested") or False
|
||||
except (PermissionError, KeyError):
|
||||
pause_requested = False
|
||||
if pause_requested:
|
||||
logger.info(f"Pause requested at iteration {iteration}")
|
||||
return True
|
||||
|
||||
@@ -16,7 +16,7 @@ from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.edge import EdgeSpec, GraphSpec
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import (
|
||||
FunctionNode,
|
||||
@@ -55,6 +55,9 @@ class ExecutionResult:
|
||||
had_partial_failures: bool = False # True if any node failed but eventually succeeded
|
||||
execution_quality: str = "clean" # "clean", "degraded", or "failed"
|
||||
|
||||
# Visit tracking (for feedback/callback edges)
|
||||
node_visit_counts: dict[str, int] = field(default_factory=dict) # {node_id: visit_count}
|
||||
|
||||
@property
|
||||
def is_clean_success(self) -> bool:
|
||||
"""True only if execution succeeded with no retries or failures."""
|
||||
@@ -251,6 +254,8 @@ class GraphExecutor:
|
||||
total_tokens = 0
|
||||
total_latency = 0
|
||||
node_retry_counts: dict[str, int] = {} # Track retries per node
|
||||
node_visit_counts: dict[str, int] = {} # Track visits for feedback loops
|
||||
_is_retry = False # True when looping back for a retry (not a new visit)
|
||||
|
||||
# Determine entry point (may differ if resuming)
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
@@ -279,6 +284,34 @@ class GraphExecutor:
|
||||
if node_spec is None:
|
||||
raise RuntimeError(f"Node not found: {current_node_id}")
|
||||
|
||||
# Enforce max_node_visits (feedback/callback edge support)
|
||||
# Don't increment visit count on retries — retries are not new visits
|
||||
if not _is_retry:
|
||||
cnt = node_visit_counts.get(current_node_id, 0) + 1
|
||||
node_visit_counts[current_node_id] = cnt
|
||||
_is_retry = False
|
||||
max_visits = getattr(node_spec, "max_node_visits", 1)
|
||||
if max_visits > 0 and node_visit_counts[current_node_id] > max_visits:
|
||||
self.logger.warning(
|
||||
f" ⊘ Node '{node_spec.name}' visit limit reached "
|
||||
f"({node_visit_counts[current_node_id]}/{max_visits}), skipping"
|
||||
)
|
||||
# Skip execution — follow outgoing edges using current memory
|
||||
skip_result = NodeResult(success=True, output=memory.read_all())
|
||||
next_node = self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
current_node_spec=node_spec,
|
||||
result=skip_result,
|
||||
memory=memory,
|
||||
)
|
||||
if next_node is None:
|
||||
self.logger.info(" → No more edges after visit limit, ending")
|
||||
break
|
||||
current_node_id = next_node
|
||||
continue
|
||||
|
||||
path.append(current_node_id)
|
||||
|
||||
# Check if pause (HITL) before execution
|
||||
@@ -405,6 +438,7 @@ class GraphExecutor:
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
_is_retry = True
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - fail the execution
|
||||
@@ -447,6 +481,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
# Check if we just executed a pause node - if so, save state and return
|
||||
@@ -486,6 +521,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
# Check if this is a terminal node - if so, we're done
|
||||
@@ -606,6 +642,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -632,6 +669,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
def _build_context(
|
||||
@@ -859,6 +897,21 @@ class GraphExecutor:
|
||||
):
|
||||
traversable.append(edge)
|
||||
|
||||
# Priority filtering for CONDITIONAL edges:
|
||||
# When multiple CONDITIONAL edges match, keep only the highest-priority
|
||||
# group. This prevents mutually-exclusive conditional branches (e.g.
|
||||
# forward vs. feedback) from incorrectly triggering fan-out.
|
||||
# ON_SUCCESS / other edge types are unaffected.
|
||||
if len(traversable) > 1:
|
||||
conditionals = [e for e in traversable if e.condition == EdgeCondition.CONDITIONAL]
|
||||
if len(conditionals) > 1:
|
||||
max_prio = max(e.priority for e in conditionals)
|
||||
traversable = [
|
||||
e
|
||||
for e in traversable
|
||||
if e.condition != EdgeCondition.CONDITIONAL or e.priority == max_prio
|
||||
]
|
||||
|
||||
return traversable
|
||||
|
||||
def _find_convergence_node(
|
||||
|
||||
@@ -12,13 +12,13 @@ Goals are:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GoalStatus(str, Enum):
|
||||
class GoalStatus(StrEnum):
|
||||
"""Lifecycle status of a goal."""
|
||||
|
||||
DRAFT = "draft" # Being defined
|
||||
|
||||
@@ -6,11 +6,11 @@ where agents need to gather input from humans.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class HITLInputType(str, Enum):
|
||||
class HITLInputType(StrEnum):
|
||||
"""Type of input expected from human."""
|
||||
|
||||
FREE_TEXT = "free_text" # Open-ended text response
|
||||
|
||||
@@ -16,10 +16,12 @@ Protocol:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -208,6 +210,15 @@ class NodeSpec(BaseModel):
|
||||
max_retries: int = Field(default=3)
|
||||
retry_on: list[str] = Field(default_factory=list, description="Error types to retry on")
|
||||
|
||||
# Visit limits (for feedback/callback edges)
|
||||
max_node_visits: int = Field(
|
||||
default=1,
|
||||
description=(
|
||||
"Max times this node executes in one graph run. "
|
||||
"Set >1 for feedback loops. 0 = unlimited (max_steps guards)."
|
||||
),
|
||||
)
|
||||
|
||||
# Pydantic model for output validation
|
||||
output_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
@@ -1523,6 +1534,8 @@ Do NOT fabricate data or return empty objects."""
|
||||
|
||||
def _build_system_prompt(self, ctx: NodeContext) -> str:
|
||||
"""Build the system prompt."""
|
||||
from datetime import datetime
|
||||
|
||||
parts = []
|
||||
|
||||
if ctx.node_spec.system_prompt:
|
||||
@@ -1545,6 +1558,15 @@ Do NOT fabricate data or return empty objects."""
|
||||
|
||||
parts.append(prompt)
|
||||
|
||||
# Inject current datetime so LLM knows "now"
|
||||
utc_dt = datetime.now(UTC)
|
||||
local_dt = datetime.now().astimezone()
|
||||
local_tz_name = local_dt.tzname() or "Unknown"
|
||||
parts.append("\n## Runtime Context")
|
||||
parts.append(f"- Current Date/Time (UTC): {utc_dt.isoformat()}")
|
||||
parts.append(f"- Local Timezone: {local_tz_name}")
|
||||
parts.append(f"- Current Date/Time (Local): {local_dt.isoformat()}")
|
||||
|
||||
if ctx.goal_context:
|
||||
parts.append("\n# Goal Context")
|
||||
parts.append(ctx.goal_context)
|
||||
@@ -1746,8 +1768,19 @@ class FunctionNode(NodeProtocol):
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Call the function
|
||||
result = self.func(**ctx.input_data)
|
||||
# Filter input_data to only declared input_keys to prevent
|
||||
# leaking extra memory keys from upstream nodes.
|
||||
if ctx.node_spec.input_keys:
|
||||
filtered = {
|
||||
k: v for k, v in ctx.input_data.items() if k in ctx.node_spec.input_keys
|
||||
}
|
||||
else:
|
||||
filtered = ctx.input_data
|
||||
|
||||
# Call the function (supports both sync and async)
|
||||
result = self.func(**filtered)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
latency_ms = int((time.time() - start) * 1000)
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ The Plan is the contract between the external planner and the executor:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
class ActionType(StrEnum):
|
||||
"""Types of actions a PlanStep can perform."""
|
||||
|
||||
LLM_CALL = "llm_call" # Call LLM for generation
|
||||
@@ -27,7 +27,7 @@ class ActionType(str, Enum):
|
||||
CODE_EXECUTION = "code_execution" # Execute dynamic code (sandboxed)
|
||||
|
||||
|
||||
class StepStatus(str, Enum):
|
||||
class StepStatus(StrEnum):
|
||||
"""Status of a plan step."""
|
||||
|
||||
PENDING = "pending"
|
||||
@@ -56,7 +56,7 @@ class StepStatus(str, Enum):
|
||||
return self == StepStatus.COMPLETED
|
||||
|
||||
|
||||
class ApprovalDecision(str, Enum):
|
||||
class ApprovalDecision(StrEnum):
|
||||
"""Human decision on a step requiring approval."""
|
||||
|
||||
APPROVE = "approve" # Execute as planned
|
||||
@@ -91,7 +91,7 @@ class ApprovalResult(BaseModel):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class JudgmentAction(str, Enum):
|
||||
class JudgmentAction(StrEnum):
|
||||
"""Actions the judge can take after evaluating a step."""
|
||||
|
||||
ACCEPT = "accept" # Step completed successfully, continue
|
||||
@@ -423,7 +423,7 @@ class Plan(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
class ExecutionStatus(StrEnum):
|
||||
"""Status of plan execution."""
|
||||
|
||||
COMPLETED = "completed"
|
||||
|
||||
@@ -75,16 +75,6 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
def visit_Constant(self, node: ast.Constant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Number/String/Bytes/NameConstant (Python < 3.8 compat if needed) ---
|
||||
def visit_Num(self, node: ast.Num) -> Any:
|
||||
return node.n
|
||||
|
||||
def visit_Str(self, node: ast.Str) -> Any:
|
||||
return node.s
|
||||
|
||||
def visit_NameConstant(self, node: ast.NameConstant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Data Structures ---
|
||||
def visit_List(self, node: ast.List) -> list:
|
||||
return [self.visit(elt) for elt in node.elts]
|
||||
|
||||
@@ -126,14 +126,16 @@ class OutputValidator:
|
||||
|
||||
for key in expected_keys:
|
||||
if key not in output:
|
||||
errors.append(f"Missing required output key: '{key}'")
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Missing required output key: '{key}'")
|
||||
elif not allow_empty:
|
||||
value = output[key]
|
||||
if value is None:
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Output key '{key}' is None")
|
||||
elif isinstance(value, str) and len(value.strip()) == 0:
|
||||
errors.append(f"Output key '{key}' is empty string")
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Output key '{key}' is empty string")
|
||||
|
||||
return ValidationResult(success=len(errors) == 0, errors=errors)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def _get_api_key_from_credential_store() -> str | None:
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except ImportError:
|
||||
|
||||
@@ -163,11 +163,24 @@ class LiteLLMProvider(LLMProvider):
|
||||
content = response.choices[0].message.content if response.choices else None
|
||||
has_tool_calls = bool(response.choices and response.choices[0].message.tool_calls)
|
||||
if not content and not has_tool_calls:
|
||||
# If the conversation ends with an assistant message,
|
||||
# an empty response is expected — don't retry.
|
||||
messages = kwargs.get("messages", [])
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
logger.debug(
|
||||
"[retry] Empty response after assistant message — "
|
||||
"expected, not retrying."
|
||||
)
|
||||
return response
|
||||
|
||||
finish_reason = (
|
||||
response.choices[0].finish_reason if response.choices else "unknown"
|
||||
)
|
||||
# Dump full request to file for debugging
|
||||
messages = kwargs.get("messages", [])
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
@@ -380,11 +393,18 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
# Execute tools and add results.
|
||||
for tool_call in message.tool_calls:
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
# Surface error to LLM and skip tool execution
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
@@ -474,7 +494,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
buffered_events: list[StreamEvent] = []
|
||||
# Post-stream events (ToolCall, TextEnd, Finish) are buffered
|
||||
# because they depend on the full stream. TextDeltaEvents are
|
||||
# yielded immediately so callers see tokens in real time.
|
||||
tail_events: list[StreamEvent] = []
|
||||
accumulated_text = ""
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
input_tokens = 0
|
||||
@@ -490,14 +513,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# --- Text content ---
|
||||
# --- Text content — yield immediately for real-time streaming ---
|
||||
if delta and delta.content:
|
||||
accumulated_text += delta.content
|
||||
buffered_events.append(
|
||||
TextDeltaEvent(
|
||||
content=delta.content,
|
||||
snapshot=accumulated_text,
|
||||
)
|
||||
yield TextDeltaEvent(
|
||||
content=delta.content,
|
||||
snapshot=accumulated_text,
|
||||
)
|
||||
|
||||
# --- Tool calls (accumulate across chunks) ---
|
||||
@@ -521,7 +542,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
buffered_events.append(
|
||||
tail_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
tool_name=tc_data["name"],
|
||||
@@ -530,14 +551,14 @@ class LiteLLMProvider(LLMProvider):
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
buffered_events.append(TextEndEvent(full_text=accumulated_text))
|
||||
tail_events.append(TextEndEvent(full_text=accumulated_text))
|
||||
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
buffered_events.append(
|
||||
tail_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
@@ -547,8 +568,25 @@ class LiteLLMProvider(LLMProvider):
|
||||
)
|
||||
|
||||
# Check whether the stream produced any real content.
|
||||
# (If text deltas were yielded above, has_content is True
|
||||
# and we skip the retry path — nothing was yielded in vain.)
|
||||
has_content = accumulated_text or tool_calls_acc
|
||||
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
# If the conversation ends with an assistant message,
|
||||
# an empty stream is expected (nothing new to say).
|
||||
# Don't retry — just flush whatever we have.
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(full_messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
logger.debug(
|
||||
"[stream] Empty response after assistant message — "
|
||||
"expected, not retrying."
|
||||
)
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
@@ -570,8 +608,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
# Success (or final attempt) — flush buffered events.
|
||||
for event in buffered_events:
|
||||
# Success (or final attempt) — flush remaining events.
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
|
||||
|
||||
@@ -516,6 +516,36 @@ def _validate_tool_credentials(tools_list: list[str]) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def _validate_agent_path(agent_path: str) -> tuple[Path | None, str | None]:
|
||||
"""
|
||||
Validate and normalize agent_path.
|
||||
|
||||
Returns:
|
||||
(Path, None) if valid
|
||||
(None, error_json) if invalid
|
||||
"""
|
||||
if not agent_path:
|
||||
return None, json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "agent_path is required (e.g., 'exports/my_agent')",
|
||||
}
|
||||
)
|
||||
|
||||
path = Path(agent_path)
|
||||
|
||||
if not path.exists():
|
||||
return None, json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Agent path not found: {path}",
|
||||
"hint": "Run export_graph to create an agent in exports/ first",
|
||||
}
|
||||
)
|
||||
|
||||
return path, None
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def add_node(
|
||||
node_id: Annotated[str, "Unique identifier for the node"],
|
||||
@@ -2597,10 +2627,11 @@ def generate_constraint_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
agent_module = _get_agent_module_from_path(agent_path)
|
||||
agent_module = _get_agent_module_from_path(path)
|
||||
|
||||
# Format constraints for display
|
||||
constraints_formatted = (
|
||||
@@ -2619,9 +2650,9 @@ def generate_constraint_tests(
|
||||
return json.dumps(
|
||||
{
|
||||
"goal_id": goal_id,
|
||||
"agent_path": agent_path,
|
||||
"agent_path": str(path),
|
||||
"agent_module": agent_module,
|
||||
"output_file": f"{agent_path}/tests/test_constraints.py",
|
||||
"output_file": f"{str(path)}/tests/test_constraints.py",
|
||||
"constraints": [c.model_dump() for c in goal.constraints] if goal.constraints else [],
|
||||
"constraints_formatted": constraints_formatted,
|
||||
"test_guidelines": {
|
||||
@@ -2677,10 +2708,11 @@ def generate_success_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
agent_module = _get_agent_module_from_path(agent_path)
|
||||
agent_module = _get_agent_module_from_path(path)
|
||||
|
||||
# Parse node/tool names for context
|
||||
nodes = [n.strip() for n in node_names.split(",") if n.strip()]
|
||||
@@ -2705,9 +2737,9 @@ def generate_success_tests(
|
||||
return json.dumps(
|
||||
{
|
||||
"goal_id": goal_id,
|
||||
"agent_path": agent_path,
|
||||
"agent_path": str(path),
|
||||
"agent_module": agent_module,
|
||||
"output_file": f"{agent_path}/tests/test_success_criteria.py",
|
||||
"output_file": f"{str(path)}/tests/test_success_criteria.py",
|
||||
"success_criteria": [c.model_dump() for c in goal.success_criteria]
|
||||
if goal.success_criteria
|
||||
else [],
|
||||
@@ -2766,7 +2798,11 @@ def run_tests(
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -2957,10 +2993,11 @@ def debug_test(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -3101,10 +3138,11 @@ def list_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -3379,7 +3417,7 @@ def store_credential(
|
||||
display_name: Annotated[str, "Human-readable name (e.g., 'HubSpot Access Token')"] = "",
|
||||
) -> str:
|
||||
"""
|
||||
Store a credential securely in the encrypted credential store at ~/.hive/credentials.
|
||||
Store a credential securely in the local encrypted store at ~/.hive/credentials.
|
||||
|
||||
Uses Fernet encryption (AES-128-CBC + HMAC). Requires HIVE_CREDENTIAL_KEY env var.
|
||||
"""
|
||||
@@ -3421,7 +3459,7 @@ def store_credential(
|
||||
@mcp.tool()
|
||||
def list_stored_credentials() -> str:
|
||||
"""
|
||||
List all credentials currently stored in the encrypted credential store.
|
||||
List all credentials currently stored in the local encrypted store.
|
||||
|
||||
Returns credential IDs and metadata (never returns secret values).
|
||||
"""
|
||||
@@ -3461,7 +3499,7 @@ def delete_stored_credential(
|
||||
credential_name: Annotated[str, "Logical credential name to delete (e.g., 'hubspot')"],
|
||||
) -> str:
|
||||
"""
|
||||
Delete a credential from the encrypted credential store.
|
||||
Delete a credential from the local encrypted store.
|
||||
"""
|
||||
try:
|
||||
store = _get_credential_store()
|
||||
|
||||
@@ -411,25 +411,8 @@ class AgentRunner:
|
||||
return self._tool_registry.register_mcp_server(server_config)
|
||||
|
||||
def _load_mcp_servers_from_config(self, config_path: Path) -> None:
|
||||
"""
|
||||
Load and register MCP servers from a configuration file.
|
||||
|
||||
Args:
|
||||
config_path: Path to mcp_servers.json file
|
||||
"""
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
servers = config.get("servers", [])
|
||||
for server_config in servers:
|
||||
try:
|
||||
self._tool_registry.register_mcp_server(server_config)
|
||||
except Exception as e:
|
||||
server_name = server_config.get("name", "unknown")
|
||||
logger.warning(f"Failed to register MCP server '{server_name}': {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP servers config from {config_path}: {e}")
|
||||
"""Load and register MCP servers from a configuration file."""
|
||||
self._tool_registry.load_mcp_config(config_path)
|
||||
|
||||
def set_approval_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
|
||||
@@ -257,6 +257,34 @@ class ToolRegistry:
|
||||
"""
|
||||
self._session_context.update(context)
|
||||
|
||||
def load_mcp_config(self, config_path: Path) -> None:
|
||||
"""
|
||||
Load and register MCP servers from a config file.
|
||||
|
||||
Resolves relative ``cwd`` paths against the config file's parent
|
||||
directory so callers never need to handle path resolution themselves.
|
||||
|
||||
Args:
|
||||
config_path: Path to an ``mcp_servers.json`` file.
|
||||
"""
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP config from {config_path}: {e}")
|
||||
return
|
||||
|
||||
base_dir = config_path.parent
|
||||
for server_config in config.get("servers", []):
|
||||
cwd = server_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
server_config["cwd"] = str((base_dir / cwd).resolve())
|
||||
try:
|
||||
self.register_mcp_server(server_config)
|
||||
except Exception as e:
|
||||
name = server_config.get("name", "unknown")
|
||||
logger.warning(f"Failed to register MCP server '{name}': {e}")
|
||||
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
@@ -309,11 +337,21 @@ class ToolRegistry:
|
||||
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
|
||||
|
||||
# Create executor that calls the MCP server
|
||||
def make_mcp_executor(client_ref: MCPClient, tool_name: str, registry_ref):
|
||||
def make_mcp_executor(
|
||||
client_ref: MCPClient,
|
||||
tool_name: str,
|
||||
registry_ref,
|
||||
tool_params: set[str],
|
||||
):
|
||||
def executor(inputs: dict) -> Any:
|
||||
try:
|
||||
# Inject session context for tools that need it
|
||||
merged_inputs = {**registry_ref._session_context, **inputs}
|
||||
# Only inject session context params the tool accepts
|
||||
filtered_context = {
|
||||
k: v
|
||||
for k, v in registry_ref._session_context.items()
|
||||
if k in tool_params
|
||||
}
|
||||
merged_inputs = {**filtered_context, **inputs}
|
||||
result = client_ref.call_tool(tool_name, merged_inputs)
|
||||
# MCP tools return content array, extract the result
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
@@ -327,10 +365,11 @@ class ToolRegistry:
|
||||
|
||||
return executor
|
||||
|
||||
tool_params = set(mcp_tool.input_schema.get("properties", {}).keys())
|
||||
self.register(
|
||||
mcp_tool.name,
|
||||
tool,
|
||||
make_mcp_executor(client, mcp_tool.name, self),
|
||||
make_mcp_executor(client, mcp_tool.name, self, tool_params),
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -12,13 +12,13 @@ import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
class EventType(StrEnum):
|
||||
"""Types of events that can be published."""
|
||||
|
||||
# Execution lifecycle
|
||||
|
||||
@@ -11,13 +11,13 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IsolationLevel(str, Enum):
|
||||
class IsolationLevel(StrEnum):
|
||||
"""State isolation level for concurrent executions."""
|
||||
|
||||
ISOLATED = "isolated" # Private state per execution
|
||||
@@ -25,7 +25,7 @@ class IsolationLevel(str, Enum):
|
||||
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
|
||||
|
||||
|
||||
class StateScope(str, Enum):
|
||||
class StateScope(StrEnum):
|
||||
"""Scope for state operations."""
|
||||
|
||||
EXECUTION = "execution" # Local to a single execution
|
||||
|
||||
@@ -10,13 +10,13 @@ This is MORE important than actions because:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
class DecisionType(StrEnum):
|
||||
"""Types of decisions an agent can make."""
|
||||
|
||||
TOOL_SELECTION = "tool_selection" # Which tool to use
|
||||
|
||||
@@ -6,7 +6,7 @@ summaries and metrics that Builder needs to understand what happened.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
@@ -14,7 +14,7 @@ from pydantic import BaseModel, Field, computed_field
|
||||
from framework.schemas.decision import Decision, Outcome
|
||||
|
||||
|
||||
class RunStatus(str, Enum):
|
||||
class RunStatus(StrEnum):
|
||||
"""Status of a run."""
|
||||
|
||||
RUNNING = "running"
|
||||
|
||||
@@ -167,14 +167,18 @@ class ConcurrentStorage:
|
||||
run: Run to save
|
||||
immediate: If True, save immediately (bypasses batching)
|
||||
"""
|
||||
# Invalidate summary cache since the run data is changing
|
||||
# This ensures load_summary() fetches fresh data after the save
|
||||
self._cache.pop(f"summary:{run.id}", None)
|
||||
|
||||
if immediate or not self._running:
|
||||
await self._save_run_locked(run)
|
||||
# Update cache only after successful immediate write
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
else:
|
||||
# For batched writes, cache will be updated in _flush_batch after successful write
|
||||
await self._write_queue.put(("run", run))
|
||||
|
||||
# Update cache
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
|
||||
async def _save_run_locked(self, run: Run) -> None:
|
||||
"""Save a run with file locking, including index locks."""
|
||||
lock_key = f"run:{run.id}"
|
||||
@@ -363,8 +367,12 @@ class ConcurrentStorage:
|
||||
try:
|
||||
if item_type == "run":
|
||||
await self._save_run_locked(item)
|
||||
# Update cache only after successful batched write
|
||||
# This fixes the race condition where cache was updated before write completed
|
||||
self._cache[f"run:{item.id}"] = CacheEntry(item, time.time())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {item_type}: {e}")
|
||||
# Cache is NOT updated on failure - prevents stale/inconsistent cache state
|
||||
|
||||
async def _flush_pending(self) -> None:
|
||||
"""Flush all pending writes."""
|
||||
|
||||
@@ -6,13 +6,13 @@ programmatic/MCP-based approval.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalAction(str, Enum):
|
||||
class ApprovalAction(StrEnum):
|
||||
"""Actions a user can take on a generated test."""
|
||||
|
||||
APPROVE = "approve" # Accept as-is
|
||||
|
||||
@@ -24,7 +24,7 @@ def _get_api_key():
|
||||
# 1. Try CredentialStoreAdapter for Anthropic
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except (ImportError, KeyError):
|
||||
@@ -57,7 +57,7 @@ def _get_api_key():
|
||||
"""Get API key from CredentialStoreAdapter or environment."""
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except (ImportError, KeyError):
|
||||
|
||||
@@ -6,13 +6,13 @@ but require mandatory user approval before being stored.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
class ApprovalStatus(StrEnum):
|
||||
"""Status of user approval for a generated test."""
|
||||
|
||||
PENDING = "pending" # Awaiting user review
|
||||
@@ -21,7 +21,7 @@ class ApprovalStatus(str, Enum):
|
||||
REJECTED = "rejected" # User declined (with reason)
|
||||
|
||||
|
||||
class TestType(str, Enum):
|
||||
class TestType(StrEnum):
|
||||
"""Type of test based on what it validates."""
|
||||
|
||||
__test__ = False # Not a pytest test class
|
||||
|
||||
@@ -6,13 +6,13 @@ categorization for guiding iteration strategy.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorCategory(str, Enum):
|
||||
class ErrorCategory(StrEnum):
|
||||
"""
|
||||
Category of test failure for guiding iteration.
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ dependencies = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"pytest-xdist>=3.0",
|
||||
"tools",
|
||||
]
|
||||
|
||||
# [project.optional-dependencies]
|
||||
@@ -21,6 +22,9 @@ dependencies = [
|
||||
[project.scripts]
|
||||
hive = "framework.cli:main"
|
||||
|
||||
[tool.uv.sources]
|
||||
tools = { workspace = true }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
@@ -43,6 +47,7 @@ lint.select = [
|
||||
"W", # pycodestyle warnings
|
||||
]
|
||||
|
||||
lint.per-file-ignores."demos/*" = ["E501"]
|
||||
lint.isort.combine-as-imports = true
|
||||
lint.isort.known-first-party = ["framework"]
|
||||
lint.isort.section-order = [
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
# Development dependencies
|
||||
-r requirements.txt
|
||||
|
||||
# Testing
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.23
|
||||
|
||||
# Linting & type checking
|
||||
ruff>=0.1.0
|
||||
mypy>=1.0
|
||||
@@ -1,14 +0,0 @@
|
||||
# Core dependencies
|
||||
pydantic>=2.0
|
||||
anthropic>=0.40.0
|
||||
httpx>=0.27.0
|
||||
litellm>=1.81.0
|
||||
|
||||
# MCP server dependencies
|
||||
mcp
|
||||
fastmcp
|
||||
|
||||
# Testing (required for test framework)
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.23
|
||||
pytest-xdist>=3.0
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests for ConcurrentStorage race condition and cache invalidation fixes."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.schemas.run import Run, RunMetrics, RunStatus
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
|
||||
def create_test_run(
|
||||
run_id: str, goal_id: str = "test-goal", status: RunStatus = RunStatus.RUNNING
|
||||
) -> Run:
|
||||
"""Create a minimal test Run object."""
|
||||
return Run(
|
||||
id=run_id,
|
||||
goal_id=goal_id,
|
||||
status=status,
|
||||
narrative="Test run",
|
||||
metrics=RunMetrics(
|
||||
nodes_executed=[],
|
||||
),
|
||||
decisions=[],
|
||||
problems=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidation_on_save(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated when a run is saved.
|
||||
|
||||
This tests the fix for the cache invalidation bug where load_summary()
|
||||
would return stale data after a run was updated.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-1"
|
||||
|
||||
# Create and save initial run
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to populate the cache
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.RUNNING
|
||||
|
||||
# Update run with new status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary again - should get fresh data, not cached stale data
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.COMPLETED, (
|
||||
"Summary cache should be invalidated on save - got stale data"
|
||||
)
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_write_cache_consistency(tmp_path: Path):
|
||||
"""Test that cache is only updated after successful batched write.
|
||||
|
||||
This tests the fix for the race condition where cache was updated
|
||||
before the batched write completed.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path, batch_interval=0.05)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-2"
|
||||
|
||||
# Save via batching (immediate=False)
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=False)
|
||||
|
||||
# Before batch flush, cache should NOT contain the run
|
||||
# (This is the fix - previously cache was updated immediately)
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key not in storage._cache, (
|
||||
"Cache should not be updated before batch is flushed"
|
||||
)
|
||||
|
||||
# Wait for batch to flush (poll instead of fixed sleep for CI reliability)
|
||||
for _ in range(500): # 500 * 0.01s = 5s max
|
||||
if cache_key in storage._cache:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# After batch flush, cache should contain the run
|
||||
assert cache_key in storage._cache, "Cache should be updated after batch flush"
|
||||
|
||||
# Verify data on disk matches cache
|
||||
loaded_run = await storage.load_run(run_id, use_cache=False)
|
||||
assert loaded_run is not None
|
||||
assert loaded_run.id == run_id
|
||||
assert loaded_run.status == RunStatus.RUNNING
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_immediate_write_updates_cache(tmp_path: Path):
|
||||
"""Test that immediate writes still update cache correctly."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-3"
|
||||
|
||||
# Save with immediate=True
|
||||
run = create_test_run(run_id, status=RunStatus.COMPLETED)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Cache should be updated immediately for immediate writes
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key in storage._cache, "Cache should be updated after immediate write"
|
||||
|
||||
# Verify cached value is correct
|
||||
cached_run = storage._cache[cache_key].value
|
||||
assert cached_run.id == run_id
|
||||
assert cached_run.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_cache_invalidated_on_multiple_saves(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated on each save, not just the first."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-4"
|
||||
|
||||
# First save
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to cache it
|
||||
summary1 = await storage.load_summary(run_id)
|
||||
assert summary1.status == RunStatus.RUNNING
|
||||
|
||||
# Second save with new status
|
||||
run.status = RunStatus.RUNNING
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh
|
||||
summary2 = await storage.load_summary(run_id)
|
||||
assert summary2.status == RunStatus.RUNNING
|
||||
|
||||
# Third save with final status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh again
|
||||
summary3 = await storage.load_summary(run_id)
|
||||
assert summary3.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
@@ -8,6 +8,7 @@ Set HIVE_TEST_LLM_MODEL=<model> to override the real model.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass
|
||||
@@ -950,7 +951,15 @@ async def test_client_facing_node_streams_output():
|
||||
event_bus=bus,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
|
||||
# client_facing + text-only blocks for user input; use shutdown to unblock
|
||||
async def auto_shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(auto_shutdown())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success
|
||||
|
||||
|
||||
@@ -446,12 +446,172 @@ class TestEventBusLifecycle:
|
||||
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
|
||||
# client_facing + text-only blocks for user input; use shutdown to unblock
|
||||
async def auto_shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(auto_shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert EventType.CLIENT_OUTPUT_DELTA in received_types
|
||||
assert EventType.LLM_TEXT_DELTA not in received_types
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Client-facing blocking
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestClientFacingBlocking:
|
||||
"""Tests for native client_facing input blocking in EventLoopNode."""
|
||||
|
||||
@pytest.fixture
|
||||
def client_spec(self):
|
||||
return NodeSpec(
|
||||
id="chat",
|
||||
name="Chat",
|
||||
description="chat node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_blocks_on_text(self, runtime, memory, client_spec):
|
||||
"""client_facing + text-only response blocks until inject_event."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("Hello!"),
|
||||
text_scenario("Got your message."),
|
||||
]
|
||||
)
|
||||
bus = EventBus()
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def user_responds():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("I need help")
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
user_task = asyncio.create_task(user_responds())
|
||||
result = await node.execute(ctx)
|
||||
await user_task
|
||||
|
||||
assert result.success is True
|
||||
# LLM should have been called at least twice (first response + after inject)
|
||||
assert llm._call_index >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_does_not_block_on_tools(self, runtime, memory):
|
||||
"""client_facing + tool calls should NOT block — judge evaluates normally."""
|
||||
spec = NodeSpec(
|
||||
id="chat",
|
||||
name="Chat",
|
||||
description="chat node",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
client_facing=True,
|
||||
)
|
||||
# Scenario 1: LLM calls set_output (tool call present → no blocking, judge RETRYs)
|
||||
# Scenario 2: LLM produces text (implicit judge sees output key set → ACCEPT)
|
||||
# But scenario 2 is text-only on client_facing → would block.
|
||||
# So we need shutdown to handle that case.
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
text_scenario("All set!"),
|
||||
]
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
# After set_output, implicit judge RETRYs (tool calls present).
|
||||
# Next turn: text-only on client_facing → blocks.
|
||||
# But implicit judge should ACCEPT first (output key is set, no tools).
|
||||
# Actually, client_facing check happens BEFORE judge, so it blocks.
|
||||
# Use shutdown as safety net.
|
||||
async def auto_shutdown():
|
||||
await asyncio.sleep(0.1)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(auto_shutdown())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_client_facing_unchanged(self, runtime, memory):
|
||||
"""client_facing=False should not block — existing behavior."""
|
||||
spec = NodeSpec(
|
||||
id="internal",
|
||||
name="Internal",
|
||||
description="internal node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
# Should complete without blocking (implicit judge ACCEPTs on no tools + no keys)
|
||||
result = await node.execute(ctx)
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_shutdown_unblocks(self, runtime, memory, client_spec):
|
||||
"""signal_shutdown should unblock a waiting client_facing node."""
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Waiting...")])
|
||||
bus = EventBus()
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=10))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def shutdown_after_delay():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown_after_delay())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_input_requested_event_published(self, runtime, memory, client_spec):
|
||||
"""CLIENT_INPUT_REQUESTED should be published when blocking."""
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello!")])
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def capture(e):
|
||||
received.append(e)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=capture,
|
||||
)
|
||||
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert len(received) >= 1
|
||||
assert received[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tool execution
|
||||
# ===========================================================================
|
||||
|
||||
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
Tests for feedback/callback edges and max_node_visits in GraphExecutor.
|
||||
|
||||
Covers:
|
||||
- NodeSpec.max_node_visits default value
|
||||
- Visit limit enforcement (skip on exceed)
|
||||
- Multiple visits allowed when max_node_visits > 1
|
||||
- Unlimited visits with max_node_visits=0
|
||||
- Conditional feedback edges (backward traversal)
|
||||
- Conditional edge NOT firing (forward path taken)
|
||||
- node_visit_counts populated in ExecutionResult
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock node implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SuccessNode(NodeProtocol):
|
||||
"""Always succeeds with configurable output."""
|
||||
|
||||
def __init__(self, output: dict | None = None):
|
||||
self._output = output or {"result": "ok"}
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
self.execute_count += 1
|
||||
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
|
||||
|
||||
|
||||
class StatefulNode(NodeProtocol):
|
||||
"""Returns different outputs on successive executions."""
|
||||
|
||||
def __init__(self, outputs: list[dict]):
|
||||
self._outputs = outputs
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
output = self._outputs[min(self.execute_count, len(self._outputs) - 1)]
|
||||
self.execute_count += 1
|
||||
return NodeResult(success=True, output=output, tokens_used=10, latency_ms=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_id")
|
||||
rt.decide = MagicMock(return_value="decision_id")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def goal():
|
||||
return Goal(id="g1", name="Test", description="Feedback edge tests")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. NodeSpec default
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_max_node_visits_default():
|
||||
"""NodeSpec.max_node_visits should default to 1."""
|
||||
spec = NodeSpec(id="n", name="N", description="test", node_type="function", output_keys=["out"])
|
||||
assert spec.max_node_visits == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Visit limit skips node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_skips_node(runtime, goal):
|
||||
"""A→B→A cycle with A.max_visits=1: second visit to A should be skipped.
|
||||
|
||||
Neither node is terminal — max_steps is the guard. After A is skipped,
|
||||
the skip-redirect loop (A skip→B→A skip→B...) burns through max_steps.
|
||||
"""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry with visit limit",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=1,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0, # unlimited — let max_steps guard
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[], # No terminal — max_steps is the guard
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# A should only execute once (all subsequent visits are skipped)
|
||||
assert a_impl.execute_count == 1
|
||||
# Path should contain "a" exactly once (skipped visits aren't appended)
|
||||
assert result.path.count("a") == 1
|
||||
# Visit count tracks ALL visits (including skipped ones)
|
||||
assert result.node_visit_counts["a"] >= 2
|
||||
# B executes multiple times (no visit limit)
|
||||
assert b_impl.execute_count >= 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Visit limit allows multiple
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_allows_multiple(runtime, goal):
|
||||
"""A→B→A cycle with A.max_visits=2: A executes twice before skip."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry allows two visits",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0, # unlimited
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# A should execute exactly twice
|
||||
assert a_impl.execute_count == 2
|
||||
# Path should contain "a" exactly twice
|
||||
assert result.path.count("a") == 2
|
||||
# Visit count includes skipped visits too
|
||||
assert result.node_visit_counts["a"] >= 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Visit limit zero = unlimited
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_zero_unlimited(runtime, goal):
|
||||
"""max_node_visits=0 means unlimited; max_steps is the only guard."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="unlimited visits",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=0,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0,
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[],
|
||||
max_steps=6, # A,B,A,B,A,B
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# With max_steps=6: A,B,A,B,A,B → each executes 3 times
|
||||
assert a_impl.execute_count == 3
|
||||
assert b_impl.execute_count == 3
|
||||
assert result.steps_executed == 6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Conditional feedback edge fires
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_feedback_edge(runtime, goal):
|
||||
"""Writer→Director backward edge fires when needs_revision==True in output.
|
||||
|
||||
Edge conditions evaluate `output` (current node result) and `memory`
|
||||
(accumulated shared state). The writer's output hasn't been written to
|
||||
memory yet when edges are evaluated, so we use `output.get(...)`.
|
||||
"""
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="plans work",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="writes draft",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="final output",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="feedback_graph",
|
||||
goal_id="g1",
|
||||
name="Feedback Graph",
|
||||
entry_node="director",
|
||||
nodes=[director, writer, output_node],
|
||||
edges=[
|
||||
EdgeSpec(
|
||||
id="director_to_writer",
|
||||
source="director",
|
||||
target="writer",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
# Forward path: writer → output (when NOT needs_revision)
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') != True",
|
||||
priority=0,
|
||||
),
|
||||
# Feedback path: writer → director (when needs_revision)
|
||||
EdgeSpec(
|
||||
id="writer_feedback",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
director_impl = SuccessNode({"plan": "research AI"})
|
||||
# Writer: first call sets needs_revision=True, second sets False
|
||||
writer_impl = StatefulNode(
|
||||
[
|
||||
{"draft": "draft_v1", "needs_revision": True},
|
||||
{"draft": "draft_v2", "needs_revision": False},
|
||||
]
|
||||
)
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("director", director_impl)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Director executed twice (initial + feedback)
|
||||
assert director_impl.execute_count == 2
|
||||
# Writer executed twice (first draft rejected, second accepted)
|
||||
assert writer_impl.execute_count == 2
|
||||
# Output executed once
|
||||
assert output_impl.execute_count == 1
|
||||
# Full path: director → writer → director → writer → output
|
||||
assert result.path == ["director", "writer", "director", "writer", "output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Conditional feedback edge does NOT fire
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_feedback_false(runtime, goal):
|
||||
"""Writer→Director backward edge does NOT fire when needs_revision is False."""
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="plans work",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="writes draft",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="final output",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="feedback_graph",
|
||||
goal_id="g1",
|
||||
name="Feedback Graph",
|
||||
entry_node="director",
|
||||
nodes=[director, writer, output_node],
|
||||
edges=[
|
||||
EdgeSpec(
|
||||
id="director_to_writer",
|
||||
source="director",
|
||||
target="writer",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') != True",
|
||||
priority=0,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="writer_feedback",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
director_impl = SuccessNode({"plan": "research AI"})
|
||||
# Writer always outputs good draft (no revision needed)
|
||||
writer_impl = SuccessNode({"draft": "perfect_draft", "needs_revision": False})
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("director", director_impl)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Director only executed once (no feedback loop)
|
||||
assert director_impl.execute_count == 1
|
||||
# Writer only executed once
|
||||
assert writer_impl.execute_count == 1
|
||||
# Output executed
|
||||
assert output_impl.execute_count == 1
|
||||
# Straight-through path
|
||||
assert result.path == ["director", "writer", "output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Visit counts in ExecutionResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_counts_in_result(runtime, goal):
|
||||
"""ExecutionResult.node_visit_counts is populated with actual visit counts."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="terminal",
|
||||
node_type="function",
|
||||
input_keys=["a_out"],
|
||||
output_keys=["b_out"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="linear_graph",
|
||||
goal_id="g1",
|
||||
name="Linear Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=["b"],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", SuccessNode({"a_out": "x"}))
|
||||
executor.register_node("b", SuccessNode({"b_out": "y"}))
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
assert result.node_visit_counts == {"a": 1, "b": 1}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Conditional priority prevents fan-out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_priority_prevents_fanout(runtime, goal):
|
||||
"""When multiple CONDITIONAL edges match, only highest-priority fires.
|
||||
|
||||
Simulates: writer produces output where both forward and feedback
|
||||
conditions could match. The higher-priority forward edge should win;
|
||||
the executor must NOT treat this as fan-out.
|
||||
"""
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="produces output",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="forward target",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="feedback target",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="priority_graph",
|
||||
goal_id="g1",
|
||||
name="Priority Graph",
|
||||
entry_node="writer",
|
||||
nodes=[writer, output_node, director],
|
||||
edges=[
|
||||
# Forward: higher priority (1)
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('draft') is not None",
|
||||
priority=1,
|
||||
),
|
||||
# Feedback: lower priority (-1)
|
||||
EdgeSpec(
|
||||
id="writer_to_director",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
# Writer sets BOTH output keys — both conditions are true
|
||||
writer_impl = SuccessNode({"draft": "my draft", "needs_revision": True})
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
director_impl = SuccessNode({"plan": "plan"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
executor.register_node("director", director_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Forward edge (priority 1) wins — output executes, director does NOT
|
||||
assert output_impl.execute_count == 1
|
||||
assert director_impl.execute_count == 0
|
||||
assert result.path == ["writer", "output"]
|
||||
@@ -209,6 +209,62 @@ class TestLiteLLMProviderToolUse:
|
||||
assert result.output_tokens == 25 # 15 + 10
|
||||
assert mock_completion.call_count == 2
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools_invalid_json_arguments_are_handled(self, mock_completion):
|
||||
"""Test that invalid JSON tool arguments do not execute the tool."""
|
||||
# Mock response with invalid JSON arguments
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.choices = [MagicMock()]
|
||||
tool_call_response.choices[0].message.content = None
|
||||
tool_call_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
tool_call_response.choices[0].message.tool_calls[0].id = "call_123"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.name = "test_tool"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.arguments = "{invalid json"
|
||||
tool_call_response.choices[0].finish_reason = "tool_calls"
|
||||
tool_call_response.model = "gpt-4o-mini"
|
||||
tool_call_response.usage.prompt_tokens = 10
|
||||
tool_call_response.usage.completion_tokens = 5
|
||||
|
||||
# Final response (LLM continues after tool error)
|
||||
final_response = MagicMock()
|
||||
final_response.choices = [MagicMock()]
|
||||
final_response.choices[0].message.content = "Handled error"
|
||||
final_response.choices[0].message.tool_calls = None
|
||||
final_response.choices[0].finish_reason = "stop"
|
||||
final_response.model = "gpt-4o-mini"
|
||||
final_response.usage.prompt_tokens = 5
|
||||
final_response.usage.completion_tokens = 5
|
||||
|
||||
mock_completion.side_effect = [tool_call_response, final_response]
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={"properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
called["value"] = True
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id, content="should not be called", is_error=False
|
||||
)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "Run tool"}],
|
||||
system="You are a test assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert called["value"] is False
|
||||
assert result.content == "Handled error"
|
||||
|
||||
|
||||
class TestToolConversion:
|
||||
"""Test tool format conversion."""
|
||||
|
||||
@@ -362,7 +362,6 @@ class AgentRequest(BaseModel):
|
||||
raise ValueError('max_tokens too high')
|
||||
return v
|
||||
```
|
||||
|
||||
### Output Sanitization
|
||||
> **Note:** The following snippet is illustrative and shows a simplified example
|
||||
> of output sanitization logic. Actual implementations may differ.
|
||||
|
||||
@@ -1757,7 +1757,7 @@ tools/src/aden_tools/credentials/
|
||||
|
||||
### Manual Testing
|
||||
|
||||
- [ ] Create encrypted credential store
|
||||
- [ ] Create local encrypted store
|
||||
- [ ] Save and load multi-key credentials
|
||||
- [ ] Verify template resolution in tool headers
|
||||
- [ ] Test OAuth2 token refresh
|
||||
|
||||
+1
-3
@@ -14,7 +14,6 @@
|
||||
|
||||
[](https://github.com/adenhq/hive/blob/main/LICENSE)
|
||||
[](https://www.ycombinator.com/companies/aden)
|
||||
[](https://hub.docker.com/u/adenhq)
|
||||
[](https://discord.com/invite/MXE49hrKDk)
|
||||
[](https://x.com/aden_hq)
|
||||
[](https://www.linkedin.com/company/teamaden/)
|
||||
@@ -42,7 +41,7 @@ Visita [adenhq.com](https://adenhq.com) para documentación completa, ejemplos y
|
||||
## ¿Qué es Aden?
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden es una plataforma para construir, desplegar, operar y adaptar agentes de IA:
|
||||
@@ -66,7 +65,6 @@ Aden es una plataforma para construir, desplegar, operar y adaptar agentes de IA
|
||||
### Prerrequisitos
|
||||
|
||||
- [Python 3.11+](https://www.python.org/downloads/) - Para desarrollo de agentes
|
||||
- [Docker](https://docs.docker.com/get-docker/) (v20.10+) - Opcional, para herramientas en contenedores
|
||||
|
||||
### Instalación
|
||||
|
||||
|
||||
+1
-1
@@ -44,7 +44,7 @@
|
||||
# Aden क्या है?
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden एक ऐसा प्लेटफ़ॉर्म है जो AI एजेंट्स को बनाने, डिप्लॉय करने, ऑपरेट करने और अनुकूलित करने के लिए उपयोग होता है:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Adenとは
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Adenは、AIエージェントの構築、デプロイ、運用、適応のためのプラットフォームです:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Aden이란 무엇인가
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden은 AI 에이전트를 구축, 배포, 운영, 적응시키기 위한 플랫폼입니다:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@ Visite [adenhq.com](https://adenhq.com) para documentação completa, exemplos e
|
||||
## O que é Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden é uma plataforma para construir, implantar, operar e adaptar agentes de IA:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Что такое Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden — это платформа для создания, развёртывания, эксплуатации и адаптации ИИ-агентов:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## 什么是 Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden 是一个用于构建、部署、运营和适应 AI 智能体的平台:
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# Examples
|
||||
|
||||
This directory contains two types of examples to help you build agents with the Hive framework.
|
||||
|
||||
## Recipes vs Templates
|
||||
|
||||
### [recipes/](recipes/) — "How to make it"
|
||||
|
||||
A recipe is a **prompt-only** description of an agent. It tells you the goal, the nodes, the prompts, the edge routing logic, and what tools to wire in — but it's not runnable code. You read the recipe, then build the agent yourself.
|
||||
|
||||
Use recipes when you want to:
|
||||
- Understand a pattern before committing to an implementation
|
||||
- Adapt an idea to your own codebase or tooling
|
||||
- Learn how to think about agent design (goals, nodes, edges, prompts)
|
||||
|
||||
### [templates/](templates/) — "Ready to eat"
|
||||
|
||||
A template is a **working agent scaffold** that follows the standard Hive export structure. Copy the folder, rename it, swap in your own prompts and tools, and run it.
|
||||
|
||||
Use templates when you want to:
|
||||
- Get a new agent running quickly
|
||||
- Start from a known-good structure instead of from scratch
|
||||
- See how all the pieces (goal, nodes, edges, config, CLI) fit together in real code
|
||||
|
||||
## How to use a template
|
||||
|
||||
```bash
|
||||
# 1. Copy the template
|
||||
cp -r examples/templates/marketing_agent exports/my_agent
|
||||
|
||||
# 2. Edit the goal, nodes, and edges in agent.py and nodes/__init__.py
|
||||
|
||||
# 3. Run it
|
||||
PYTHONPATH=core python -m exports.my_agent --help
|
||||
```
|
||||
|
||||
## How to use a recipe
|
||||
|
||||
1. Read the recipe markdown file
|
||||
2. Use the patterns described to build your own agent — either manually or with the builder agent (`/agent-workflow`)
|
||||
3. Refer to the [core README](../core/README.md) for framework API details
|
||||
@@ -0,0 +1,27 @@
|
||||
# Recipes
|
||||
|
||||
A recipe describes an agent's design — the goal, nodes, prompts, edge logic, and tools — without providing runnable code. Think of it as a blueprint: it tells you *how* to build the agent, but you do the building.
|
||||
|
||||
## What's in a recipe
|
||||
|
||||
Each recipe is a markdown file (or folder with a markdown file) containing:
|
||||
|
||||
- **Goal**: What the agent accomplishes, including success criteria and constraints
|
||||
- **Nodes**: Each step in the workflow, with the system prompt, node type, and input/output keys
|
||||
- **Edges**: How nodes connect, including conditions and routing logic
|
||||
- **Tools**: What external tools or MCP servers the agent needs
|
||||
- **Usage notes**: Tips, gotchas, and suggested variations
|
||||
|
||||
## How to use a recipe
|
||||
|
||||
1. Read through the recipe to understand the design
|
||||
2. Create a new agent using the standard export structure (see [templates/](../templates/) for a scaffold)
|
||||
3. Translate the recipe's goal, nodes, and edges into code
|
||||
4. Wire in the tools described
|
||||
5. Test and iterate
|
||||
|
||||
## Available recipes
|
||||
|
||||
| Recipe | Description |
|
||||
|--------|-------------|
|
||||
| [marketing_agent](marketing_agent/) | Multi-channel marketing content generator with audience analysis and A/B copy variants |
|
||||
@@ -0,0 +1,156 @@
|
||||
# Recipe: Marketing Content Agent
|
||||
|
||||
A multi-channel marketing content generator. Given a product description and target audience, this agent analyzes the audience, generates tailored copy for multiple channels, and produces A/B variants.
|
||||
|
||||
## Goal
|
||||
|
||||
```
|
||||
Name: Marketing Content Generator
|
||||
Description: Generate targeted marketing content across multiple channels
|
||||
for a given product and audience.
|
||||
|
||||
Success criteria:
|
||||
- Audience analysis is produced with demographics and pain points
|
||||
- At least 2 channel-specific content pieces are generated
|
||||
- A/B variants are provided for each piece
|
||||
- All content aligns with the specified brand voice
|
||||
|
||||
Constraints:
|
||||
- (hard) No competitor brand names in generated content
|
||||
- (soft) Content should be under 280 characters for social media channels
|
||||
```
|
||||
|
||||
## Input / Output
|
||||
|
||||
**Input:**
|
||||
- `product_description` (str) — What the product is and does
|
||||
- `target_audience` (str) — Who the content is for
|
||||
- `brand_voice` (str) — Tone and style guidelines (e.g., "professional but approachable")
|
||||
- `channels` (list[str]) — Target channels, e.g. `["email", "twitter", "linkedin"]`
|
||||
|
||||
**Output:**
|
||||
- `audience_analysis` (dict) — Demographics, pain points, motivations
|
||||
- `content` (list[dict]) — Per-channel content with A/B variants
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
[analyze_audience] → [generate_content] → [review_and_refine]
|
||||
|
|
||||
(conditional)
|
||||
|
|
||||
needs_revision == True → [generate_content]
|
||||
needs_revision == False → (done)
|
||||
```
|
||||
|
||||
## Nodes
|
||||
|
||||
### 1. analyze_audience
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `product_description`, `target_audience` |
|
||||
| Output keys | `audience_analysis` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a marketing strategist. Analyze the target audience for a product.
|
||||
|
||||
Product: {product_description}
|
||||
Target audience: {target_audience}
|
||||
|
||||
Produce a structured analysis in JSON:
|
||||
{{
|
||||
"audience_analysis": {{
|
||||
"demographics": "...",
|
||||
"pain_points": ["..."],
|
||||
"motivations": ["..."],
|
||||
"preferred_channels": ["..."],
|
||||
"messaging_angle": "..."
|
||||
}}
|
||||
}}
|
||||
```
|
||||
|
||||
### 2. generate_content
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `product_description`, `audience_analysis`, `brand_voice`, `channels` |
|
||||
| Output keys | `content` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a marketing copywriter. Generate content for each channel.
|
||||
|
||||
Product: {product_description}
|
||||
Audience analysis: {audience_analysis}
|
||||
Brand voice: {brand_voice}
|
||||
Channels: {channels}
|
||||
|
||||
For each channel, produce two variants (A and B).
|
||||
|
||||
Output as JSON:
|
||||
{{
|
||||
"content": [
|
||||
{{
|
||||
"channel": "twitter",
|
||||
"variant_a": "...",
|
||||
"variant_b": "..."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
### 3. review_and_refine
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `content`, `brand_voice` |
|
||||
| Output keys | `content`, `needs_revision` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a senior marketing editor. Review the following content for brand
|
||||
voice alignment, clarity, and channel appropriateness.
|
||||
|
||||
Content: {content}
|
||||
Brand voice: {brand_voice}
|
||||
|
||||
If any piece needs revision, fix it and set needs_revision to true.
|
||||
If everything looks good, return the content unchanged with needs_revision false.
|
||||
|
||||
Output as JSON:
|
||||
{{
|
||||
"content": [...],
|
||||
"needs_revision": false
|
||||
}}
|
||||
```
|
||||
|
||||
## Edges
|
||||
|
||||
| Source | Target | Condition | Priority |
|
||||
|--------|--------|-----------|----------|
|
||||
| analyze_audience | generate_content | `on_success` | 0 |
|
||||
| generate_content | review_and_refine | `on_success` | 0 |
|
||||
| review_and_refine | generate_content | `conditional: needs_revision == True` | 10 |
|
||||
|
||||
The `review_and_refine → generate_content` loop has higher priority so it's checked first. If `needs_revision` is false, execution ends at `review_and_refine` (terminal node).
|
||||
|
||||
## Tools
|
||||
|
||||
This recipe uses no external tools — all nodes are `llm_generate`. To extend it, consider adding:
|
||||
- A web search tool for competitive analysis in the `analyze_audience` node
|
||||
- A URL shortener tool for social media content
|
||||
- An image generation tool for visual content variants
|
||||
|
||||
## Variations
|
||||
|
||||
- **Single-channel mode**: Remove the `channels` input and hardcode to one channel for simpler output
|
||||
- **With approval gate**: Add a `human_input` node between `review_and_refine` and the terminal to require human sign-off
|
||||
- **With analytics**: Add a `function` node that logs generated content to a tracking system
|
||||
@@ -0,0 +1,38 @@
|
||||
# Templates
|
||||
|
||||
A template is a working agent scaffold that follows the standard Hive export structure. Copy it, rename it, customize the goal/nodes/edges, and run it.
|
||||
|
||||
## What's in a template
|
||||
|
||||
Each template is a complete agent package:
|
||||
|
||||
```
|
||||
template_name/
|
||||
├── __init__.py # Package exports
|
||||
├── __main__.py # CLI entry point
|
||||
├── agent.py # Goal, edges, graph spec, agent class
|
||||
├── config.py # Runtime configuration
|
||||
├── nodes/
|
||||
│ └── __init__.py # Node definitions (NodeSpec instances)
|
||||
└── README.md # What this template demonstrates
|
||||
```
|
||||
|
||||
## How to use a template
|
||||
|
||||
```bash
|
||||
# 1. Copy to your exports directory
|
||||
cp -r examples/templates/marketing_agent exports/my_marketing_agent
|
||||
|
||||
# 2. Update the module references in __main__.py and __init__.py
|
||||
|
||||
# 3. Customize goal, nodes, edges, and prompts
|
||||
|
||||
# 4. Run it
|
||||
PYTHONPATH=core python -m exports.my_marketing_agent --input '{"product_description": "..."}'
|
||||
```
|
||||
|
||||
## Available templates
|
||||
|
||||
| Template | Description |
|
||||
|----------|-------------|
|
||||
| [marketing_agent](marketing_agent/) | Multi-channel marketing content generator with audience analysis, content generation, and editorial review nodes |
|
||||
@@ -0,0 +1,57 @@
|
||||
# Template: Marketing Content Agent
|
||||
|
||||
A multi-channel marketing content generator. Given a product and audience, this agent analyzes the audience, generates tailored copy for multiple channels with A/B variants, and reviews the output for quality.
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
[analyze-audience] → [generate-content] → [review-and-refine]
|
||||
|
|
||||
(conditional)
|
||||
|
|
||||
needs_revision == True → [generate-content]
|
||||
needs_revision == False → (done)
|
||||
```
|
||||
|
||||
## Nodes
|
||||
|
||||
| Node | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `analyze-audience` | `llm_generate` | Produces structured audience analysis |
|
||||
| `generate-content` | `llm_generate` | Creates per-channel copy with A/B variants |
|
||||
| `review-and-refine` | `llm_generate` | Reviews and optionally revises content |
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# From the repo root
|
||||
PYTHONPATH=core python -m examples.templates.marketing_agent
|
||||
|
||||
# With custom input
|
||||
PYTHONPATH=core python -m examples.templates.marketing_agent --input '{
|
||||
"product_description": "A fitness tracking app",
|
||||
"target_audience": "Health-conscious millennials",
|
||||
"brand_voice": "Energetic and motivational",
|
||||
"channels": ["instagram", "email"]
|
||||
}'
|
||||
```
|
||||
|
||||
## Customization ideas
|
||||
|
||||
- Add a `function` node to call an analytics API and inform audience analysis with real data
|
||||
- Add a `human_input` pause node before final output for editorial approval
|
||||
- Swap `llm_generate` nodes to `llm_tool_use` and add web search tools for competitive research
|
||||
- Add an image generation tool to produce visual assets alongside copy
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
marketing_agent/
|
||||
├── __init__.py # Package exports
|
||||
├── __main__.py # CLI entry point
|
||||
├── agent.py # Goal, edges, graph spec, MarketingAgent class
|
||||
├── config.py # RuntimeConfig and AgentMetadata
|
||||
├── nodes/
|
||||
│ └── __init__.py # NodeSpec definitions
|
||||
└── README.md # This file
|
||||
```
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Marketing Content Agent — template example."""
|
||||
|
||||
from .agent import MarketingAgent, goal, edges, nodes
|
||||
from .config import default_config
|
||||
|
||||
__all__ = ["MarketingAgent", "goal", "edges", "nodes", "default_config"]
|
||||
@@ -0,0 +1,31 @@
|
||||
"""CLI entry point for Marketing Content Agent."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
from .agent import MarketingAgent
|
||||
from .config import default_config
|
||||
|
||||
# Simple CLI — replace with Click for production use
|
||||
input_data = {
|
||||
"product_description": "An AI-powered project management tool for remote teams",
|
||||
"target_audience": "Engineering managers at mid-size tech companies",
|
||||
"brand_voice": "Professional but approachable, concise, data-driven",
|
||||
"channels": ["email", "twitter", "linkedin"],
|
||||
}
|
||||
|
||||
# Accept JSON input from command line
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--input":
|
||||
input_data = json.loads(sys.argv[2])
|
||||
|
||||
agent = MarketingAgent(config=default_config)
|
||||
result = asyncio.run(agent.run(input_data))
|
||||
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,161 @@
|
||||
"""Marketing Content Agent — goal, edges, graph spec, and agent class."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import EdgeCondition, EdgeSpec, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
|
||||
from .config import default_config, RuntimeConfig
|
||||
from .nodes import all_nodes
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Goal
|
||||
# ---------------------------------------------------------------------------
|
||||
goal = Goal(
|
||||
id="marketing-content",
|
||||
name="Marketing Content Generator",
|
||||
description=(
|
||||
"Generate targeted marketing content across multiple channels "
|
||||
"for a given product and audience."
|
||||
),
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="audience-analyzed",
|
||||
description="Audience analysis is produced with demographics and pain points",
|
||||
metric="output_contains",
|
||||
target="audience_analysis",
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="content-generated",
|
||||
description="At least 2 channel-specific content pieces are generated",
|
||||
metric="custom",
|
||||
target="len(content) >= 2",
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="variants-provided",
|
||||
description="A/B variants are provided for each content piece",
|
||||
metric="custom",
|
||||
target="all variants present",
|
||||
),
|
||||
],
|
||||
constraints=[
|
||||
Constraint(
|
||||
id="no-competitor-names",
|
||||
description="No competitor brand names in generated content",
|
||||
constraint_type="hard",
|
||||
category="safety",
|
||||
),
|
||||
Constraint(
|
||||
id="social-length",
|
||||
description="Social media content should be under 280 characters",
|
||||
constraint_type="soft",
|
||||
category="quality",
|
||||
),
|
||||
],
|
||||
input_schema={
|
||||
"product_description": {"type": "string"},
|
||||
"target_audience": {"type": "string"},
|
||||
"brand_voice": {"type": "string"},
|
||||
"channels": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
output_schema={
|
||||
"audience_analysis": {"type": "object"},
|
||||
"content": {"type": "array"},
|
||||
},
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edges
|
||||
# ---------------------------------------------------------------------------
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="analyze-to-generate",
|
||||
source="analyze-audience",
|
||||
target="generate-content",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
description="After audience analysis, generate content",
|
||||
),
|
||||
EdgeSpec(
|
||||
id="generate-to-review",
|
||||
source="generate-content",
|
||||
target="review-and-refine",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
description="After content generation, review and refine",
|
||||
),
|
||||
EdgeSpec(
|
||||
id="review-to-regenerate",
|
||||
source="review-and-refine",
|
||||
target="generate-content",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="needs_revision == True",
|
||||
priority=10,
|
||||
description="If revision needed, loop back to content generation",
|
||||
),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph structure
|
||||
# ---------------------------------------------------------------------------
|
||||
entry_node = "analyze-audience"
|
||||
entry_points = {"start": "analyze-audience"}
|
||||
terminal_nodes = ["review-and-refine"]
|
||||
pause_nodes = []
|
||||
nodes = all_nodes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent class
|
||||
# ---------------------------------------------------------------------------
|
||||
class MarketingAgent:
|
||||
"""Multi-channel marketing content generator agent."""
|
||||
|
||||
def __init__(self, config: RuntimeConfig | None = None):
|
||||
self.config = config or default_config
|
||||
self.goal = goal
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
self.entry_node = entry_node
|
||||
self.terminal_nodes = terminal_nodes
|
||||
self.executor = None
|
||||
|
||||
def _build_graph(self) -> GraphSpec:
|
||||
return GraphSpec(
|
||||
id="marketing-content-graph",
|
||||
goal_id=self.goal.id,
|
||||
entry_node=self.entry_node,
|
||||
entry_points=entry_points,
|
||||
terminal_nodes=self.terminal_nodes,
|
||||
pause_nodes=pause_nodes,
|
||||
nodes=self.nodes,
|
||||
edges=self.edges,
|
||||
default_model=self.config.model,
|
||||
max_tokens=self.config.max_tokens,
|
||||
description="Marketing content generation workflow",
|
||||
)
|
||||
|
||||
def _create_executor(self):
|
||||
runtime = Runtime(storage_path=Path(self.config.storage_path).expanduser())
|
||||
llm = AnthropicProvider(model=self.config.model)
|
||||
self.executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
return self.executor
|
||||
|
||||
async def run(self, context: dict, mock_mode: bool = False) -> dict:
|
||||
graph = self._build_graph()
|
||||
executor = self._create_executor()
|
||||
result = await executor.execute(
|
||||
graph=graph,
|
||||
goal=self.goal,
|
||||
input_data=context,
|
||||
)
|
||||
return {
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"steps": result.steps_executed,
|
||||
"path": result.path,
|
||||
}
|
||||
|
||||
|
||||
default_agent = MarketingAgent()
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Runtime configuration for Marketing Content Agent."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfig:
|
||||
model: str = "claude-haiku-4-5-20251001"
|
||||
max_tokens: int = 2048
|
||||
storage_path: str = "~/.hive/storage"
|
||||
mock_mode: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
name: str = "marketing_agent"
|
||||
version: str = "0.1.0"
|
||||
description: str = "Multi-channel marketing content generator"
|
||||
author: str = ""
|
||||
tags: list[str] = field(
|
||||
default_factory=lambda: ["marketing", "content", "template"]
|
||||
)
|
||||
|
||||
|
||||
default_config = RuntimeConfig()
|
||||
metadata = AgentMetadata()
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Node definitions for Marketing Content Agent."""
|
||||
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 1: Analyze the target audience
|
||||
# ---------------------------------------------------------------------------
|
||||
analyze_audience_node = NodeSpec(
|
||||
id="analyze-audience",
|
||||
name="Analyze Audience",
|
||||
description="Produce a structured audience analysis from the product and target audience description.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["product_description", "target_audience"],
|
||||
output_keys=["audience_analysis"],
|
||||
system_prompt="""\
|
||||
You are a marketing strategist. Analyze the target audience for a product.
|
||||
|
||||
Product: {product_description}
|
||||
Target audience: {target_audience}
|
||||
|
||||
Produce a structured analysis as raw JSON (no markdown):
|
||||
{{
|
||||
"audience_analysis": {{
|
||||
"demographics": "...",
|
||||
"pain_points": ["..."],
|
||||
"motivations": ["..."],
|
||||
"preferred_channels": ["..."],
|
||||
"messaging_angle": "..."
|
||||
}}
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 2: Generate channel-specific content with A/B variants
|
||||
# ---------------------------------------------------------------------------
|
||||
generate_content_node = NodeSpec(
|
||||
id="generate-content",
|
||||
name="Generate Content",
|
||||
description="Create marketing copy for each requested channel with two variants per channel.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["product_description", "audience_analysis", "brand_voice", "channels"],
|
||||
output_keys=["content"],
|
||||
system_prompt="""\
|
||||
You are a marketing copywriter. Generate content for each channel.
|
||||
|
||||
Product: {product_description}
|
||||
Audience analysis: {audience_analysis}
|
||||
Brand voice: {brand_voice}
|
||||
Channels: {channels}
|
||||
|
||||
For each channel, produce two variants (A and B).
|
||||
|
||||
Output as raw JSON (no markdown):
|
||||
{{
|
||||
"content": [
|
||||
{{
|
||||
"channel": "twitter",
|
||||
"variant_a": "...",
|
||||
"variant_b": "..."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 3: Review and refine content
|
||||
# ---------------------------------------------------------------------------
|
||||
review_and_refine_node = NodeSpec(
|
||||
id="review-and-refine",
|
||||
name="Review and Refine",
|
||||
description="Review generated content for brand voice alignment and channel fit. Revise if needed.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["content", "brand_voice"],
|
||||
output_keys=["content", "needs_revision"],
|
||||
system_prompt="""\
|
||||
You are a senior marketing editor. Review the following content for brand
|
||||
voice alignment, clarity, and channel appropriateness.
|
||||
|
||||
Content: {content}
|
||||
Brand voice: {brand_voice}
|
||||
|
||||
If any piece needs revision, fix it and set needs_revision to true.
|
||||
If everything looks good, return the content unchanged with needs_revision false.
|
||||
|
||||
Output as raw JSON (no markdown):
|
||||
{{
|
||||
"content": [...],
|
||||
"needs_revision": false
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# All nodes for easy import
|
||||
all_nodes = [
|
||||
analyze_audience_node,
|
||||
generate_content_node,
|
||||
review_and_refine_node,
|
||||
]
|
||||
@@ -0,0 +1,2 @@
|
||||
[tool.uv.workspace]
|
||||
members = ["core", "tools"]
|
||||
+149
-153
@@ -11,6 +11,14 @@
|
||||
|
||||
set -e
|
||||
|
||||
# Detect Bash version for compatibility
|
||||
BASH_MAJOR_VERSION="${BASH_VERSINFO[0]}"
|
||||
USE_ASSOC_ARRAYS=false
|
||||
if [ "$BASH_MAJOR_VERSION" -ge 4 ]; then
|
||||
USE_ASSOC_ARRAYS=true
|
||||
fi
|
||||
echo "[debug] Bash version: ${BASH_VERSION}"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
@@ -52,7 +60,7 @@ prompt_choice() {
|
||||
echo -e "${BOLD}$prompt${NC}"
|
||||
for opt in "${options[@]}"; do
|
||||
echo -e " ${CYAN}$i)${NC} $opt"
|
||||
((i++))
|
||||
i=$((i + 1))
|
||||
done
|
||||
echo ""
|
||||
|
||||
@@ -60,7 +68,8 @@ prompt_choice() {
|
||||
while true; do
|
||||
read -r -p "Enter choice (1-${#options[@]}): " choice
|
||||
if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#options[@]}" ]; then
|
||||
return $((choice - 1))
|
||||
PROMPT_CHOICE=$((choice - 1))
|
||||
return 0
|
||||
fi
|
||||
echo -e "${RED}Invalid choice. Please enter 1-${#options[@]}${NC}"
|
||||
done
|
||||
@@ -174,63 +183,26 @@ echo ""
|
||||
echo -e "${DIM}This may take a minute...${NC}"
|
||||
echo ""
|
||||
|
||||
# Upgrade pip, setuptools, and wheel
|
||||
echo -n " Upgrading pip... "
|
||||
$PYTHON_CMD -m pip install --upgrade pip setuptools wheel > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install framework package from core/
|
||||
echo -n " Installing framework... "
|
||||
cd "$SCRIPT_DIR/core"
|
||||
# Install all workspace packages (core + tools) from workspace root
|
||||
echo -n " Installing workspace packages... "
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
if [ -f "pyproject.toml" ]; then
|
||||
uv sync > /dev/null 2>&1
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${GREEN} ✓ framework package installed${NC}"
|
||||
if uv sync > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ workspace packages installed${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ framework installation had issues (may be OK)${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}failed (no pyproject.toml)${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install aden_tools package from tools/
|
||||
echo -n " Installing tools... "
|
||||
cd "$SCRIPT_DIR/tools"
|
||||
|
||||
if [ -f "pyproject.toml" ]; then
|
||||
uv sync > /dev/null 2>&1
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${GREEN} ✓ aden_tools package installed${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ aden_tools installation failed${NC}"
|
||||
echo -e "${RED} ✗ workspace installation failed${NC}"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
echo -e "${RED}failed (no root pyproject.toml)${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install MCP dependencies
|
||||
echo -n " Installing MCP... "
|
||||
$PYTHON_CMD -m pip install mcp fastmcp > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Fix openai version compatibility
|
||||
echo -n " Checking openai... "
|
||||
$PYTHON_CMD -m pip install "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install click for CLI
|
||||
echo -n " Installing CLI tools... "
|
||||
$PYTHON_CMD -m pip install click > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install Playwright browser
|
||||
echo -n " Installing Playwright browser... "
|
||||
if $PYTHON_CMD -c "import playwright" > /dev/null 2>&1; then
|
||||
if $PYTHON_CMD -m playwright install chromium > /dev/null 2>&1; then
|
||||
if uv run python -c "import playwright" > /dev/null 2>&1; then
|
||||
if uv run python -m playwright install chromium > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}⏭${NC}"
|
||||
@@ -249,33 +221,6 @@ echo ""
|
||||
# ============================================================
|
||||
|
||||
echo -e "${YELLOW}⬢${NC} ${BLUE}${BOLD}Step 3: Configuring LLM provider...${NC}"
|
||||
# Install MCP dependencies (in tools venv)
|
||||
echo " Installing MCP dependencies..."
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
uv pip install --python "$TOOLS_PYTHON" mcp fastmcp > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ MCP dependencies installed${NC}"
|
||||
|
||||
# Fix openai version compatibility (in tools venv)
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
OPENAI_VERSION=$($TOOLS_PYTHON -c "import openai; print(openai.__version__)" 2>/dev/null || echo "not_installed")
|
||||
if [ "$OPENAI_VERSION" = "not_installed" ]; then
|
||||
echo " Installing openai package..."
|
||||
uv pip install --python "$TOOLS_PYTHON" "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ openai installed${NC}"
|
||||
elif [[ "$OPENAI_VERSION" =~ ^0\. ]]; then
|
||||
echo " Upgrading openai to 1.x+ for litellm compatibility..."
|
||||
uv pip install --python "$TOOLS_PYTHON" --upgrade "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ openai upgraded${NC}"
|
||||
else
|
||||
echo -e "${GREEN} ✓ openai $OPENAI_VERSION is compatible${NC}"
|
||||
fi
|
||||
|
||||
# Install click for CLI (in tools venv)
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
uv pip install --python "$TOOLS_PYTHON" click > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ click installed${NC}"
|
||||
|
||||
cd "$SCRIPT_DIR"
|
||||
echo ""
|
||||
|
||||
# ============================================================
|
||||
@@ -287,42 +232,28 @@ echo ""
|
||||
|
||||
IMPORT_ERRORS=0
|
||||
|
||||
# Test imports using their respective venvs
|
||||
CORE_PYTHON="$SCRIPT_DIR/core/.venv/bin/python"
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
|
||||
# Test framework import (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "import framework" > /dev/null 2>&1; then
|
||||
# Test imports using workspace venv via uv run
|
||||
if uv run python -c "import framework" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ framework imports OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ framework import failed${NC}"
|
||||
IMPORT_ERRORS=$((IMPORT_ERRORS + 1))
|
||||
fi
|
||||
|
||||
# Test aden_tools import (from tools venv)
|
||||
if [ -f "$TOOLS_PYTHON" ] && $TOOLS_PYTHON -c "import aden_tools" > /dev/null 2>&1; then
|
||||
if uv run python -c "import aden_tools" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ aden_tools imports OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ aden_tools import failed${NC}"
|
||||
IMPORT_ERRORS=$((IMPORT_ERRORS + 1))
|
||||
fi
|
||||
|
||||
# Test litellm import (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK (core)${NC}"
|
||||
if uv run python -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ litellm import issues in core (may be OK)${NC}"
|
||||
echo -e "${YELLOW} ⚠ litellm import issues (may be OK)${NC}"
|
||||
fi
|
||||
|
||||
# Test litellm import (from tools venv)
|
||||
if [ -f "$TOOLS_PYTHON" ] && $TOOLS_PYTHON -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK (tools)${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ litellm import issues in tools (may be OK)${NC}"
|
||||
fi
|
||||
|
||||
# Test MCP server module (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "from framework.mcp import agent_builder_server" > /dev/null 2>&1; then
|
||||
if uv run python -c "from framework.mcp import agent_builder_server" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ MCP server module OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ MCP server module failed${NC}"
|
||||
@@ -344,53 +275,105 @@ echo ""
|
||||
echo -e "${BLUE}Step 4: Verifying Claude Code skills...${NC}"
|
||||
echo ""
|
||||
|
||||
# Provider data as parallel indexed arrays (Bash 3.2 compatible — no declare -A)
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_DISPLAY_NAMES=("Anthropic (Claude)" "OpenAI (GPT)" "Google Gemini" "Google AI" "Groq" "Cerebras" "Mistral" "Together AI" "DeepSeek")
|
||||
PROVIDER_ID_LIST=(anthropic openai gemini google groq cerebras mistral together deepseek)
|
||||
# Provider configuration - use associative arrays (Bash 4+) or indexed arrays (Bash 3.2)
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
# Bash 4+ - use associative arrays (cleaner and more efficient)
|
||||
declare -A PROVIDER_NAMES=(
|
||||
["ANTHROPIC_API_KEY"]="Anthropic (Claude)"
|
||||
["OPENAI_API_KEY"]="OpenAI (GPT)"
|
||||
["GEMINI_API_KEY"]="Google Gemini"
|
||||
["GOOGLE_API_KEY"]="Google AI"
|
||||
["GROQ_API_KEY"]="Groq"
|
||||
["CEREBRAS_API_KEY"]="Cerebras"
|
||||
["MISTRAL_API_KEY"]="Mistral"
|
||||
["TOGETHER_API_KEY"]="Together AI"
|
||||
["DEEPSEEK_API_KEY"]="DeepSeek"
|
||||
)
|
||||
|
||||
# Default models by provider id (parallel arrays)
|
||||
MODEL_PROVIDER_IDS=(anthropic openai gemini groq cerebras mistral together_ai deepseek)
|
||||
MODEL_DEFAULTS=("claude-sonnet-4-5-20250929" "gpt-4o" "gemini-3.0-flash-preview" "moonshotai/kimi-k2-instruct-0905" "zai-glm-4.7" "mistral-large-latest" "meta-llama/Llama-3.3-70B-Instruct-Turbo" "deepseek-chat")
|
||||
declare -A PROVIDER_IDS=(
|
||||
["ANTHROPIC_API_KEY"]="anthropic"
|
||||
["OPENAI_API_KEY"]="openai"
|
||||
["GEMINI_API_KEY"]="gemini"
|
||||
["GOOGLE_API_KEY"]="google"
|
||||
["GROQ_API_KEY"]="groq"
|
||||
["CEREBRAS_API_KEY"]="cerebras"
|
||||
["MISTRAL_API_KEY"]="mistral"
|
||||
["TOGETHER_API_KEY"]="together"
|
||||
["DEEPSEEK_API_KEY"]="deepseek"
|
||||
)
|
||||
|
||||
# Helper: get provider display name for an env var
|
||||
get_provider_name() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_DISPLAY_NAMES[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
declare -A DEFAULT_MODELS=(
|
||||
["anthropic"]="claude-sonnet-4-5-20250929"
|
||||
["openai"]="gpt-4o"
|
||||
["gemini"]="gemini-3.0-flash-preview"
|
||||
["groq"]="moonshotai/kimi-k2-instruct-0905"
|
||||
["cerebras"]="zai-glm-4.7"
|
||||
["mistral"]="mistral-large-latest"
|
||||
["together_ai"]="meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||
["deepseek"]="deepseek-chat"
|
||||
)
|
||||
|
||||
# Helper: get provider id for an env var
|
||||
get_provider_id() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_ID_LIST[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
# Helper functions for Bash 4+
|
||||
get_provider_name() {
|
||||
echo "${PROVIDER_NAMES[$1]}"
|
||||
}
|
||||
|
||||
# Helper: get default model for a provider id
|
||||
get_default_model() {
|
||||
local provider_id="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#MODEL_PROVIDER_IDS[@]} ]; do
|
||||
if [ "${MODEL_PROVIDER_IDS[$i]}" = "$provider_id" ]; then
|
||||
echo "${MODEL_DEFAULTS[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
get_provider_id() {
|
||||
echo "${PROVIDER_IDS[$1]}"
|
||||
}
|
||||
|
||||
get_default_model() {
|
||||
echo "${DEFAULT_MODELS[$1]}"
|
||||
}
|
||||
else
|
||||
# Bash 3.2 - use parallel indexed arrays
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_DISPLAY_NAMES=("Anthropic (Claude)" "OpenAI (GPT)" "Google Gemini" "Google AI" "Groq" "Cerebras" "Mistral" "Together AI" "DeepSeek")
|
||||
PROVIDER_ID_LIST=(anthropic openai gemini google groq cerebras mistral together deepseek)
|
||||
|
||||
# Default models by provider id (parallel arrays)
|
||||
MODEL_PROVIDER_IDS=(anthropic openai gemini groq cerebras mistral together_ai deepseek)
|
||||
MODEL_DEFAULTS=("claude-sonnet-4-5-20250929" "gpt-4o" "gemini-3.0-flash-preview" "moonshotai/kimi-k2-instruct-0905" "zai-glm-4.7" "mistral-large-latest" "meta-llama/Llama-3.3-70B-Instruct-Turbo" "deepseek-chat")
|
||||
|
||||
# Helper: get provider display name for an env var
|
||||
get_provider_name() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_DISPLAY_NAMES[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
|
||||
# Helper: get provider id for an env var
|
||||
get_provider_id() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_ID_LIST[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
|
||||
# Helper: get default model for a provider id
|
||||
get_default_model() {
|
||||
local provider_id="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#MODEL_PROVIDER_IDS[@]} ]; do
|
||||
if [ "${MODEL_PROVIDER_IDS[$i]}" = "$provider_id" ]; then
|
||||
echo "${MODEL_DEFAULTS[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
fi
|
||||
|
||||
# Configuration directory
|
||||
HIVE_CONFIG_DIR="$HOME/.hive"
|
||||
@@ -413,7 +396,7 @@ config = {
|
||||
'model': '$model',
|
||||
'api_key_env_var': '$env_var'
|
||||
},
|
||||
'created_at': '$(date -Iseconds)'
|
||||
'created_at': '$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")'
|
||||
}
|
||||
with open('$HIVE_CONFIG_FILE', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
@@ -442,13 +425,25 @@ FOUND_ENV_VARS=() # Corresponding env var names
|
||||
SELECTED_PROVIDER_ID="" # Will hold the chosen provider ID
|
||||
SELECTED_ENV_VAR="" # Will hold the chosen env var
|
||||
|
||||
for env_var in "${PROVIDER_ENV_VARS[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
# Bash 4+ - iterate over associative array keys
|
||||
for env_var in "${!PROVIDER_NAMES[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
else
|
||||
# Bash 3.2 - iterate over indexed array
|
||||
for env_var in "${PROVIDER_ENV_VARS[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${#FOUND_PROVIDERS[@]} -gt 0 ]; then
|
||||
echo "Found API keys:"
|
||||
@@ -476,7 +471,7 @@ if [ ${#FOUND_PROVIDERS[@]} -gt 0 ]; then
|
||||
i=1
|
||||
for provider in "${FOUND_PROVIDERS[@]}"; do
|
||||
echo -e " ${CYAN}$i)${NC} $provider"
|
||||
((i++))
|
||||
i=$((i + 1))
|
||||
done
|
||||
echo ""
|
||||
|
||||
@@ -507,7 +502,7 @@ if [ -z "$SELECTED_PROVIDER_ID" ]; then
|
||||
"Groq - Fast, free tier" \
|
||||
"Cerebras - Fast, free tier" \
|
||||
"Skip for now"
|
||||
choice=$?
|
||||
choice=$PROMPT_CHOICE
|
||||
|
||||
case $choice in
|
||||
0)
|
||||
@@ -542,7 +537,8 @@ if [ -z "$SELECTED_PROVIDER_ID" ]; then
|
||||
;;
|
||||
5)
|
||||
echo ""
|
||||
echo -e "${YELLOW}Skipped.${NC} Add your API key later:"
|
||||
echo -e "${YELLOW}Skipped.${NC} An LLM API key is required to test and use worker agents."
|
||||
echo -e "Add your API key later by running:"
|
||||
echo ""
|
||||
echo -e " ${CYAN}echo 'ANTHROPIC_API_KEY=your-key' >> .env${NC}"
|
||||
echo ""
|
||||
@@ -595,7 +591,7 @@ ERRORS=0
|
||||
|
||||
# Test imports
|
||||
echo -n " ⬡ framework... "
|
||||
if $PYTHON_CMD -c "import framework" > /dev/null 2>&1; then
|
||||
if uv run python -c "import framework" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
@@ -603,7 +599,7 @@ else
|
||||
fi
|
||||
|
||||
echo -n " ⬡ aden_tools... "
|
||||
if $PYTHON_CMD -c "import aden_tools" > /dev/null 2>&1; then
|
||||
if uv run python -c "import aden_tools" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
@@ -611,7 +607,7 @@ else
|
||||
fi
|
||||
|
||||
echo -n " ⬡ litellm... "
|
||||
if $PYTHON_CMD -c "import litellm" > /dev/null 2>&1; then
|
||||
if uv run python -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}--${NC}"
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
<#
|
||||
|
||||
setup-python.ps1 - Python Environment Setup for Aden Agent Framework
|
||||
|
||||
This script sets up the Python environment with all required packages
|
||||
for building and running goal-driven agents.
|
||||
#>
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# Colors for output
|
||||
$RED = "Red"
|
||||
$GREEN = "Green"
|
||||
$YELLOW = "Yellow"
|
||||
$BLUE = "Cyan"
|
||||
|
||||
# Get the directory where this script is located
|
||||
$SCRIPT_DIR = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||
$PROJECT_ROOT = Split-Path -Parent $SCRIPT_DIR
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "=================================================="
|
||||
Write-Host " Aden Agent Framework - Python Setup"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
# Check for Python
|
||||
$pythonCmd = $null
|
||||
if (Get-Command python -ErrorAction SilentlyContinue) {
|
||||
$pythonCmd = "python"
|
||||
}
|
||||
|
||||
if (-not $pythonCmd) {
|
||||
Write-Host "Error: Python is not installed." -ForegroundColor $RED
|
||||
Write-Host "Please install Python 3.11+ from https://python.org"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check Python version
|
||||
$versionInfo = & $pythonCmd -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')"
|
||||
$major = & $pythonCmd -c "import sys; print(sys.version_info.major)"
|
||||
$minor = & $pythonCmd -c "import sys; print(sys.version_info.minor)"
|
||||
|
||||
Write-Host "Detected Python: $versionInfo" -ForegroundColor $BLUE
|
||||
|
||||
if ($major -lt 3 -or ($major -eq 3 -and $minor -lt 11)) {
|
||||
Write-Host "Error: Python 3.11+ is required (found $versionInfo)" -ForegroundColor $RED
|
||||
Write-Host "Please upgrade your Python installation"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if ($minor -lt 11) {
|
||||
Write-Host "Warning: Python 3.11+ is recommended for best compatibility" -ForegroundColor $YELLOW
|
||||
Write-Host "You have Python $versionInfo which may work but is not officially supported" -ForegroundColor $YELLOW
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
Write-Host "[OK] Python version check passed" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Create and activate virtual environment
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Setting up Python Virtual Environment"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
$VENV_PATH = Join-Path $PROJECT_ROOT ".venv"
|
||||
$VENV_PYTHON = Join-Path $VENV_PATH "Scripts\python.exe"
|
||||
$VENV_ACTIVATE = Join-Path $VENV_PATH "Scripts\Activate.ps1"
|
||||
|
||||
if (-not (Test-Path $VENV_PYTHON)) {
|
||||
Write-Host "Creating virtual environment at .venv..."
|
||||
& $pythonCmd -m venv $VENV_PATH
|
||||
Write-Host "[OK] Virtual environment created" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[OK] Virtual environment already exists" -ForegroundColor $GREEN
|
||||
}
|
||||
|
||||
# Activate venv
|
||||
Write-Host "Activating virtual environment..."
|
||||
& $VENV_ACTIVATE
|
||||
Write-Host "[OK] Virtual environment activated" -ForegroundColor $GREEN
|
||||
|
||||
# From here on, always use venv python
|
||||
$pythonCmd = $VENV_PYTHON
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Check for pip
|
||||
try {
|
||||
& $pythonCmd -m pip --version | Out-Null
|
||||
}
|
||||
catch {
|
||||
Write-Host "Error: pip is not installed" -ForegroundColor $RED
|
||||
Write-Host "Please install pip for Python $versionInfo"
|
||||
exit 1
|
||||
}
|
||||
|
||||
Write-Host "[OK] pip detected" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Upgrade pip, setuptools, and wheel
|
||||
Write-Host "Upgrading pip, setuptools, and wheel..."
|
||||
& $pythonCmd -m pip install --upgrade pip setuptools wheel
|
||||
Write-Host "[OK] Core packages upgraded" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Install core framework package
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Installing Core Framework Package"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location "$PROJECT_ROOT\core"
|
||||
|
||||
if (Test-Path "pyproject.toml") {
|
||||
Write-Host "Installing framework from core/ (editable mode)..."
|
||||
& $pythonCmd -m pip install -e . | Out-Null
|
||||
Write-Host "[OK] Framework package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[WARN] No pyproject.toml found in core/, skipping framework installation" -ForegroundColor $YELLOW
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Install tools package
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Installing Tools Package (aden_tools)"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location "$PROJECT_ROOT\tools"
|
||||
|
||||
if (Test-Path "pyproject.toml") {
|
||||
Write-Host "Installing aden_tools from tools/ (editable mode)..."
|
||||
& $pythonCmd -m pip install -e . | Out-Null
|
||||
Write-Host "[OK] Tools package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "Error: No pyproject.toml found in tools/" -ForegroundColor $RED
|
||||
exit 1
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Fix openai version compatibility with litellm
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Fixing Package Compatibility"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
try {
|
||||
$openaiVersion = & $pythonCmd -c "import openai; print(openai.__version__)"
|
||||
}
|
||||
catch {
|
||||
$openaiVersion = "not_installed"
|
||||
}
|
||||
|
||||
if ($openaiVersion -eq "not_installed") {
|
||||
Write-Host "Installing openai package..."
|
||||
& $pythonCmd -m pip install "openai>=1.0.0" | Out-Null
|
||||
Write-Host "[OK] openai package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
elseif ($openaiVersion.StartsWith("0.")) {
|
||||
Write-Host "Found old openai version: $openaiVersion" -ForegroundColor $YELLOW
|
||||
Write-Host "Upgrading to openai 1.x+ for litellm compatibility..."
|
||||
& $pythonCmd -m pip install --upgrade "openai>=1.0.0" | Out-Null
|
||||
$openaiVersion = & $pythonCmd -c "import openai; print(openai.__version__)"
|
||||
Write-Host "[OK] openai upgraded to $openaiVersion" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[OK] openai $openaiVersion is compatible" -ForegroundColor $GREEN
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Verify installations
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Verifying Installation"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location $PROJECT_ROOT
|
||||
|
||||
# Test framework import
|
||||
& $pythonCmd -c "import framework" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] framework package imports successfully" -ForegroundColor Green
|
||||
}
|
||||
else {
|
||||
Write-Host "[FAIL] framework package import failed" -ForegroundColor Red
|
||||
}
|
||||
|
||||
# Test aden_tools import
|
||||
& $pythonCmd -c "import aden_tools" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] aden_tools package imports successfully" -ForegroundColor Green
|
||||
}
|
||||
else {
|
||||
Write-Host "[FAIL] aden_tools package import failed" -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Test litellm
|
||||
& $pythonCmd -c "import litellm" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] litellm package imports successfully" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[WARN] litellm import had issues (may be OK if not using LLM features)" -ForegroundColor $YELLOW
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Print agent commands
|
||||
Write-Host "=================================================="
|
||||
Write-Host " Setup Complete!"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
Write-Host "Python packages installed:"
|
||||
Write-Host " - framework (core agent runtime)"
|
||||
Write-Host " - aden_tools (tools and MCP servers)"
|
||||
Write-Host " - All dependencies and compatibility fixes applied"
|
||||
Write-Host ""
|
||||
Write-Host "To run agents on Windows (PowerShell):"
|
||||
Write-Host ""
|
||||
Write-Host "1. From the project root, set PYTHONPATH:"
|
||||
Write-Host " `$env:PYTHONPATH=`"core;exports`""
|
||||
Write-Host ""
|
||||
Write-Host "2. Run an agent command:"
|
||||
Write-Host " python -m agent_name validate"
|
||||
Write-Host " python -m agent_name info"
|
||||
Write-Host " python -m agent_name run --input '{...}'"
|
||||
Write-Host ""
|
||||
Write-Host "Example (support_ticket_agent):"
|
||||
Write-Host " python -m support_ticket_agent validate"
|
||||
Write-Host " python -m support_ticket_agent info"
|
||||
Write-Host " python -m support_ticket_agent run --input '{""ticket_content"":""..."",""customer_id"":""..."",""ticket_id"":""...""}'"
|
||||
Write-Host ""
|
||||
Write-Host "Notes:"
|
||||
Write-Host " - Ensure the virtual environment is activated (.venv)"
|
||||
Write-Host " - PYTHONPATH must be set in each new PowerShell session"
|
||||
Write-Host ""
|
||||
Write-Host "Documentation:"
|
||||
Write-Host " $PROJECT_ROOT\README.md"
|
||||
Write-Host ""
|
||||
Write-Host "Agent Examples:"
|
||||
Write-Host " $PROJECT_ROOT\exports\"
|
||||
Write-Host ""
|
||||
+1
-12
@@ -68,18 +68,7 @@ from starlette.responses import PlainTextResponse # noqa: E402
|
||||
from aden_tools.credentials import CredentialError, CredentialStoreAdapter # noqa: E402
|
||||
from aden_tools.tools import register_all_tools # noqa: E402
|
||||
|
||||
# Create credential store with access to both env vars AND encrypted store
|
||||
# This allows using Aden-synced credentials from ~/.hive/credentials
|
||||
try:
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
store = CredentialStore.with_encrypted_storage() # ~/.hive/credentials
|
||||
credentials = CredentialStoreAdapter(store)
|
||||
logger.info("Using CredentialStoreAdapter with encrypted storage")
|
||||
except Exception as e:
|
||||
# Fall back to env-only adapter if encrypted storage fails
|
||||
credentials = CredentialStoreAdapter.with_env_storage()
|
||||
logger.warning(f"Falling back to env-only CredentialStoreAdapter: {e}")
|
||||
credentials = CredentialStoreAdapter.default()
|
||||
|
||||
# Tier 1: Validate startup-required credentials (if any)
|
||||
try:
|
||||
|
||||
@@ -30,6 +30,7 @@ dependencies = [
|
||||
"playwright-stealth>=1.0.5",
|
||||
"litellm>=1.81.0",
|
||||
"resend>=2.0.0",
|
||||
"framework",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -54,6 +55,9 @@ all = [
|
||||
"duckdb>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
framework = { workspace = true }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# MCP Server
|
||||
fastmcp
|
||||
|
||||
# Tool dependencies
|
||||
diff-match-patch
|
||||
pypdf
|
||||
beautifulsoup4
|
||||
lxml
|
||||
playwright
|
||||
playwright-stealth
|
||||
requests
|
||||
|
||||
# Note: After installing, run `playwright install` to download browser binaries
|
||||
@@ -10,7 +10,7 @@ Usage:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
mcp = FastMCP("my-server")
|
||||
credentials = CredentialStoreAdapter.with_env_storage()
|
||||
credentials = CredentialStoreAdapter.default()
|
||||
register_all_tools(mcp, credentials=credentials)
|
||||
"""
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ Philosophy: Google Strictness + Apple UX
|
||||
|
||||
Usage:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
from core.framework.credentials import CredentialStore
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
# With encrypted storage (production)
|
||||
store = CredentialStore.with_encrypted_storage() # defaults to ~/.hive/credentials
|
||||
credentials = CredentialStoreAdapter(store)
|
||||
|
||||
# With env vars only (simple setup)
|
||||
credentials = CredentialStoreAdapter.with_env_storage()
|
||||
# With composite storage (encrypted primary + env fallback)
|
||||
credentials = CredentialStoreAdapter.default()
|
||||
|
||||
# In agent runner (validate at agent load time)
|
||||
credentials.validate_for_tools(["web_search", "file_read"])
|
||||
|
||||
@@ -70,6 +70,9 @@ class CredentialSpec:
|
||||
credential_key: str = "access_token"
|
||||
"""Key name within the credential (e.g., 'access_token', 'api_key')"""
|
||||
|
||||
credential_group: str = ""
|
||||
"""Group name for credentials that must be configured together (e.g., 'google_custom_search')"""
|
||||
|
||||
|
||||
class CredentialError(Exception):
|
||||
"""Raised when required credentials are missing."""
|
||||
|
||||
@@ -15,5 +15,22 @@ EMAIL_CREDENTIALS = {
|
||||
startup_required=False,
|
||||
help_url="https://resend.com/api-keys",
|
||||
description="API key for Resend email service",
|
||||
# Auth method support
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Resend API key:
|
||||
1. Go to https://resend.com and create an account (or sign in)
|
||||
2. Navigate to API Keys in the dashboard
|
||||
3. Click "Create API Key"
|
||||
4. Give it a name (e.g., "Hive Agent") and choose permissions:
|
||||
- "Sending access" is sufficient for most use cases
|
||||
- "Full access" if you also need to manage domains
|
||||
5. Copy the API key (starts with re_)
|
||||
6. Store it securely - you won't be able to see it again!
|
||||
7. Note: You'll also need to verify a domain to send emails from custom addresses""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://api.resend.com/domains",
|
||||
# Credential store mapping
|
||||
credential_id="resend",
|
||||
credential_key="api_key",
|
||||
),
|
||||
}
|
||||
|
||||
@@ -231,10 +231,213 @@ class GoogleSearchHealthChecker:
|
||||
)
|
||||
|
||||
|
||||
class AnthropicHealthChecker:
|
||||
"""Health checker for Anthropic API credentials."""
|
||||
|
||||
ENDPOINT = "https://api.anthropic.com/v1/messages"
|
||||
TIMEOUT = 10.0
|
||||
|
||||
def check(self, api_key: str) -> HealthCheckResult:
|
||||
"""
|
||||
Validate Anthropic API key without consuming tokens.
|
||||
|
||||
Sends a deliberately invalid request (empty messages) to the messages endpoint.
|
||||
A 401 means invalid key; 400 (bad request) means the key authenticated
|
||||
but the payload was rejected — confirming the key is valid without
|
||||
generating any tokens. 429 (rate limited) also indicates a valid key.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.TIMEOUT) as client:
|
||||
response = client.post(
|
||||
self.ENDPOINT,
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
# Empty messages triggers 400 (not 200), so no tokens are consumed.
|
||||
json={
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"max_tokens": 1,
|
||||
"messages": [],
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message="Anthropic API key valid",
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Anthropic API key is invalid",
|
||||
details={"status_code": 401},
|
||||
)
|
||||
elif response.status_code == 429:
|
||||
# Rate limited but key is valid
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message="Anthropic API key valid (rate limited)",
|
||||
details={"status_code": 429, "rate_limited": True},
|
||||
)
|
||||
elif response.status_code == 400:
|
||||
# Bad request but key authenticated - key is valid
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message="Anthropic API key valid",
|
||||
details={"status_code": 400},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Anthropic API returned status {response.status_code}",
|
||||
details={"status_code": response.status_code},
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Anthropic API request timed out",
|
||||
details={"error": "timeout"},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Failed to connect to Anthropic API: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
class GitHubHealthChecker:
|
||||
"""Health checker for GitHub Personal Access Token."""
|
||||
|
||||
ENDPOINT = "https://api.github.com/user"
|
||||
TIMEOUT = 10.0
|
||||
|
||||
def check(self, access_token: str) -> HealthCheckResult:
|
||||
"""
|
||||
Validate GitHub token by fetching the authenticated user.
|
||||
|
||||
Returns the authenticated username on success.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.TIMEOUT) as client:
|
||||
response = client.get(
|
||||
self.ENDPOINT,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
username = data.get("login", "unknown")
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message=f"GitHub token valid (authenticated as {username})",
|
||||
details={"username": username},
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="GitHub token is invalid or expired",
|
||||
details={"status_code": 401},
|
||||
)
|
||||
elif response.status_code == 403:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="GitHub token lacks required permissions",
|
||||
details={"status_code": 403},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"GitHub API returned status {response.status_code}",
|
||||
details={"status_code": response.status_code},
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="GitHub API request timed out",
|
||||
details={"error": "timeout"},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Failed to connect to GitHub API: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
class ResendHealthChecker:
|
||||
"""Health checker for Resend API credentials."""
|
||||
|
||||
ENDPOINT = "https://api.resend.com/domains"
|
||||
TIMEOUT = 10.0
|
||||
|
||||
def check(self, api_key: str) -> HealthCheckResult:
|
||||
"""
|
||||
Validate Resend API key by listing domains.
|
||||
|
||||
A successful response confirms the key is valid.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.TIMEOUT) as client:
|
||||
response = client.get(
|
||||
self.ENDPOINT,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message="Resend API key valid",
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Resend API key is invalid",
|
||||
details={"status_code": 401},
|
||||
)
|
||||
elif response.status_code == 403:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Resend API key lacks required permissions",
|
||||
details={"status_code": 403},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Resend API returned status {response.status_code}",
|
||||
details={"status_code": response.status_code},
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Resend API request timed out",
|
||||
details={"error": "timeout"},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Failed to connect to Resend API: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
# Registry of health checkers
|
||||
HEALTH_CHECKERS: dict[str, CredentialHealthChecker] = {
|
||||
"hubspot": HubSpotHealthChecker(),
|
||||
"brave_search": BraveSearchHealthChecker(),
|
||||
"google_search": GoogleSearchHealthChecker(),
|
||||
"anthropic": AnthropicHealthChecker(),
|
||||
"github": GitHubHealthChecker(),
|
||||
"resend": ResendHealthChecker(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,47 @@ Contains credentials for third-party service integrations (HubSpot, etc.).
|
||||
from .base import CredentialSpec
|
||||
|
||||
INTEGRATION_CREDENTIALS = {
|
||||
"github": CredentialSpec(
|
||||
env_var="GITHUB_TOKEN",
|
||||
tools=[
|
||||
"github_list_repos",
|
||||
"github_get_repo",
|
||||
"github_search_repos",
|
||||
"github_list_issues",
|
||||
"github_get_issue",
|
||||
"github_create_issue",
|
||||
"github_update_issue",
|
||||
"github_list_pull_requests",
|
||||
"github_get_pull_request",
|
||||
"github_create_pull_request",
|
||||
"github_search_code",
|
||||
"github_list_branches",
|
||||
"github_get_branch",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://github.com/settings/tokens",
|
||||
description="GitHub Personal Access Token (classic)",
|
||||
# Auth method support
|
||||
aden_supported=False,
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a GitHub Personal Access Token:
|
||||
1. Go to GitHub Settings > Developer settings > Personal access tokens
|
||||
2. Click "Generate new token" > "Generate new token (classic)"
|
||||
3. Give your token a descriptive name (e.g., "Hive Agent")
|
||||
4. Select the following scopes:
|
||||
- repo (Full control of private repositories)
|
||||
- read:org (Read org and team membership - optional)
|
||||
- user (Read user profile data - optional)
|
||||
5. Click "Generate token" and copy the token (starts with ghp_)
|
||||
6. Store it securely - you won't be able to see it again!""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://api.github.com/user",
|
||||
health_check_method="GET",
|
||||
# Credential store mapping
|
||||
credential_id="github",
|
||||
credential_key="access_token",
|
||||
),
|
||||
"hubspot": CredentialSpec(
|
||||
env_var="HUBSPOT_ACCESS_TOKEN",
|
||||
tools=[
|
||||
|
||||
@@ -15,6 +15,21 @@ LLM_CREDENTIALS = {
|
||||
startup_required=False, # MCP server doesn't need LLM credentials
|
||||
help_url="https://console.anthropic.com/settings/keys",
|
||||
description="API key for Anthropic Claude models",
|
||||
# Auth method support
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get an Anthropic API key:
|
||||
1. Go to https://console.anthropic.com/settings/keys
|
||||
2. Sign in or create an Anthropic account
|
||||
3. Click "Create Key"
|
||||
4. Give your key a descriptive name (e.g., "Hive Agent")
|
||||
5. Copy the API key (starts with sk-ant-)
|
||||
6. Store it securely - you won't be able to see the full key again!""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://api.anthropic.com/v1/messages",
|
||||
health_check_method="POST",
|
||||
# Credential store mapping
|
||||
credential_id="anthropic",
|
||||
credential_key="api_key",
|
||||
),
|
||||
# Future LLM providers:
|
||||
# "openai": CredentialSpec(
|
||||
|
||||
@@ -15,6 +15,20 @@ SEARCH_CREDENTIALS = {
|
||||
startup_required=False,
|
||||
help_url="https://brave.com/search/api/",
|
||||
description="API key for Brave Search",
|
||||
# Auth method support
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Brave Search API key:
|
||||
1. Go to https://brave.com/search/api/
|
||||
2. Create a Brave Search API account (or sign in)
|
||||
3. Choose a plan (Free tier includes 2,000 queries/month)
|
||||
4. Navigate to the API Keys section in your dashboard
|
||||
5. Click "Create API Key" and give it a name
|
||||
6. Copy the API key and store it securely""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://api.search.brave.com/res/v1/web/search",
|
||||
# Credential store mapping
|
||||
credential_id="brave_search",
|
||||
credential_key="api_key",
|
||||
),
|
||||
"google_search": CredentialSpec(
|
||||
env_var="GOOGLE_API_KEY",
|
||||
@@ -22,8 +36,24 @@ SEARCH_CREDENTIALS = {
|
||||
node_types=[],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://console.cloud.google.com/",
|
||||
help_url="https://console.cloud.google.com/apis/credentials",
|
||||
description="API key for Google Custom Search",
|
||||
# Auth method support
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Google Custom Search API key:
|
||||
1. Go to https://console.cloud.google.com/apis/credentials
|
||||
2. Create a new project (or select an existing one)
|
||||
3. Enable the "Custom Search API" from the API Library
|
||||
4. Go to Credentials > Create Credentials > API Key
|
||||
5. Copy the generated API key
|
||||
6. (Recommended) Click "Restrict Key" and limit it to the Custom Search API
|
||||
7. Store the key securely""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://www.googleapis.com/customsearch/v1",
|
||||
# Credential store mapping
|
||||
credential_id="google_search",
|
||||
credential_key="api_key",
|
||||
credential_group="google_custom_search",
|
||||
),
|
||||
"google_cse": CredentialSpec(
|
||||
env_var="GOOGLE_CSE_ID",
|
||||
@@ -31,7 +61,22 @@ SEARCH_CREDENTIALS = {
|
||||
node_types=[],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://programmablesearchengine.google.com/",
|
||||
help_url="https://programmablesearchengine.google.com/controlpanel/all",
|
||||
description="Google Custom Search Engine ID",
|
||||
# Auth method support
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Google Custom Search Engine (CSE) ID:
|
||||
1. Go to https://programmablesearchengine.google.com/controlpanel/all
|
||||
2. Click "Add" to create a new search engine
|
||||
3. Under "What to search", select "Search the entire web"
|
||||
4. Give your search engine a name (e.g., "Hive Agent Search")
|
||||
5. Click "Create"
|
||||
6. Copy the Search Engine ID (cx value) from the overview page""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://www.googleapis.com/customsearch/v1",
|
||||
# Credential store mapping
|
||||
credential_id="google_cse",
|
||||
credential_key="api_key",
|
||||
credential_group="google_custom_search",
|
||||
),
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ This provides backward compatibility, allowing existing tools to work unchanged
|
||||
while enabling new features (template resolution, multi-key credentials, etc.).
|
||||
|
||||
Usage:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from framework.credentials import CredentialStore
|
||||
from aden_tools.credentials.store_adapter import CredentialStoreAdapter
|
||||
|
||||
# Create new credential store
|
||||
@@ -31,7 +31,7 @@ from typing import TYPE_CHECKING
|
||||
from .base import CredentialError, CredentialSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
|
||||
class CredentialStoreAdapter:
|
||||
@@ -348,6 +348,41 @@ class CredentialStoreAdapter:
|
||||
|
||||
# --- Factory Methods ---
|
||||
|
||||
@classmethod
|
||||
def default(
|
||||
cls,
|
||||
specs: dict[str, CredentialSpec] | None = None,
|
||||
) -> CredentialStoreAdapter:
|
||||
"""Create adapter with encrypted storage primary and env var fallback."""
|
||||
from framework.credentials import CredentialStore
|
||||
from framework.credentials.storage import (
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
|
||||
if specs is None:
|
||||
from . import CREDENTIAL_SPECS
|
||||
|
||||
specs = CREDENTIAL_SPECS
|
||||
|
||||
env_mapping = {name: spec.env_var for name, spec in specs.items()}
|
||||
|
||||
try:
|
||||
encrypted = EncryptedFileStorage()
|
||||
env = EnvVarStorage(env_mapping)
|
||||
composite = CompositeStorage(primary=encrypted, fallbacks=[env])
|
||||
store = CredentialStore(storage=composite)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"Encrypted credential storage unavailable, falling back to env vars: %s", e
|
||||
)
|
||||
store = CredentialStore.with_env_storage(env_mapping)
|
||||
|
||||
return cls(store=store, specs=specs)
|
||||
|
||||
@classmethod
|
||||
def for_testing(
|
||||
cls,
|
||||
@@ -368,7 +403,7 @@ class CredentialStoreAdapter:
|
||||
credentials = CredentialStoreAdapter.for_testing({"brave_search": "test-key"})
|
||||
assert credentials.get("brave_search") == "test-key"
|
||||
"""
|
||||
from core.framework.credentials import CredentialStore
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
# Convert to CredentialStore.for_testing format
|
||||
# Simple credentials get a single "api_key" key
|
||||
@@ -395,13 +430,14 @@ class CredentialStoreAdapter:
|
||||
Returns:
|
||||
CredentialStoreAdapter using env vars for storage
|
||||
"""
|
||||
from core.framework.credentials import CredentialStore
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
# Build env mapping from specs if not provided
|
||||
if env_mapping is None and specs is None:
|
||||
from . import CREDENTIAL_SPECS
|
||||
if env_mapping is None:
|
||||
if specs is None:
|
||||
from . import CREDENTIAL_SPECS
|
||||
|
||||
specs = CREDENTIAL_SPECS
|
||||
specs = CREDENTIAL_SPECS
|
||||
env_mapping = {name: spec.env_var for name, spec in specs.items()}
|
||||
|
||||
store = CredentialStore.with_env_storage(env_mapping)
|
||||
|
||||
@@ -7,7 +7,7 @@ Usage:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
mcp = FastMCP("my-server")
|
||||
credentials = CredentialStoreAdapter.with_env_storage()
|
||||
credentials = CredentialStoreAdapter.default()
|
||||
register_all_tools(mcp, credentials=credentials)
|
||||
"""
|
||||
|
||||
@@ -38,6 +38,7 @@ from .file_system_toolkits.replace_file_content import (
|
||||
# Import file system toolkits
|
||||
from .file_system_toolkits.view_file import register_tools as register_view_file
|
||||
from .file_system_toolkits.write_to_file import register_tools as register_write_to_file
|
||||
from .github_tool import register_tools as register_github
|
||||
from .hubspot_tool import register_tools as register_hubspot
|
||||
from .pdf_read_tool import register_tools as register_pdf_read
|
||||
from .web_scrape_tool import register_tools as register_web_scrape
|
||||
@@ -67,6 +68,7 @@ def register_all_tools(
|
||||
# Tools that need credentials (pass credentials if provided)
|
||||
# web_search supports multiple providers (Google, Brave) with auto-detection
|
||||
register_web_search(mcp, credentials=credentials)
|
||||
register_github(mcp, credentials=credentials)
|
||||
# email supports multiple providers (Resend) with auto-detection
|
||||
register_email(mcp, credentials=credentials)
|
||||
register_hubspot(mcp, credentials=credentials)
|
||||
@@ -100,6 +102,22 @@ def register_all_tools(
|
||||
"csv_append",
|
||||
"csv_info",
|
||||
"csv_sql",
|
||||
"github_list_repos",
|
||||
"github_get_repo",
|
||||
"github_search_repos",
|
||||
"github_list_issues",
|
||||
"github_get_issue",
|
||||
"github_create_issue",
|
||||
"github_update_issue",
|
||||
"github_list_pull_requests",
|
||||
"github_get_pull_request",
|
||||
"github_create_pull_request",
|
||||
"github_search_code",
|
||||
"github_list_branches",
|
||||
"github_get_branch",
|
||||
"github_list_stargazers",
|
||||
"github_get_user_profile",
|
||||
"github_get_user_emails",
|
||||
"send_email",
|
||||
"send_budget_alert_email",
|
||||
"hubspot_search_contacts",
|
||||
|
||||
@@ -113,6 +113,16 @@ def register_tools(
|
||||
cc_list = _normalize_recipients(cc)
|
||||
bcc_list = _normalize_recipients(bcc)
|
||||
|
||||
# Testing override: redirect all recipients to a single address.
|
||||
# Set EMAIL_OVERRIDE_TO=you@example.com to intercept all outbound mail.
|
||||
override_to = os.getenv("EMAIL_OVERRIDE_TO")
|
||||
if override_to:
|
||||
original_to = to_list
|
||||
to_list = [override_to]
|
||||
cc_list = None
|
||||
bcc_list = None
|
||||
subject = f"[TEST -> {', '.join(original_to)}] {subject}"
|
||||
|
||||
creds = _get_credentials()
|
||||
resend_available = bool(creds["resend_api_key"])
|
||||
|
||||
|
||||
@@ -0,0 +1,646 @@
|
||||
# GitHub Tool
|
||||
|
||||
Interact with GitHub repositories, issues, and pull requests within the Aden agent framework.
|
||||
|
||||
## Installation
|
||||
|
||||
The GitHub tool uses `httpx` which is already included in the base dependencies. No additional installation required.
|
||||
|
||||
## Setup
|
||||
|
||||
You need a GitHub Personal Access Token (PAT) to use this tool.
|
||||
|
||||
### Getting a GitHub Token
|
||||
|
||||
1. Go to https://github.com/settings/tokens
|
||||
2. Click "Generate new token" → "Generate new token (classic)"
|
||||
3. Give your token a descriptive name (e.g., "Aden Agent Framework")
|
||||
4. Select the following scopes:
|
||||
- `repo` - Full control of private repositories (includes all repo scopes)
|
||||
- `read:org` - Read org and team membership (optional, for org access)
|
||||
- `user` - Read user profile data (optional)
|
||||
5. Click "Generate token"
|
||||
6. Copy the token (starts with `ghp_`)
|
||||
|
||||
**Note:** Keep your token secure! It provides access to your GitHub account.
|
||||
|
||||
### Configuration
|
||||
|
||||
Set the token as an environment variable:
|
||||
|
||||
```bash
|
||||
export GITHUB_TOKEN=ghp_your_token_here
|
||||
```
|
||||
|
||||
Or configure via the credential store (recommended for production).
|
||||
|
||||
## Available Functions
|
||||
|
||||
### Repository Management
|
||||
|
||||
#### `github_list_repos`
|
||||
|
||||
List repositories for a user or the authenticated user.
|
||||
|
||||
**Parameters:**
|
||||
- `username` (str, optional): GitHub username (if None, lists authenticated user's repos)
|
||||
- `visibility` (str, optional): Repository visibility ("all", "public", "private", default "all")
|
||||
- `sort` (str, optional): Sort order ("created", "updated", "pushed", "full_name", default "updated")
|
||||
- `limit` (int, optional): Maximum number of repositories (1-100, default 30)
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": [
|
||||
{
|
||||
"id": 123456,
|
||||
"name": "my-repo",
|
||||
"full_name": "username/my-repo",
|
||||
"description": "A cool project",
|
||||
"private": False,
|
||||
"html_url": "https://github.com/username/my-repo",
|
||||
"stargazers_count": 42,
|
||||
"forks_count": 7
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# List your repositories
|
||||
result = github_list_repos()
|
||||
|
||||
# List another user's public repositories
|
||||
result = github_list_repos(username="octocat", limit=10)
|
||||
```
|
||||
|
||||
#### `github_get_repo`
|
||||
|
||||
Get detailed information about a specific repository.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner (username or organization)
|
||||
- `repo` (str): Repository name
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"id": 123456,
|
||||
"name": "my-repo",
|
||||
"full_name": "owner/my-repo",
|
||||
"description": "Project description",
|
||||
"private": False,
|
||||
"default_branch": "main",
|
||||
"stargazers_count": 100,
|
||||
"forks_count": 25,
|
||||
"language": "Python",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-31T12:00:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
result = github_get_repo(owner="adenhq", repo="hive")
|
||||
print(f"Stars: {result['data']['stargazers_count']}")
|
||||
```
|
||||
|
||||
#### `github_search_repos`
|
||||
|
||||
Search for repositories on GitHub.
|
||||
|
||||
**Parameters:**
|
||||
- `query` (str): Search query (supports GitHub search syntax)
|
||||
- `sort` (str, optional): Sort field ("stars", "forks", "updated")
|
||||
- `limit` (int, optional): Maximum results (1-100, default 30)
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"total_count": 1000,
|
||||
"items": [
|
||||
{
|
||||
"id": 123,
|
||||
"name": "awesome-python",
|
||||
"full_name": "user/awesome-python",
|
||||
"description": "A curated list",
|
||||
"stargazers_count": 5000
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Search for Python repos with many stars
|
||||
result = github_search_repos(
|
||||
query="language:python stars:>1000",
|
||||
sort="stars",
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Search in a specific organization
|
||||
result = github_search_repos(query="org:adenhq agent")
|
||||
```
|
||||
|
||||
### Issue Management
|
||||
|
||||
#### `github_list_issues`
|
||||
|
||||
List issues for a repository.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `state` (str, optional): Issue state ("open", "closed", "all", default "open")
|
||||
- `limit` (int, optional): Maximum issues (1-100, default 30)
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": [
|
||||
{
|
||||
"number": 42,
|
||||
"title": "Bug in feature X",
|
||||
"state": "open",
|
||||
"user": {"login": "username"},
|
||||
"labels": [{"name": "bug"}],
|
||||
"created_at": "2024-01-30T10:00:00Z",
|
||||
"html_url": "https://github.com/owner/repo/issues/42"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# List open issues
|
||||
issues = github_list_issues(owner="adenhq", repo="hive", state="open")
|
||||
for issue in issues["data"]:
|
||||
print(f"#{issue['number']}: {issue['title']}")
|
||||
```
|
||||
|
||||
#### `github_get_issue`
|
||||
|
||||
Get a specific issue by number.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `issue_number` (int): Issue number
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"number": 42,
|
||||
"title": "Issue title",
|
||||
"body": "Detailed description...",
|
||||
"state": "open",
|
||||
"user": {"login": "username"},
|
||||
"assignees": [],
|
||||
"labels": [{"name": "enhancement"}],
|
||||
"comments": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
issue = github_get_issue(owner="adenhq", repo="hive", issue_number=2805)
|
||||
print(issue["data"]["body"])
|
||||
```
|
||||
|
||||
#### `github_create_issue`
|
||||
|
||||
Create a new issue in a repository.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `title` (str): Issue title
|
||||
- `body` (str, optional): Issue description (supports Markdown)
|
||||
- `labels` (list[str], optional): List of label names
|
||||
- `assignees` (list[str], optional): List of usernames to assign
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"number": 43,
|
||||
"title": "New issue",
|
||||
"html_url": "https://github.com/owner/repo/issues/43"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
result = github_create_issue(
|
||||
owner="myorg",
|
||||
repo="myrepo",
|
||||
title="Add new feature",
|
||||
body="## Description\n\nWe need to add...",
|
||||
labels=["enhancement", "help wanted"],
|
||||
assignees=["developer1"]
|
||||
)
|
||||
print(f"Created issue #{result['data']['number']}")
|
||||
```
|
||||
|
||||
#### `github_update_issue`
|
||||
|
||||
Update an existing issue.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `issue_number` (int): Issue number
|
||||
- `title` (str, optional): New title
|
||||
- `body` (str, optional): New body
|
||||
- `state` (str, optional): New state ("open" or "closed")
|
||||
- `labels` (list[str], optional): New list of label names
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"number": 43,
|
||||
"title": "Updated title",
|
||||
"state": "closed"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Close an issue
|
||||
result = github_update_issue(
|
||||
owner="myorg",
|
||||
repo="myrepo",
|
||||
issue_number=43,
|
||||
state="closed",
|
||||
body="Fixed in PR #44"
|
||||
)
|
||||
```
|
||||
|
||||
### Pull Request Management
|
||||
|
||||
#### `github_list_pull_requests`
|
||||
|
||||
List pull requests for a repository.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `state` (str, optional): PR state ("open", "closed", "all", default "open")
|
||||
- `limit` (int, optional): Maximum PRs (1-100, default 30)
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": [
|
||||
{
|
||||
"number": 10,
|
||||
"title": "Add new feature",
|
||||
"state": "open",
|
||||
"user": {"login": "contributor"},
|
||||
"head": {"ref": "feature-branch"},
|
||||
"base": {"ref": "main"},
|
||||
"html_url": "https://github.com/owner/repo/pull/10"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
prs = github_list_pull_requests(owner="adenhq", repo="hive", state="open")
|
||||
for pr in prs["data"]:
|
||||
print(f"PR #{pr['number']}: {pr['title']}")
|
||||
```
|
||||
|
||||
#### `github_get_pull_request`
|
||||
|
||||
Get a specific pull request.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `pull_number` (int): Pull request number
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"number": 10,
|
||||
"title": "PR title",
|
||||
"body": "Description...",
|
||||
"state": "open",
|
||||
"merged": False,
|
||||
"draft": False,
|
||||
"head": {"ref": "feature"},
|
||||
"base": {"ref": "main"}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
pr = github_get_pull_request(owner="adenhq", repo="hive", pull_number=2814)
|
||||
print(f"PR by {pr['data']['user']['login']}")
|
||||
```
|
||||
|
||||
#### `github_create_pull_request`
|
||||
|
||||
Create a new pull request.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `title` (str): Pull request title
|
||||
- `head` (str): Branch with your changes (e.g., "my-feature")
|
||||
- `base` (str): Branch to merge into (e.g., "main")
|
||||
- `body` (str, optional): Pull request description (supports Markdown)
|
||||
- `draft` (bool, optional): Create as draft PR (default False)
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"number": 11,
|
||||
"title": "New PR",
|
||||
"html_url": "https://github.com/owner/repo/pull/11"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
result = github_create_pull_request(
|
||||
owner="myorg",
|
||||
repo="myrepo",
|
||||
title="feat: Add GitHub integration tool",
|
||||
head="feature/github-tool",
|
||||
base="main",
|
||||
body="## Summary\n\n- Implements GitHub API integration\n- Adds 30+ tests",
|
||||
draft=False
|
||||
)
|
||||
print(f"Created PR: {result['data']['html_url']}")
|
||||
```
|
||||
|
||||
### Search
|
||||
|
||||
#### `github_search_code`
|
||||
|
||||
Search code across GitHub.
|
||||
|
||||
**Parameters:**
|
||||
- `query` (str): Search query (supports GitHub code search syntax)
|
||||
- `limit` (int, optional): Maximum results (1-100, default 30)
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"total_count": 50,
|
||||
"items": [
|
||||
{
|
||||
"name": "example.py",
|
||||
"path": "src/example.py",
|
||||
"repository": {
|
||||
"full_name": "owner/repo"
|
||||
},
|
||||
"html_url": "https://github.com/owner/repo/blob/main/src/example.py"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Search for function usage
|
||||
result = github_search_code(
|
||||
query="register_tools language:python repo:adenhq/hive"
|
||||
)
|
||||
|
||||
# Search for specific code pattern
|
||||
result = github_search_code(query="FastMCP extension:py")
|
||||
```
|
||||
|
||||
### Branch Management
|
||||
|
||||
#### `github_list_branches`
|
||||
|
||||
List branches for a repository.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `limit` (int, optional): Maximum branches (1-100, default 30)
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": [
|
||||
{
|
||||
"name": "main",
|
||||
"protected": True,
|
||||
"commit": {"sha": "abc123..."}
|
||||
},
|
||||
{
|
||||
"name": "develop",
|
||||
"protected": False
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
branches = github_list_branches(owner="adenhq", repo="hive")
|
||||
for branch in branches["data"]:
|
||||
print(f"Branch: {branch['name']}")
|
||||
```
|
||||
|
||||
#### `github_get_branch`
|
||||
|
||||
Get information about a specific branch.
|
||||
|
||||
**Parameters:**
|
||||
- `owner` (str): Repository owner
|
||||
- `repo` (str): Repository name
|
||||
- `branch` (str): Branch name
|
||||
|
||||
**Returns:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"name": "main",
|
||||
"protected": True,
|
||||
"commit": {
|
||||
"sha": "abc123...",
|
||||
"commit": {
|
||||
"message": "Latest commit message"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
main_branch = github_get_branch(owner="adenhq", repo="hive", branch="main")
|
||||
print(f"Latest commit: {main_branch['data']['commit']['sha']}")
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
All functions return a dict with an `error` key if something goes wrong:
|
||||
|
||||
```python
|
||||
{
|
||||
"error": "GitHub API error (HTTP 404): Not Found"
|
||||
}
|
||||
```
|
||||
|
||||
Common errors:
|
||||
- `not configured` - No GitHub token provided
|
||||
- `Invalid or expired GitHub token` - Token authentication failed (401)
|
||||
- `Forbidden` - Insufficient permissions or rate limit exceeded (403)
|
||||
- `Resource not found` - Repository, issue, or PR doesn't exist (404)
|
||||
- `Validation error` - Invalid request parameters (422)
|
||||
- `Request timed out` - Network timeout
|
||||
- `Network error` - Connection issues
|
||||
|
||||
## Security
|
||||
|
||||
- Personal Access Tokens are never logged or exposed
|
||||
- All API calls use HTTPS
|
||||
- Tokens are retrieved from secure credential store or environment variables
|
||||
- Fine-grained permissions can be configured via GitHub token scopes
|
||||
|
||||
## Use Cases
|
||||
|
||||
### Automated Issue Management
|
||||
```python
|
||||
# Create issues from bug reports
|
||||
github_create_issue(
|
||||
owner="myorg",
|
||||
repo="myapp",
|
||||
title="Bug: Login fails on mobile",
|
||||
body="## Steps to reproduce\n1. Open app on mobile...",
|
||||
labels=["bug", "mobile"]
|
||||
)
|
||||
```
|
||||
|
||||
### CI/CD Integration
|
||||
```python
|
||||
# Create PR after automated changes
|
||||
github_create_pull_request(
|
||||
owner="myorg",
|
||||
repo="myrepo",
|
||||
title="chore: Update dependencies",
|
||||
head="bot/update-deps",
|
||||
base="main",
|
||||
body="Automated dependency updates"
|
||||
)
|
||||
```
|
||||
|
||||
### Repository Analytics
|
||||
```python
|
||||
# Analyze repository activity
|
||||
repo = github_get_repo(owner="adenhq", repo="hive")
|
||||
issues = github_list_issues(owner="adenhq", repo="hive", state="open")
|
||||
prs = github_list_pull_requests(owner="adenhq", repo="hive", state="open")
|
||||
|
||||
print(f"Stars: {repo['data']['stargazers_count']}")
|
||||
print(f"Open Issues: {len(issues['data'])}")
|
||||
print(f"Open PRs: {len(prs['data'])}")
|
||||
```
|
||||
|
||||
### Code Discovery
|
||||
```python
|
||||
# Find examples of API usage
|
||||
results = github_search_code(
|
||||
query="register_tools language:python",
|
||||
limit=50
|
||||
)
|
||||
for item in results["data"]["items"]:
|
||||
print(f"Found in: {item['repository']['full_name']}")
|
||||
```
|
||||
|
||||
### Project Automation
|
||||
```python
|
||||
# Auto-close stale issues
|
||||
issues = github_list_issues(owner="myorg", repo="myrepo", state="open")
|
||||
for issue in issues["data"]:
|
||||
# Check if stale (custom logic)
|
||||
if is_stale(issue):
|
||||
github_update_issue(
|
||||
owner="myorg",
|
||||
repo="myrepo",
|
||||
issue_number=issue["number"],
|
||||
state="closed",
|
||||
body="Closing due to inactivity"
|
||||
)
|
||||
```
|
||||
|
||||
## Rate Limits
|
||||
|
||||
GitHub enforces rate limits on API calls:
|
||||
- **Authenticated requests**: 5,000 requests per hour
|
||||
- **Search API**: 30 requests per minute
|
||||
- **Unauthenticated requests**: 60 requests per hour (not applicable with token)
|
||||
|
||||
The tool handles rate limit errors gracefully with appropriate error messages. Monitor your usage at: https://api.github.com/rate_limit
|
||||
|
||||
## GitHub Search Syntax
|
||||
|
||||
For `github_search_repos` and `github_search_code`, you can use advanced search qualifiers:
|
||||
|
||||
### Repository Search
|
||||
- `language:python` - Filter by language
|
||||
- `stars:>1000` - Repositories with more than 1000 stars
|
||||
- `forks:>100` - Repositories with more than 100 forks
|
||||
- `org:adenhq` - Search within an organization
|
||||
- `topic:machine-learning` - Filter by topic
|
||||
- `created:>2024-01-01` - Created after date
|
||||
|
||||
### Code Search
|
||||
- `repo:owner/repo` - Search in specific repository
|
||||
- `extension:py` - Filter by file extension
|
||||
- `path:src/` - Search in specific path
|
||||
- `language:python` - Filter by language
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Find popular Python ML projects
|
||||
github_search_repos(
|
||||
query="language:python topic:machine-learning stars:>5000",
|
||||
sort="stars"
|
||||
)
|
||||
|
||||
# Find FastMCP usage examples
|
||||
github_search_code(
|
||||
query="FastMCP extension:py"
|
||||
)
|
||||
```
|
||||
@@ -0,0 +1,5 @@
|
||||
"""GitHub Tool package."""
|
||||
|
||||
from .github_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,991 @@
|
||||
"""
|
||||
GitHub Tool - Interact with GitHub repositories, issues, and pull requests.
|
||||
|
||||
Supports:
|
||||
- Personal Access Tokens (GITHUB_TOKEN / ghp_...)
|
||||
- OAuth tokens via the credential store
|
||||
|
||||
API Reference: https://docs.github.com/en/rest
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
GITHUB_API_BASE = "https://api.github.com"
|
||||
|
||||
|
||||
def _sanitize_path_param(param: str, param_name: str = "parameter") -> str:
|
||||
"""
|
||||
Sanitize URL path parameters to prevent path traversal.
|
||||
|
||||
Args:
|
||||
param: The parameter value to sanitize
|
||||
param_name: Name of the parameter (for error messages)
|
||||
|
||||
Returns:
|
||||
The sanitized parameter
|
||||
|
||||
Raises:
|
||||
ValueError: If parameter contains invalid characters
|
||||
"""
|
||||
if "/" in param or ".." in param:
|
||||
raise ValueError(f"Invalid {param_name}: cannot contain '/' or '..'")
|
||||
return param
|
||||
|
||||
|
||||
def _sanitize_error_message(error: Exception) -> str:
|
||||
"""
|
||||
Sanitize error messages to prevent token leaks.
|
||||
|
||||
httpx.RequestError can include headers in the exception message,
|
||||
which may expose the Bearer token.
|
||||
|
||||
Args:
|
||||
error: The exception to sanitize
|
||||
|
||||
Returns:
|
||||
A safe error message without sensitive information
|
||||
"""
|
||||
error_str = str(error)
|
||||
# Remove any Authorization headers or Bearer tokens
|
||||
if "Authorization" in error_str or "Bearer" in error_str:
|
||||
return "Network error occurred"
|
||||
return f"Network error: {error_str}"
|
||||
|
||||
|
||||
class _GitHubClient:
|
||||
"""Internal client wrapping GitHub REST API v3 calls."""
|
||||
|
||||
def __init__(self, token: str):
|
||||
self._token = token
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self._token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
def _handle_response(self, response: httpx.Response) -> dict[str, Any]:
|
||||
"""Handle GitHub API response format."""
|
||||
if response.status_code == 401:
|
||||
return {"error": "Invalid or expired GitHub token"}
|
||||
if response.status_code == 403:
|
||||
return {"error": "Forbidden - check token permissions or rate limit"}
|
||||
if response.status_code == 404:
|
||||
return {"error": "Resource not found"}
|
||||
if response.status_code == 422:
|
||||
try:
|
||||
detail = response.json().get("message", "Validation failed")
|
||||
except Exception:
|
||||
detail = "Validation failed"
|
||||
return {"error": f"Validation error: {detail}"}
|
||||
if response.status_code >= 400:
|
||||
try:
|
||||
detail = response.json().get("message", response.text)
|
||||
except Exception:
|
||||
detail = response.text
|
||||
return {"error": f"GitHub API error (HTTP {response.status_code}): {detail}"}
|
||||
|
||||
try:
|
||||
return {"success": True, "data": response.json()}
|
||||
except Exception:
|
||||
return {"success": True, "data": {}}
|
||||
|
||||
# --- Repositories ---
|
||||
|
||||
def list_repos(
|
||||
self,
|
||||
username: str | None = None,
|
||||
visibility: str = "all",
|
||||
sort: str = "updated",
|
||||
limit: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
"""List repositories for a user or authenticated user."""
|
||||
if username:
|
||||
username = _sanitize_path_param(username, "username")
|
||||
url = f"{GITHUB_API_BASE}/users/{username}/repos"
|
||||
else:
|
||||
url = f"{GITHUB_API_BASE}/user/repos"
|
||||
|
||||
params = {
|
||||
"visibility": visibility,
|
||||
"sort": sort,
|
||||
"per_page": min(limit, 100),
|
||||
}
|
||||
|
||||
response = httpx.get(
|
||||
url,
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def get_repo(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get repository information."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}",
|
||||
headers=self._headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def search_repos(
|
||||
self,
|
||||
query: str,
|
||||
sort: str | None = None,
|
||||
limit: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
"""Search for repositories."""
|
||||
params: dict[str, Any] = {
|
||||
"q": query,
|
||||
"per_page": min(limit, 100),
|
||||
}
|
||||
if sort:
|
||||
params["sort"] = sort
|
||||
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/search/repositories",
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
# --- Issues ---
|
||||
|
||||
def list_issues(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: str = "open",
|
||||
page: int = 1,
|
||||
limit: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
"""List issues for a repository."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
params = {
|
||||
"state": state,
|
||||
"per_page": min(limit, 100),
|
||||
"page": max(1, page),
|
||||
}
|
||||
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/issues",
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def get_issue(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
issue_number: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Get a specific issue."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/issues/{issue_number}",
|
||||
headers=self._headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def create_issue(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str | None = None,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new issue."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
payload: dict[str, Any] = {"title": title}
|
||||
if body:
|
||||
payload["body"] = body
|
||||
if labels:
|
||||
payload["labels"] = labels
|
||||
if assignees:
|
||||
payload["assignees"] = assignees
|
||||
|
||||
response = httpx.post(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/issues",
|
||||
headers=self._headers,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def update_issue(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
issue_number: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: str | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing issue."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
payload: dict[str, Any] = {}
|
||||
if title:
|
||||
payload["title"] = title
|
||||
if body is not None:
|
||||
payload["body"] = body
|
||||
if state:
|
||||
payload["state"] = state
|
||||
if labels is not None:
|
||||
payload["labels"] = labels
|
||||
|
||||
response = httpx.patch(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/issues/{issue_number}",
|
||||
headers=self._headers,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
# --- Pull Requests ---
|
||||
|
||||
def list_pull_requests(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: str = "open",
|
||||
page: int = 1,
|
||||
limit: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
"""List pull requests for a repository."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
params = {
|
||||
"state": state,
|
||||
"per_page": min(limit, 100),
|
||||
"page": max(1, page),
|
||||
}
|
||||
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/pulls",
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def get_pull_request(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pull_number: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Get a specific pull request."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/pulls/{pull_number}",
|
||||
headers=self._headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def create_pull_request(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
head: str,
|
||||
base: str,
|
||||
body: str | None = None,
|
||||
draft: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new pull request."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
payload: dict[str, Any] = {
|
||||
"title": title,
|
||||
"head": head,
|
||||
"base": base,
|
||||
"draft": draft,
|
||||
}
|
||||
if body:
|
||||
payload["body"] = body
|
||||
|
||||
response = httpx.post(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/pulls",
|
||||
headers=self._headers,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
# --- Search ---
|
||||
|
||||
def search_code(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
"""Search code across GitHub."""
|
||||
params = {
|
||||
"q": query,
|
||||
"per_page": min(limit, 100),
|
||||
}
|
||||
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/search/code",
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
# --- Branches ---
|
||||
|
||||
def list_branches(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
limit: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
"""List branches for a repository."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
params = {
|
||||
"per_page": min(limit, 100),
|
||||
}
|
||||
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/branches",
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def get_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get a specific branch."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
branch = _sanitize_path_param(branch, "branch")
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/branches/{branch}",
|
||||
headers=self._headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
# --- Stargazers ---
|
||||
|
||||
def list_stargazers(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
page: int = 1,
|
||||
limit: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
"""List users who starred a repository."""
|
||||
owner = _sanitize_path_param(owner, "owner")
|
||||
repo = _sanitize_path_param(repo, "repo")
|
||||
params = {
|
||||
"per_page": min(limit, 100),
|
||||
"page": max(1, page),
|
||||
}
|
||||
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/repos/{owner}/{repo}/stargazers",
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
# --- Users ---
|
||||
|
||||
def get_user_profile(
|
||||
self,
|
||||
username: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get a user's public profile."""
|
||||
username = _sanitize_path_param(username, "username")
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/users/{username}",
|
||||
headers=self._headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
def get_user_emails(
|
||||
self,
|
||||
username: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Find a user's email addresses from their public activity.
|
||||
|
||||
The /users/{username} endpoint only returns the public email
|
||||
(which most users leave blank). This method also checks the
|
||||
user's recent public events for commit-author emails.
|
||||
"""
|
||||
username = _sanitize_path_param(username, "username")
|
||||
|
||||
emails: dict[str, str] = {} # email -> source
|
||||
|
||||
# 1. Check profile for public email
|
||||
profile = self.get_user_profile(username)
|
||||
if isinstance(profile, dict) and "error" not in profile:
|
||||
if profile.get("email"):
|
||||
emails[profile["email"]] = "profile"
|
||||
|
||||
# 2. Check recent public events for commit emails
|
||||
response = httpx.get(
|
||||
f"{GITHUB_API_BASE}/users/{username}/events/public",
|
||||
headers=self._headers,
|
||||
params={"per_page": 30},
|
||||
timeout=30.0,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
for event in response.json():
|
||||
if event.get("type") != "PushEvent":
|
||||
continue
|
||||
for commit in event.get("payload", {}).get("commits", []):
|
||||
author = commit.get("author", {})
|
||||
email = author.get("email", "")
|
||||
if email and "@" in email and "noreply" not in email.lower():
|
||||
emails[email] = "commit"
|
||||
|
||||
return {
|
||||
"username": username,
|
||||
"emails": [{"email": e, "source": s} for e, s in emails.items()],
|
||||
"total": len(emails),
|
||||
}
|
||||
|
||||
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register GitHub tools with the MCP server."""
|
||||
|
||||
def _get_token() -> str | None:
|
||||
"""Get GitHub token from credential manager or environment."""
|
||||
if credentials is not None:
|
||||
token = credentials.get("github")
|
||||
if token is not None and not isinstance(token, str):
|
||||
raise TypeError(
|
||||
f"Expected string from credentials.get('github'), got {type(token).__name__}"
|
||||
)
|
||||
return token
|
||||
return os.getenv("GITHUB_TOKEN")
|
||||
|
||||
def _get_client() -> _GitHubClient | dict[str, str]:
|
||||
"""Get a GitHub client, or return an error dict if no credentials."""
|
||||
token = _get_token()
|
||||
if not token:
|
||||
return {
|
||||
"error": "GitHub credentials not configured",
|
||||
"help": (
|
||||
"Set GITHUB_TOKEN environment variable "
|
||||
"or configure via credential store. "
|
||||
"Get a token at https://github.com/settings/tokens"
|
||||
),
|
||||
}
|
||||
return _GitHubClient(token)
|
||||
|
||||
# --- Repositories ---
|
||||
|
||||
@mcp.tool()
|
||||
def github_list_repos(
|
||||
username: str | None = None,
|
||||
visibility: str = "all",
|
||||
sort: str = "updated",
|
||||
limit: int = 30,
|
||||
) -> dict:
|
||||
"""
|
||||
List repositories for a user or the authenticated user.
|
||||
|
||||
Args:
|
||||
username: GitHub username (if None, lists authenticated user's repos)
|
||||
visibility: Repository visibility filter ("all", "public", "private")
|
||||
sort: Sort order ("created", "updated", "pushed", "full_name")
|
||||
limit: Maximum number of repositories to return (1-100, default 30)
|
||||
|
||||
Returns:
|
||||
Dict with list of repositories or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.list_repos(username, visibility, sort, limit)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_get_repo(
|
||||
owner: str,
|
||||
repo: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get information about a specific repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner (username or organization)
|
||||
repo: Repository name
|
||||
|
||||
Returns:
|
||||
Dict with repository information or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.get_repo(owner, repo)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_search_repos(
|
||||
query: str,
|
||||
sort: str | None = None,
|
||||
limit: int = 30,
|
||||
) -> dict:
|
||||
"""
|
||||
Search for repositories on GitHub.
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "language:python stars:>1000")
|
||||
sort: Sort field ("stars", "forks", "updated")
|
||||
limit: Maximum number of results (1-100, default 30)
|
||||
|
||||
Returns:
|
||||
Dict with search results or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.search_repos(query, sort, limit)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
# --- Issues ---
|
||||
|
||||
@mcp.tool()
|
||||
def github_list_issues(
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: str = "open",
|
||||
page: int = 1,
|
||||
limit: int = 30,
|
||||
) -> dict:
|
||||
"""
|
||||
List issues for a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
state: Issue state ("open", "closed", "all")
|
||||
page: Page number for pagination (1-based, default 1)
|
||||
limit: Maximum number of issues per page (1-100, default 30)
|
||||
|
||||
Returns:
|
||||
Dict with list of issues or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.list_issues(owner, repo, state, page, limit)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_get_issue(
|
||||
owner: str,
|
||||
repo: str,
|
||||
issue_number: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Get a specific issue.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
issue_number: Issue number
|
||||
|
||||
Returns:
|
||||
Dict with issue information or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.get_issue(owner, repo, issue_number)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_create_issue(
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str | None = None,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a new issue in a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
title: Issue title
|
||||
body: Issue body/description (supports Markdown)
|
||||
labels: List of label names to apply
|
||||
assignees: List of usernames to assign
|
||||
|
||||
Returns:
|
||||
Dict with created issue information or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.create_issue(owner, repo, title, body, labels, assignees)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_update_issue(
|
||||
owner: str,
|
||||
repo: str,
|
||||
issue_number: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: str | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update an existing issue.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
issue_number: Issue number
|
||||
title: New issue title
|
||||
body: New issue body
|
||||
state: New state ("open" or "closed")
|
||||
labels: New list of label names
|
||||
|
||||
Returns:
|
||||
Dict with updated issue information or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.update_issue(owner, repo, issue_number, title, body, state, labels)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
# --- Pull Requests ---
|
||||
|
||||
@mcp.tool()
|
||||
def github_list_pull_requests(
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: str = "open",
|
||||
page: int = 1,
|
||||
limit: int = 30,
|
||||
) -> dict:
|
||||
"""
|
||||
List pull requests for a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
state: PR state ("open", "closed", "all")
|
||||
page: Page number for pagination (1-based, default 1)
|
||||
limit: Maximum number of PRs per page (1-100, default 30)
|
||||
|
||||
Returns:
|
||||
Dict with list of pull requests or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.list_pull_requests(owner, repo, state, page, limit)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_get_pull_request(
|
||||
owner: str,
|
||||
repo: str,
|
||||
pull_number: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Get a specific pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pull_number: Pull request number
|
||||
|
||||
Returns:
|
||||
Dict with pull request information or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.get_pull_request(owner, repo, pull_number)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_create_pull_request(
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
head: str,
|
||||
base: str,
|
||||
body: str | None = None,
|
||||
draft: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a new pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
title: Pull request title
|
||||
head: The name of the branch where your changes are (e.g., "my-feature")
|
||||
base: The name of the branch you want to merge into (e.g., "main")
|
||||
body: Pull request description (supports Markdown)
|
||||
draft: Whether to create as a draft PR
|
||||
|
||||
Returns:
|
||||
Dict with created pull request information or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.create_pull_request(owner, repo, title, head, base, body, draft)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
# --- Search ---
|
||||
|
||||
@mcp.tool()
|
||||
def github_search_code(
|
||||
query: str,
|
||||
limit: int = 30,
|
||||
) -> dict:
|
||||
"""
|
||||
Search code across GitHub.
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "addClass repo:jquery/jquery")
|
||||
limit: Maximum number of results (1-100, default 30)
|
||||
|
||||
Returns:
|
||||
Dict with search results or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.search_code(query, limit)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
# --- Branches ---
|
||||
|
||||
@mcp.tool()
|
||||
def github_list_branches(
|
||||
owner: str,
|
||||
repo: str,
|
||||
limit: int = 30,
|
||||
) -> dict:
|
||||
"""
|
||||
List branches for a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
limit: Maximum number of branches to return (1-100, default 30)
|
||||
|
||||
Returns:
|
||||
Dict with list of branches or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.list_branches(owner, repo, limit)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_get_branch(
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get information about a specific branch.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
branch: Branch name
|
||||
|
||||
Returns:
|
||||
Dict with branch information or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.get_branch(owner, repo, branch)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
# --- Stargazers ---
|
||||
|
||||
@mcp.tool()
|
||||
def github_list_stargazers(
|
||||
owner: str,
|
||||
repo: str,
|
||||
page: int = 1,
|
||||
limit: int = 30,
|
||||
) -> dict:
|
||||
"""
|
||||
List users who starred a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
page: Page number for pagination (1-based, default 1)
|
||||
limit: Maximum number of stargazers per page (1-100, default 30)
|
||||
|
||||
Returns:
|
||||
Dict with list of stargazers or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.list_stargazers(owner, repo, page, limit)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
# --- Users ---
|
||||
|
||||
@mcp.tool()
|
||||
def github_get_user_profile(
|
||||
username: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get a GitHub user's public profile including name, bio, company, location, and email.
|
||||
|
||||
Args:
|
||||
username: GitHub username
|
||||
|
||||
Returns:
|
||||
Dict with user profile information or error
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.get_user_profile(username)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
|
||||
@mcp.tool()
|
||||
def github_get_user_emails(
|
||||
username: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Find a GitHub user's email addresses from their public activity.
|
||||
|
||||
Checks both the user's profile (public email) and their recent
|
||||
push events for commit-author emails. Filters out noreply addresses.
|
||||
|
||||
Args:
|
||||
username: GitHub username
|
||||
|
||||
Returns:
|
||||
Dict with emails list (each with email and source), total count
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
return client.get_user_emails(username)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": _sanitize_error_message(e)}
|
||||
@@ -415,13 +415,13 @@ class TestDealTools:
|
||||
|
||||
class TestHubSpotOAuth2Provider:
|
||||
def test_provider_id(self):
|
||||
from core.framework.credentials.oauth2.hubspot_provider import HubSpotOAuth2Provider
|
||||
from framework.credentials.oauth2.hubspot_provider import HubSpotOAuth2Provider
|
||||
|
||||
provider = HubSpotOAuth2Provider(client_id="cid", client_secret="csecret")
|
||||
assert provider.provider_id == "hubspot_oauth2"
|
||||
|
||||
def test_default_scopes(self):
|
||||
from core.framework.credentials.oauth2.hubspot_provider import (
|
||||
from framework.credentials.oauth2.hubspot_provider import (
|
||||
HUBSPOT_DEFAULT_SCOPES,
|
||||
HubSpotOAuth2Provider,
|
||||
)
|
||||
@@ -430,7 +430,7 @@ class TestHubSpotOAuth2Provider:
|
||||
assert provider.config.default_scopes == HUBSPOT_DEFAULT_SCOPES
|
||||
|
||||
def test_custom_scopes(self):
|
||||
from core.framework.credentials.oauth2.hubspot_provider import HubSpotOAuth2Provider
|
||||
from framework.credentials.oauth2.hubspot_provider import HubSpotOAuth2Provider
|
||||
|
||||
provider = HubSpotOAuth2Provider(
|
||||
client_id="cid",
|
||||
@@ -440,7 +440,7 @@ class TestHubSpotOAuth2Provider:
|
||||
assert provider.config.default_scopes == ["crm.objects.contacts.read"]
|
||||
|
||||
def test_endpoints(self):
|
||||
from core.framework.credentials.oauth2.hubspot_provider import (
|
||||
from framework.credentials.oauth2.hubspot_provider import (
|
||||
HUBSPOT_AUTHORIZATION_URL,
|
||||
HUBSPOT_TOKEN_URL,
|
||||
HubSpotOAuth2Provider,
|
||||
@@ -451,15 +451,15 @@ class TestHubSpotOAuth2Provider:
|
||||
assert provider.config.authorization_url == HUBSPOT_AUTHORIZATION_URL
|
||||
|
||||
def test_supported_types(self):
|
||||
from core.framework.credentials.models import CredentialType
|
||||
from core.framework.credentials.oauth2.hubspot_provider import HubSpotOAuth2Provider
|
||||
from framework.credentials.models import CredentialType
|
||||
from framework.credentials.oauth2.hubspot_provider import HubSpotOAuth2Provider
|
||||
|
||||
provider = HubSpotOAuth2Provider(client_id="cid", client_secret="csecret")
|
||||
assert CredentialType.OAUTH2 in provider.supported_types
|
||||
|
||||
def test_validate_no_access_token(self):
|
||||
from core.framework.credentials.models import CredentialObject
|
||||
from core.framework.credentials.oauth2.hubspot_provider import HubSpotOAuth2Provider
|
||||
from framework.credentials.models import CredentialObject
|
||||
from framework.credentials.oauth2.hubspot_provider import HubSpotOAuth2Provider
|
||||
|
||||
provider = HubSpotOAuth2Provider(client_id="cid", client_secret="csecret")
|
||||
cred = CredentialObject(id="test")
|
||||
|
||||
@@ -25,8 +25,6 @@ pip install playwright playwright-stealth
|
||||
playwright install chromium
|
||||
```
|
||||
|
||||
In Docker, add `RUN playwright install chromium --with-deps` to the Dockerfile.
|
||||
|
||||
## Environment Variables
|
||||
|
||||
This tool does not require any environment variables.
|
||||
|
||||
@@ -10,6 +10,17 @@ from aden_tools.credentials import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_dotenv(tmp_path, monkeypatch):
|
||||
"""Isolate tests from the project .env file.
|
||||
|
||||
EnvVarStorage falls back to reading Path.cwd()/.env when a key is
|
||||
missing from os.environ. Changing cwd to a temp dir ensures
|
||||
monkeypatch.delenv() truly simulates a missing credential.
|
||||
"""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
|
||||
class TestCredentialStoreAdapter:
|
||||
"""Tests for CredentialStoreAdapter class."""
|
||||
|
||||
@@ -438,3 +449,38 @@ class TestStartupValidation:
|
||||
|
||||
# Should not raise
|
||||
creds.validate_startup()
|
||||
|
||||
|
||||
class TestSpecCompleteness:
|
||||
"""Tests that all credential specs have required fields populated."""
|
||||
|
||||
def test_direct_api_key_specs_have_instructions(self):
|
||||
"""All specs with direct_api_key_supported=True have non-empty api_key_instructions."""
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
if spec.direct_api_key_supported:
|
||||
assert spec.api_key_instructions.strip(), (
|
||||
f"Credential '{name}' has direct_api_key_supported=True "
|
||||
f"but empty api_key_instructions"
|
||||
)
|
||||
|
||||
def test_all_specs_have_credential_id(self):
|
||||
"""All credential specs have a non-empty credential_id."""
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
assert spec.credential_id, f"Credential '{name}' is missing credential_id"
|
||||
|
||||
def test_google_search_and_cse_share_credential_group(self):
|
||||
"""google_search and google_cse share the same credential_group."""
|
||||
google_search = CREDENTIAL_SPECS["google_search"]
|
||||
google_cse = CREDENTIAL_SPECS["google_cse"]
|
||||
|
||||
assert google_search.credential_group == "google_custom_search"
|
||||
assert google_cse.credential_group == "google_custom_search"
|
||||
assert google_search.credential_group == google_cse.credential_group
|
||||
|
||||
def test_credential_group_default_empty(self):
|
||||
"""Specs without a group have empty credential_group."""
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
if name not in ("google_search", "google_cse"):
|
||||
assert spec.credential_group == "", (
|
||||
f"Credential '{name}' has unexpected credential_group='{spec.credential_group}'"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
"""Tests for credential health checkers."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
|
||||
from aden_tools.credentials.health_check import (
|
||||
HEALTH_CHECKERS,
|
||||
AnthropicHealthChecker,
|
||||
GitHubHealthChecker,
|
||||
GoogleSearchHealthChecker,
|
||||
ResendHealthChecker,
|
||||
check_credential_health,
|
||||
)
|
||||
|
||||
|
||||
class TestHealthCheckerRegistry:
|
||||
"""Tests for the HEALTH_CHECKERS registry."""
|
||||
|
||||
def test_google_search_registered(self):
|
||||
"""GoogleSearchHealthChecker is registered in HEALTH_CHECKERS."""
|
||||
assert "google_search" in HEALTH_CHECKERS
|
||||
assert isinstance(HEALTH_CHECKERS["google_search"], GoogleSearchHealthChecker)
|
||||
|
||||
def test_anthropic_registered(self):
|
||||
"""AnthropicHealthChecker is registered in HEALTH_CHECKERS."""
|
||||
assert "anthropic" in HEALTH_CHECKERS
|
||||
assert isinstance(HEALTH_CHECKERS["anthropic"], AnthropicHealthChecker)
|
||||
|
||||
def test_github_registered(self):
|
||||
"""GitHubHealthChecker is registered in HEALTH_CHECKERS."""
|
||||
assert "github" in HEALTH_CHECKERS
|
||||
assert isinstance(HEALTH_CHECKERS["github"], GitHubHealthChecker)
|
||||
|
||||
def test_resend_registered(self):
|
||||
"""ResendHealthChecker is registered in HEALTH_CHECKERS."""
|
||||
assert "resend" in HEALTH_CHECKERS
|
||||
assert isinstance(HEALTH_CHECKERS["resend"], ResendHealthChecker)
|
||||
|
||||
def test_all_expected_checkers_registered(self):
|
||||
"""All expected health checkers are in the registry."""
|
||||
expected = {"hubspot", "brave_search", "google_search", "anthropic", "github", "resend"}
|
||||
assert set(HEALTH_CHECKERS.keys()) == expected
|
||||
|
||||
|
||||
class TestAnthropicHealthChecker:
|
||||
"""Tests for AnthropicHealthChecker."""
|
||||
|
||||
def _mock_response(self, status_code, json_data=None):
|
||||
response = MagicMock(spec=httpx.Response)
|
||||
response.status_code = status_code
|
||||
if json_data:
|
||||
response.json.return_value = json_data
|
||||
return response
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_valid_key_200(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.post.return_value = self._mock_response(200)
|
||||
|
||||
checker = AnthropicHealthChecker()
|
||||
result = checker.check("sk-ant-test-key")
|
||||
|
||||
assert result.valid is True
|
||||
assert "valid" in result.message.lower()
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_invalid_key_401(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.post.return_value = self._mock_response(401)
|
||||
|
||||
checker = AnthropicHealthChecker()
|
||||
result = checker.check("invalid-key")
|
||||
|
||||
assert result.valid is False
|
||||
assert result.details["status_code"] == 401
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_rate_limited_429(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.post.return_value = self._mock_response(429)
|
||||
|
||||
checker = AnthropicHealthChecker()
|
||||
result = checker.check("sk-ant-test-key")
|
||||
|
||||
assert result.valid is True
|
||||
assert result.details.get("rate_limited") is True
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_bad_request_400_still_valid(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.post.return_value = self._mock_response(400)
|
||||
|
||||
checker = AnthropicHealthChecker()
|
||||
result = checker.check("sk-ant-test-key")
|
||||
|
||||
assert result.valid is True
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_timeout(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.post.side_effect = httpx.TimeoutException("timed out")
|
||||
|
||||
checker = AnthropicHealthChecker()
|
||||
result = checker.check("sk-ant-test-key")
|
||||
|
||||
assert result.valid is False
|
||||
assert result.details["error"] == "timeout"
|
||||
|
||||
|
||||
class TestGitHubHealthChecker:
|
||||
"""Tests for GitHubHealthChecker."""
|
||||
|
||||
def _mock_response(self, status_code, json_data=None):
|
||||
response = MagicMock(spec=httpx.Response)
|
||||
response.status_code = status_code
|
||||
if json_data:
|
||||
response.json.return_value = json_data
|
||||
return response
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_valid_token_200(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = self._mock_response(200, {"login": "testuser"})
|
||||
|
||||
checker = GitHubHealthChecker()
|
||||
result = checker.check("ghp_test-token")
|
||||
|
||||
assert result.valid is True
|
||||
assert "testuser" in result.message
|
||||
assert result.details["username"] == "testuser"
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_invalid_token_401(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = self._mock_response(401)
|
||||
|
||||
checker = GitHubHealthChecker()
|
||||
result = checker.check("invalid-token")
|
||||
|
||||
assert result.valid is False
|
||||
assert result.details["status_code"] == 401
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_forbidden_403(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = self._mock_response(403)
|
||||
|
||||
checker = GitHubHealthChecker()
|
||||
result = checker.check("ghp_test-token")
|
||||
|
||||
assert result.valid is False
|
||||
assert result.details["status_code"] == 403
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_timeout(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.side_effect = httpx.TimeoutException("timed out")
|
||||
|
||||
checker = GitHubHealthChecker()
|
||||
result = checker.check("ghp_test-token")
|
||||
|
||||
assert result.valid is False
|
||||
assert result.details["error"] == "timeout"
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_request_error(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.side_effect = httpx.RequestError("connection failed")
|
||||
|
||||
checker = GitHubHealthChecker()
|
||||
result = checker.check("ghp_test-token")
|
||||
|
||||
assert result.valid is False
|
||||
assert "connection failed" in result.details["error"]
|
||||
|
||||
|
||||
class TestResendHealthChecker:
|
||||
"""Tests for ResendHealthChecker."""
|
||||
|
||||
def _mock_response(self, status_code, json_data=None):
|
||||
response = MagicMock(spec=httpx.Response)
|
||||
response.status_code = status_code
|
||||
if json_data:
|
||||
response.json.return_value = json_data
|
||||
return response
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_valid_key_200(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = self._mock_response(200)
|
||||
|
||||
checker = ResendHealthChecker()
|
||||
result = checker.check("re_test-key")
|
||||
|
||||
assert result.valid is True
|
||||
assert "valid" in result.message.lower()
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_invalid_key_401(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = self._mock_response(401)
|
||||
|
||||
checker = ResendHealthChecker()
|
||||
result = checker.check("invalid-key")
|
||||
|
||||
assert result.valid is False
|
||||
assert result.details["status_code"] == 401
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_forbidden_403(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = self._mock_response(403)
|
||||
|
||||
checker = ResendHealthChecker()
|
||||
result = checker.check("re_test-key")
|
||||
|
||||
assert result.valid is False
|
||||
assert result.details["status_code"] == 403
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_timeout(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.side_effect = httpx.TimeoutException("timed out")
|
||||
|
||||
checker = ResendHealthChecker()
|
||||
result = checker.check("re_test-key")
|
||||
|
||||
assert result.valid is False
|
||||
assert result.details["error"] == "timeout"
|
||||
|
||||
|
||||
class TestCheckCredentialHealthDispatcher:
|
||||
"""Tests for the check_credential_health() top-level dispatcher."""
|
||||
|
||||
def test_unknown_credential_returns_valid(self):
|
||||
"""Unregistered credential names are assumed valid."""
|
||||
result = check_credential_health("nonexistent_service", "some-key")
|
||||
|
||||
assert result.valid is True
|
||||
assert result.details.get("no_checker") is True
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_dispatches_to_registered_checker(self, mock_client_cls):
|
||||
"""Normal dispatch calls the registered checker."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
response = MagicMock(spec=httpx.Response)
|
||||
response.status_code = 200
|
||||
mock_client.get.return_value = response
|
||||
|
||||
result = check_credential_health("brave_search", "test-key")
|
||||
|
||||
assert result.valid is True
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
@patch("aden_tools.credentials.health_check.httpx.Client")
|
||||
def test_google_search_with_cse_id(self, mock_client_cls):
|
||||
"""google_search special case passes cse_id to checker."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
response = MagicMock(spec=httpx.Response)
|
||||
response.status_code = 200
|
||||
mock_client.get.return_value = response
|
||||
|
||||
result = check_credential_health("google_search", "api-key", cse_id="cse-123")
|
||||
|
||||
assert result.valid is True
|
||||
# Verify the request included the cse_id as the cx param
|
||||
call_kwargs = mock_client.get.call_args
|
||||
assert call_kwargs[1]["params"]["cx"] == "cse-123"
|
||||
|
||||
def test_google_search_without_cse_id(self):
|
||||
"""google_search without cse_id does partial check (no HTTP call)."""
|
||||
result = check_credential_health("google_search", "api-key")
|
||||
|
||||
assert result.valid is True
|
||||
assert result.details.get("partial_check") is True
|
||||
@@ -0,0 +1,624 @@
|
||||
"""
|
||||
Tests for GitHub tool.
|
||||
|
||||
Covers:
|
||||
- _GitHubClient methods (repositories, issues, PRs, search, branches)
|
||||
- Error handling (API errors, timeout, network errors)
|
||||
- Credential retrieval (CredentialStoreAdapter vs env var)
|
||||
- All 15 MCP tool functions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.github_tool.github_tool import (
|
||||
_GitHubClient,
|
||||
register_tools,
|
||||
)
|
||||
|
||||
# --- _GitHubClient tests ---
|
||||
|
||||
|
||||
class TestGitHubClient:
|
||||
def setup_method(self):
|
||||
self.client = _GitHubClient("ghp_test_token")
|
||||
|
||||
def test_headers(self):
|
||||
headers = self.client._headers
|
||||
assert headers["Authorization"] == "Bearer ghp_test_token"
|
||||
assert "application/vnd.github+json" in headers["Accept"]
|
||||
|
||||
def test_handle_response_success(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.json.return_value = {"id": 123, "name": "test-repo"}
|
||||
result = self.client._handle_response(response)
|
||||
assert result["success"] is True
|
||||
assert result["data"]["name"] == "test-repo"
|
||||
|
||||
def test_handle_response_401(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 401
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert "Invalid or expired" in result["error"]
|
||||
|
||||
def test_handle_response_403(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 403
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert "Forbidden" in result["error"]
|
||||
|
||||
def test_handle_response_404(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 404
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_handle_response_422(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 422
|
||||
response.json.return_value = {"message": "Validation failed"}
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert "Validation" in result["error"]
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_repos(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [
|
||||
{"id": 1, "name": "repo1", "full_name": "user/repo1"},
|
||||
{"id": 2, "name": "repo2", "full_name": "user/repo2"},
|
||||
]
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.list_repos(username="testuser")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
assert result["success"] is True
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_repos_authenticated_user(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = []
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
self.client.list_repos(username=None)
|
||||
|
||||
call_url = mock_get.call_args.args[0]
|
||||
assert "/user/repos" in call_url
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_get_repo(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"id": 123,
|
||||
"name": "test-repo",
|
||||
"full_name": "owner/test-repo",
|
||||
"description": "A test repository",
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.get_repo("owner", "test-repo")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["name"] == "test-repo"
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_search_repos(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"total_count": 1,
|
||||
"items": [{"id": 123, "name": "test-repo"}],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.search_repos("language:python")
|
||||
|
||||
assert result["success"] is True
|
||||
assert "items" in result["data"]
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_issues(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [
|
||||
{"number": 1, "title": "Issue 1", "state": "open"},
|
||||
{"number": 2, "title": "Issue 2", "state": "open"},
|
||||
]
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.list_issues("owner", "repo", state="open")
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_get_issue(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"number": 1,
|
||||
"title": "Test Issue",
|
||||
"body": "This is a test",
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.get_issue("owner", "repo", 1)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["title"] == "Test Issue"
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.post")
|
||||
def test_create_issue(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {
|
||||
"number": 42,
|
||||
"title": "New Issue",
|
||||
"body": "Description",
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.create_issue(
|
||||
"owner", "repo", "New Issue", body="Description", labels=["bug"]
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["number"] == 42
|
||||
call_json = mock_post.call_args.kwargs["json"]
|
||||
assert call_json["labels"] == ["bug"]
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.patch")
|
||||
def test_update_issue(self, mock_patch):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"number": 1,
|
||||
"title": "Updated Title",
|
||||
"state": "closed",
|
||||
}
|
||||
mock_patch.return_value = mock_response
|
||||
|
||||
result = self.client.update_issue("owner", "repo", 1, title="Updated Title", state="closed")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["state"] == "closed"
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_pull_requests(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [
|
||||
{"number": 1, "title": "PR 1", "state": "open"},
|
||||
]
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.list_pull_requests("owner", "repo")
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["data"]) == 1
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_get_pull_request(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"number": 1,
|
||||
"title": "Test PR",
|
||||
"head": {"ref": "feature"},
|
||||
"base": {"ref": "main"},
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.get_pull_request("owner", "repo", 1)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["title"] == "Test PR"
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.post")
|
||||
def test_create_pull_request(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {
|
||||
"number": 10,
|
||||
"title": "New PR",
|
||||
"draft": False,
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.client.create_pull_request(
|
||||
"owner", "repo", "New PR", "feature-branch", "main", body="PR description"
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["number"] == 10
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_search_code(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"total_count": 5,
|
||||
"items": [{"name": "file.py", "path": "src/file.py"}],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.search_code("addClass repo:jquery/jquery")
|
||||
|
||||
assert result["success"] is True
|
||||
assert "items" in result["data"]
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_branches(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [
|
||||
{"name": "main", "protected": True},
|
||||
{"name": "develop", "protected": False},
|
||||
]
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.list_branches("owner", "repo")
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_get_branch(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"name": "main",
|
||||
"protected": True,
|
||||
"commit": {"sha": "abc123"},
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.client.get_branch("owner", "repo", "main")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["name"] == "main"
|
||||
|
||||
|
||||
# --- Credential retrieval tests ---
|
||||
|
||||
|
||||
class TestCredentialRetrieval:
|
||||
@pytest.fixture
|
||||
def mcp(self):
|
||||
return FastMCP("test-server")
|
||||
|
||||
def test_no_credentials_returns_error(self, mcp):
|
||||
"""When no credentials are configured, tools return helpful error."""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
with patch("os.getenv", return_value=None):
|
||||
register_tools(mcp, credentials=None)
|
||||
list_repos = mcp._tool_manager._tools["github_list_repos"].fn
|
||||
|
||||
result = list_repos()
|
||||
|
||||
assert "error" in result
|
||||
assert "not configured" in result["error"]
|
||||
assert "help" in result
|
||||
|
||||
def test_env_var_token(self, mcp):
|
||||
"""Token from GITHUB_TOKEN env var is used."""
|
||||
with patch("os.getenv", return_value="ghp_env_token"):
|
||||
with patch("aden_tools.tools.github_tool.github_tool.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = []
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
register_tools(mcp, credentials=None)
|
||||
list_repos = mcp._tool_manager._tools["github_list_repos"].fn
|
||||
|
||||
list_repos()
|
||||
|
||||
call_headers = mock_get.call_args.kwargs["headers"]
|
||||
assert call_headers["Authorization"] == "Bearer ghp_env_token"
|
||||
|
||||
def test_credential_store_token(self, mcp):
|
||||
"""Token from CredentialStoreAdapter is preferred."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.get.return_value = "ghp_store_token"
|
||||
|
||||
with patch("aden_tools.tools.github_tool.github_tool.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = []
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
register_tools(mcp, credentials=mock_credentials)
|
||||
list_repos = mcp._tool_manager._tools["github_list_repos"].fn
|
||||
|
||||
list_repos()
|
||||
|
||||
mock_credentials.get.assert_called_with("github")
|
||||
call_headers = mock_get.call_args.kwargs["headers"]
|
||||
assert call_headers["Authorization"] == "Bearer ghp_store_token"
|
||||
|
||||
|
||||
# --- MCP Tool function tests ---
|
||||
|
||||
|
||||
class TestGitHubListRepos:
|
||||
@pytest.fixture
|
||||
def mcp(self):
|
||||
return FastMCP("test-server")
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_repos_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [{"id": 1, "name": "test-repo"}]
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
list_repos = mcp._tool_manager._tools["github_list_repos"].fn
|
||||
|
||||
result = list_repos(username="testuser")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_repos_timeout(self, mock_get, mcp):
|
||||
mock_get.side_effect = httpx.TimeoutException("Timeout")
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
list_repos = mcp._tool_manager._tools["github_list_repos"].fn
|
||||
|
||||
result = list_repos()
|
||||
|
||||
assert "error" in result
|
||||
assert "timed out" in result["error"]
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_repos_network_error(self, mock_get, mcp):
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
list_repos = mcp._tool_manager._tools["github_list_repos"].fn
|
||||
|
||||
result = list_repos()
|
||||
|
||||
assert "error" in result
|
||||
assert "Network error" in result["error"]
|
||||
|
||||
|
||||
class TestGitHubGetRepo:
|
||||
@pytest.fixture
|
||||
def mcp(self):
|
||||
return FastMCP("test-server")
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_get_repo_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"id": 1, "name": "test-repo"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
get_repo = mcp._tool_manager._tools["github_get_repo"].fn
|
||||
|
||||
result = get_repo(owner="owner", repo="test-repo")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestGitHubSearchRepos:
|
||||
@pytest.fixture
|
||||
def mcp(self):
|
||||
return FastMCP("test-server")
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_search_repos_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"total_count": 1, "items": []}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
search_repos = mcp._tool_manager._tools["github_search_repos"].fn
|
||||
|
||||
result = search_repos(query="python")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestGitHubIssues:
|
||||
@pytest.fixture
|
||||
def mcp(self):
|
||||
return FastMCP("test-server")
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_issues_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [{"number": 1, "title": "Test Issue"}]
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
list_issues = mcp._tool_manager._tools["github_list_issues"].fn
|
||||
|
||||
result = list_issues(owner="owner", repo="repo")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_get_issue_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"number": 1, "title": "Test"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
get_issue = mcp._tool_manager._tools["github_get_issue"].fn
|
||||
|
||||
result = get_issue(owner="owner", repo="repo", issue_number=1)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.post")
|
||||
def test_create_issue_success(self, mock_post, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"number": 1, "title": "New Issue"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
create_issue = mcp._tool_manager._tools["github_create_issue"].fn
|
||||
|
||||
result = create_issue(owner="owner", repo="repo", title="New Issue")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.patch")
|
||||
def test_update_issue_success(self, mock_patch, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"number": 1, "state": "closed"}
|
||||
mock_patch.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
update_issue = mcp._tool_manager._tools["github_update_issue"].fn
|
||||
|
||||
result = update_issue(owner="owner", repo="repo", issue_number=1, state="closed")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestGitHubPullRequests:
|
||||
@pytest.fixture
|
||||
def mcp(self):
|
||||
return FastMCP("test-server")
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_pull_requests_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [{"number": 1, "title": "Test PR"}]
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
list_prs = mcp._tool_manager._tools["github_list_pull_requests"].fn
|
||||
|
||||
result = list_prs(owner="owner", repo="repo")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_get_pull_request_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"number": 1, "title": "PR"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
get_pr = mcp._tool_manager._tools["github_get_pull_request"].fn
|
||||
|
||||
result = get_pr(owner="owner", repo="repo", pull_number=1)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.post")
|
||||
def test_create_pull_request_success(self, mock_post, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"number": 1, "title": "New PR"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
create_pr = mcp._tool_manager._tools["github_create_pull_request"].fn
|
||||
|
||||
result = create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="New PR",
|
||||
head="feature",
|
||||
base="main",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestGitHubSearch:
|
||||
@pytest.fixture
|
||||
def mcp(self):
|
||||
return FastMCP("test-server")
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_search_code_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"total_count": 1, "items": []}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
search_code = mcp._tool_manager._tools["github_search_code"].fn
|
||||
|
||||
result = search_code(query="addClass")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestGitHubBranches:
|
||||
@pytest.fixture
|
||||
def mcp(self):
|
||||
return FastMCP("test-server")
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_list_branches_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [{"name": "main"}]
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
list_branches = mcp._tool_manager._tools["github_list_branches"].fn
|
||||
|
||||
result = list_branches(owner="owner", repo="repo")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.github_tool.github_tool.httpx.get")
|
||||
def test_get_branch_success(self, mock_get, mcp):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"name": "main", "protected": True}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch("os.getenv", return_value="ghp_test"):
|
||||
register_tools(mcp, credentials=None)
|
||||
get_branch = mcp._tool_manager._tools["github_get_branch"].fn
|
||||
|
||||
result = get_branch(owner="owner", repo="repo", branch="main")
|
||||
|
||||
assert result["success"] is True
|
||||
@@ -95,6 +95,7 @@ class TestPdfReadTool:
|
||||
def __init__(self, path: Path) -> None: # noqa: ARG002
|
||||
self.pages = [FakePage(f"Page {i + 1}") for i in range(50)]
|
||||
self.is_encrypted = False
|
||||
self.metadata = None
|
||||
|
||||
# Patch PdfReader used inside the tool so we don't need a real PDF
|
||||
from aden_tools.tools.pdf_read_tool import pdf_read_tool
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for web_scrape tool (FastMCP)."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
@@ -15,60 +15,135 @@ def web_scrape_fn(mcp: FastMCP):
|
||||
return mcp._tool_manager._tools["web_scrape"].fn
|
||||
|
||||
|
||||
def _make_playwright_mocks(html, status=200, final_url="https://example.com/page"):
|
||||
"""Build a full playwright mock chain and return (context_manager, response, page)."""
|
||||
mock_response = MagicMock(
|
||||
status=status,
|
||||
url=final_url,
|
||||
headers={"content-type": "text/html; charset=utf-8"},
|
||||
)
|
||||
|
||||
mock_page = AsyncMock()
|
||||
mock_page.goto.return_value = mock_response
|
||||
mock_page.content.return_value = html
|
||||
mock_page.wait_for_timeout.return_value = None
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
mock_browser = AsyncMock()
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
|
||||
mock_pw = MagicMock()
|
||||
mock_pw.chromium.launch = AsyncMock(return_value=mock_browser)
|
||||
|
||||
# async context manager for async_playwright()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_pw)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return mock_cm, mock_response, mock_page
|
||||
|
||||
|
||||
_PW_PATH = "aden_tools.tools.web_scrape_tool.web_scrape_tool.async_playwright"
|
||||
_STEALTH_PATH = "aden_tools.tools.web_scrape_tool.web_scrape_tool.Stealth"
|
||||
|
||||
|
||||
class TestWebScrapeTool:
|
||||
"""Tests for web_scrape tool."""
|
||||
|
||||
def test_url_auto_prefixed_with_https(self, web_scrape_fn):
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_url_auto_prefixed_with_https(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""URLs without scheme get https:// prefix."""
|
||||
# This will fail to connect, but we can verify the behavior
|
||||
result = web_scrape_fn(url="example.com")
|
||||
# Should either succeed or have a network error (not a validation error)
|
||||
assert isinstance(result, dict)
|
||||
html = "<html><body>Hello</body></html>"
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
def test_max_length_clamped_low(self, web_scrape_fn):
|
||||
result = await web_scrape_fn(url="example.com")
|
||||
assert isinstance(result, dict)
|
||||
assert "error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_max_length_clamped_low(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""max_length below 1000 is clamped to 1000."""
|
||||
# Test with a very low max_length - implementation clamps to 1000
|
||||
result = web_scrape_fn(url="https://example.com", max_length=500)
|
||||
# Should not error due to invalid max_length
|
||||
assert isinstance(result, dict)
|
||||
html = "<html><body>Hello</body></html>"
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
def test_max_length_clamped_high(self, web_scrape_fn):
|
||||
result = await web_scrape_fn(url="https://example.com", max_length=500)
|
||||
assert isinstance(result, dict)
|
||||
assert "error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_max_length_clamped_high(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""max_length above 500000 is clamped to 500000."""
|
||||
# Test with a very high max_length - implementation clamps to 500000
|
||||
result = web_scrape_fn(url="https://example.com", max_length=600000)
|
||||
# Should not error due to invalid max_length
|
||||
assert isinstance(result, dict)
|
||||
html = "<html><body>Hello</body></html>"
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
def test_valid_max_length_accepted(self, web_scrape_fn):
|
||||
result = await web_scrape_fn(url="https://example.com", max_length=600000)
|
||||
assert isinstance(result, dict)
|
||||
assert "error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_valid_max_length_accepted(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""Valid max_length values are accepted."""
|
||||
result = web_scrape_fn(url="https://example.com", max_length=10000)
|
||||
assert isinstance(result, dict)
|
||||
html = "<html><body>Hello</body></html>"
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
def test_include_links_option(self, web_scrape_fn):
|
||||
result = await web_scrape_fn(url="https://example.com", max_length=10000)
|
||||
assert isinstance(result, dict)
|
||||
assert "error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_include_links_option(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""include_links parameter is accepted."""
|
||||
result = web_scrape_fn(url="https://example.com", include_links=True)
|
||||
assert isinstance(result, dict)
|
||||
html = '<html><body><a href="/link">Link</a></body></html>'
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
def test_selector_option(self, web_scrape_fn):
|
||||
"""selector parameter is accepted."""
|
||||
result = web_scrape_fn(url="https://example.com", selector=".content")
|
||||
result = await web_scrape_fn(url="https://example.com", include_links=True)
|
||||
assert isinstance(result, dict)
|
||||
assert "error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_selector_option(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""selector parameter is accepted."""
|
||||
html = '<html><body><div class="content">Content here</div></body></html>'
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = await web_scrape_fn(url="https://example.com", selector=".content")
|
||||
assert isinstance(result, dict)
|
||||
assert "error" not in result
|
||||
|
||||
|
||||
class TestWebScrapeToolLinkConversion:
|
||||
"""Tests for link URL conversion (relative to absolute)."""
|
||||
|
||||
def _mock_response(self, html_content, final_url="https://example.com/page"):
|
||||
"""Create a mock httpx response object."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = html_content
|
||||
mock_response.url = final_url
|
||||
mock_response.headers = {"content-type": "text/html; charset=utf-8"}
|
||||
return mock_response
|
||||
|
||||
@patch("aden_tools.tools.web_scrape_tool.web_scrape_tool.httpx.get")
|
||||
def test_relative_links_converted_to_absolute(self, mock_get, web_scrape_fn):
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_relative_links_converted_to_absolute(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""Relative URLs like ../page are converted to absolute URLs."""
|
||||
html = """
|
||||
<html>
|
||||
@@ -78,9 +153,11 @@ class TestWebScrapeToolLinkConversion:
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mock_get.return_value = self._mock_response(html, "https://example.com/blog/post")
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com/blog/post")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = web_scrape_fn(url="https://example.com/blog/post", include_links=True)
|
||||
result = await web_scrape_fn(url="https://example.com/blog/post", include_links=True)
|
||||
|
||||
assert "error" not in result
|
||||
assert "links" in result
|
||||
@@ -95,8 +172,10 @@ class TestWebScrapeToolLinkConversion:
|
||||
expected = "https://example.com/blog/page.html"
|
||||
assert hrefs["Next Page"] == expected, f"Got {hrefs['Next Page']}"
|
||||
|
||||
@patch("aden_tools.tools.web_scrape_tool.web_scrape_tool.httpx.get")
|
||||
def test_root_relative_links_converted(self, mock_get, web_scrape_fn):
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_root_relative_links_converted(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""Root-relative URLs like /about are converted to absolute URLs."""
|
||||
html = """
|
||||
<html>
|
||||
@@ -106,9 +185,11 @@ class TestWebScrapeToolLinkConversion:
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mock_get.return_value = self._mock_response(html, "https://example.com/blog/post")
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com/blog/post")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = web_scrape_fn(url="https://example.com/blog/post", include_links=True)
|
||||
result = await web_scrape_fn(url="https://example.com/blog/post", include_links=True)
|
||||
|
||||
assert "error" not in result
|
||||
assert "links" in result
|
||||
@@ -119,8 +200,10 @@ class TestWebScrapeToolLinkConversion:
|
||||
assert hrefs["About"] == "https://example.com/about"
|
||||
assert hrefs["Contact"] == "https://example.com/contact"
|
||||
|
||||
@patch("aden_tools.tools.web_scrape_tool.web_scrape_tool.httpx.get")
|
||||
def test_absolute_links_unchanged(self, mock_get, web_scrape_fn):
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_absolute_links_unchanged(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""Absolute URLs remain unchanged."""
|
||||
html = """
|
||||
<html>
|
||||
@@ -130,9 +213,11 @@ class TestWebScrapeToolLinkConversion:
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mock_get.return_value = self._mock_response(html)
|
||||
mock_cm, _, _ = _make_playwright_mocks(html)
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = web_scrape_fn(url="https://example.com", include_links=True)
|
||||
result = await web_scrape_fn(url="https://example.com", include_links=True)
|
||||
|
||||
assert "error" not in result
|
||||
assert "links" in result
|
||||
@@ -143,8 +228,10 @@ class TestWebScrapeToolLinkConversion:
|
||||
assert hrefs["Other Site"] == "https://other.com"
|
||||
assert hrefs["Internal"] == "https://example.com/page"
|
||||
|
||||
@patch("aden_tools.tools.web_scrape_tool.web_scrape_tool.httpx.get")
|
||||
def test_links_after_redirects(self, mock_get, web_scrape_fn):
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_links_after_redirects(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""Links are resolved relative to final URL after redirects."""
|
||||
html = """
|
||||
<html>
|
||||
@@ -155,12 +242,14 @@ class TestWebScrapeToolLinkConversion:
|
||||
</html>
|
||||
"""
|
||||
# Mock redirect: request to /old/url redirects to /new/location
|
||||
mock_get.return_value = self._mock_response(
|
||||
mock_cm, _, _ = _make_playwright_mocks(
|
||||
html,
|
||||
final_url="https://example.com/new/location", # Final URL after redirect
|
||||
)
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = web_scrape_fn(url="https://example.com/old/url", include_links=True)
|
||||
result = await web_scrape_fn(url="https://example.com/old/url", include_links=True)
|
||||
|
||||
assert "error" not in result
|
||||
assert "links" in result
|
||||
@@ -173,8 +262,10 @@ class TestWebScrapeToolLinkConversion:
|
||||
)
|
||||
assert hrefs["Next"] == "https://example.com/new/next"
|
||||
|
||||
@patch("aden_tools.tools.web_scrape_tool.web_scrape_tool.httpx.get")
|
||||
def test_fragment_links_preserved(self, mock_get, web_scrape_fn):
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_fragment_links_preserved(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""Fragment links (anchors) are preserved."""
|
||||
html = """
|
||||
<html>
|
||||
@@ -184,9 +275,11 @@ class TestWebScrapeToolLinkConversion:
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mock_get.return_value = self._mock_response(html, "https://example.com/page")
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com/page")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = web_scrape_fn(url="https://example.com/page", include_links=True)
|
||||
result = await web_scrape_fn(url="https://example.com/page", include_links=True)
|
||||
|
||||
assert "error" not in result
|
||||
assert "links" in result
|
||||
@@ -197,8 +290,10 @@ class TestWebScrapeToolLinkConversion:
|
||||
assert hrefs["Section 1"] == "https://example.com/page#section1"
|
||||
assert hrefs["Page Section 2"] == "https://example.com/page#section2"
|
||||
|
||||
@patch("aden_tools.tools.web_scrape_tool.web_scrape_tool.httpx.get")
|
||||
def test_query_parameters_preserved(self, mock_get, web_scrape_fn):
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_query_parameters_preserved(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""Query parameters in URLs are preserved."""
|
||||
html = """
|
||||
<html>
|
||||
@@ -208,9 +303,11 @@ class TestWebScrapeToolLinkConversion:
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mock_get.return_value = self._mock_response(html, "https://example.com/blog/post")
|
||||
mock_cm, _, _ = _make_playwright_mocks(html, final_url="https://example.com/blog/post")
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = web_scrape_fn(url="https://example.com/blog/post", include_links=True)
|
||||
result = await web_scrape_fn(url="https://example.com/blog/post", include_links=True)
|
||||
|
||||
assert "error" not in result
|
||||
assert "links" in result
|
||||
@@ -222,8 +319,10 @@ class TestWebScrapeToolLinkConversion:
|
||||
assert "q=test" in hrefs["Search"]
|
||||
assert "sort=date" in hrefs["Search"]
|
||||
|
||||
@patch("aden_tools.tools.web_scrape_tool.web_scrape_tool.httpx.get")
|
||||
def test_empty_href_skipped(self, mock_get, web_scrape_fn):
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
async def test_empty_href_skipped(self, mock_pw, mock_stealth, web_scrape_fn):
|
||||
"""Links with empty or whitespace text are skipped."""
|
||||
html = """
|
||||
<html>
|
||||
@@ -234,9 +333,11 @@ class TestWebScrapeToolLinkConversion:
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mock_get.return_value = self._mock_response(html)
|
||||
mock_cm, _, _ = _make_playwright_mocks(html)
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = web_scrape_fn(url="https://example.com", include_links=True)
|
||||
result = await web_scrape_fn(url="https://example.com", include_links=True)
|
||||
|
||||
assert "error" not in result
|
||||
assert "links" in result
|
||||
|
||||
Reference in New Issue
Block a user