Compare commits
90 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 119280da1a | |||
| 4d49f74d5a | |||
| 6a42b9c66b | |||
| fc4a39480a | |||
| b98afb01c8 | |||
| ccd6bb7656 | |||
| ea30e5c631 | |||
| d16a3c3b22 | |||
| a03bd78c2e | |||
| 3cca41aab1 | |||
| d19aaed946 | |||
| 9a7db8cf94 | |||
| f50630c551 | |||
| 0ef2e64733 | |||
| 3a8e121d43 | |||
| 51d341b88c | |||
| 7dd70b8e31 | |||
| 84b332d989 | |||
| bcc6848275 | |||
| 75dd053a40 | |||
| 20f2aa09f2 | |||
| fb8c810b3d | |||
| ad21cf4243 | |||
| 1e45cfff67 | |||
| 0280600a47 | |||
| 571ad518dc | |||
| fe37a25cf1 | |||
| e06138628c | |||
| 1ed0edd158 | |||
| 49dbc46082 | |||
| a16a4adc09 | |||
| b4ab1cbd56 | |||
| 6faa63f0d0 | |||
| 2b44af427f | |||
| 11f7401bc2 | |||
| db7b5180dd | |||
| 5b4e56252c | |||
| e3c71f77de | |||
| b09824faec | |||
| c69bc24598 | |||
| 0cf17e1c63 | |||
| feac803491 | |||
| 4aacec30d8 | |||
| b459a2f7a9 | |||
| ca8ede65f0 | |||
| 9a177c46e1 | |||
| d49e858d32 | |||
| 20bea9cd7f | |||
| d7afa5dcf2 | |||
| 22e816bf86 | |||
| a7709d489c | |||
| 3240616808 | |||
| 18dfc997b8 | |||
| 92d0b6addf | |||
| b9f83d4d61 | |||
| 9c16826ad3 | |||
| f305745295 | |||
| df4d0ad3fd | |||
| 9034d1dc71 | |||
| 537172d8ce | |||
| 20b2e4b3dd | |||
| fc22586752 | |||
| 646440eba3 | |||
| 53e5579326 | |||
| 29a1630d0f | |||
| 171f4ab2ae | |||
| a86043a2ec | |||
| 3947da2cf1 | |||
| 17caab6563 | |||
| a5ae071a03 | |||
| 94d31743b0 | |||
| 70db618c6e | |||
| 363a650dfa | |||
| b6e2634537 | |||
| 9f424f2fc0 | |||
| 0715fc5498 | |||
| f9fddd6663 | |||
| 63d017fc21 | |||
| c52ce6bb49 | |||
| bcddd4ce77 | |||
| 017872f71b | |||
| bfb660275e | |||
| d6ae48bc58 | |||
| 7e670ce0a8 | |||
| 94197cbcb9 | |||
| 3ee6d98905 | |||
| a96cd546c8 | |||
| eb33d4f1c2 | |||
| 4253956326 | |||
| d6b05bf337 |
@@ -1,40 +0,0 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(npm install:*)",
|
||||
"Bash(npm test:*)",
|
||||
"Skill(building-agents-construction)",
|
||||
"Skill(building-agents-construction:*)",
|
||||
"Bash(PYTHONPATH=core:exports pytest:*)",
|
||||
"mcp__agent-builder__create_session",
|
||||
"mcp__agent-builder__get_session_status",
|
||||
"mcp__agent-builder__set_goal",
|
||||
"mcp__agent-builder__list_mcp_servers",
|
||||
"mcp__agent-builder__test_node",
|
||||
"mcp__agent-builder__add_node",
|
||||
"mcp__agent-builder__add_edge",
|
||||
"mcp__agent-builder__validate_graph",
|
||||
"Bash(ruff check:*)",
|
||||
"Bash(PYTHONPATH=core:exports python:*)",
|
||||
"mcp__agent-builder__list_tests",
|
||||
"mcp__agent-builder__generate_constraint_tests",
|
||||
"Bash(python -m agent:*)",
|
||||
"Bash(python agent.py:*)",
|
||||
"Bash(python -c:*)",
|
||||
"Bash(done)",
|
||||
"Bash(xargs cat:*)",
|
||||
"mcp__agent-builder__list_mcp_tools",
|
||||
"mcp__agent-builder__add_mcp_server",
|
||||
"mcp__agent-builder__check_missing_credentials",
|
||||
"mcp__agent-builder__store_credential",
|
||||
"mcp__agent-builder__list_stored_credentials",
|
||||
"mcp__agent-builder__delete_stored_credential",
|
||||
"mcp__agent-builder__verify_credentials",
|
||||
"Bash(PYTHONPATH=/home/timothy/oss/hive/core:/home/timothy/oss/hive/exports python:*)",
|
||||
"Bash(PYTHONPATH=core:exports:tools/src python -m hubspot_input:*)",
|
||||
"mcp__agent-builder__export_graph"
|
||||
]
|
||||
},
|
||||
"enabledMcpjsonServers": ["agent-builder", "tools"],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
+2
-12
@@ -218,19 +218,9 @@ class OnlineResearchAgent:
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
# Load MCP servers (always load, needed for tool validation)
|
||||
agent_dir = Path(__file__).parent
|
||||
mcp_config_path = agent_dir / "mcp_servers.json"
|
||||
|
||||
mcp_config_path = Path(__file__).parent / "mcp_servers.json"
|
||||
if mcp_config_path.exists():
|
||||
with open(mcp_config_path) as f:
|
||||
mcp_servers = json.load(f)
|
||||
|
||||
for server_config in mcp_servers.get("servers", []):
|
||||
# Resolve relative cwd paths
|
||||
cwd = server_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
server_config["cwd"] = str(agent_dir / cwd)
|
||||
tool_registry.register_mcp_server(server_config)
|
||||
tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
llm = None
|
||||
if not mock_mode:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
---
|
||||
name: setup-credentials
|
||||
description: Set up and install credentials for an agent. Detects missing credentials from agent config, collects them from the user, and stores them securely in the encrypted credential store at ~/.hive/credentials.
|
||||
description: Set up and install credentials for an agent. Detects missing credentials from agent config, collects them from the user, and stores them securely in the local encrypted store at ~/.hive/credentials.
|
||||
license: Apache-2.0
|
||||
metadata:
|
||||
author: hive
|
||||
version: "2.1"
|
||||
version: "2.2"
|
||||
type: utility
|
||||
---
|
||||
|
||||
@@ -31,48 +31,96 @@ Determine which agent needs credentials. The user will either:
|
||||
|
||||
Locate the agent's directory under `exports/{agent_name}/`.
|
||||
|
||||
### Step 2: Detect Required Credentials
|
||||
### Step 2: Detect Required Credentials (Bash-First)
|
||||
|
||||
Read the agent's configuration to determine which tools and node types it uses:
|
||||
Use bash commands to determine what the agent needs and what's already configured. This avoids Python import issues and works even when `HIVE_CREDENTIAL_KEY` is not set.
|
||||
|
||||
```python
|
||||
from core.framework.runner import AgentRunner
|
||||
#### Step 2a: Read Agent Requirements
|
||||
|
||||
runner = AgentRunner.load("exports/{agent_name}")
|
||||
validation = runner.validate()
|
||||
Extract `required_tools` and node types from the agent config:
|
||||
|
||||
# validation.missing_credentials contains env var names
|
||||
# validation.warnings contains detailed messages with help URLs
|
||||
```bash
|
||||
# Get required tools
|
||||
jq -r '.required_tools[]?' exports/{agent_name}/agent.json 2>/dev/null
|
||||
|
||||
# Get node types from graph nodes
|
||||
jq -r '.graph.nodes[]?.node_type' exports/{agent_name}/agent.json 2>/dev/null | sort -u
|
||||
```
|
||||
|
||||
Alternatively, check the credential store directly:
|
||||
Map the extracted tools and node types to credentials by reading the spec files directly:
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore
|
||||
|
||||
# Use encrypted storage (default: ~/.hive/credentials)
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
|
||||
# Check what's available
|
||||
available = store.list_credentials()
|
||||
print(f"Available credentials: {available}")
|
||||
|
||||
# Check if specific credential exists
|
||||
if store.is_available("hubspot"):
|
||||
print("HubSpot credential found")
|
||||
else:
|
||||
print("HubSpot credential missing")
|
||||
```bash
|
||||
# Read all credential specs — each file defines tools, node_types, env_var, and credential_id
|
||||
cat tools/src/aden_tools/credentials/llm.py tools/src/aden_tools/credentials/search.py tools/src/aden_tools/credentials/email.py tools/src/aden_tools/credentials/integrations.py
|
||||
```
|
||||
|
||||
To see all known credential specs (for help URLs and setup instructions):
|
||||
For each `CredentialSpec`, match its `tools` and `node_types` lists against the agent's required tools and node types. Extract the `env_var`, `credential_id`, and `credential_group` for every match. This is the list of needed credentials.
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
#### Step 2b: Check Existing Credential Sources
|
||||
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
print(f"{name}: env_var={spec.env_var}, aden={spec.aden_supported}")
|
||||
For each needed credential, check three sources. A credential is "found" if it exists in ANY of them:
|
||||
|
||||
**1. Encrypted store metadata index** (unencrypted JSON — no decryption key needed):
|
||||
|
||||
```bash
|
||||
cat ~/.hive/credentials/metadata/index.json 2>/dev/null | jq -r '.credentials | keys[]'
|
||||
```
|
||||
|
||||
If a credential ID appears in this list, it is stored in the encrypted store.
|
||||
|
||||
**2. Environment variables:**
|
||||
|
||||
```bash
|
||||
# Check each needed env var, e.g.:
|
||||
printenv ANTHROPIC_API_KEY > /dev/null 2>&1 && echo "ANTHROPIC_API_KEY: set" || echo "ANTHROPIC_API_KEY: not set"
|
||||
printenv BRAVE_SEARCH_API_KEY > /dev/null 2>&1 && echo "BRAVE_SEARCH_API_KEY: set" || echo "BRAVE_SEARCH_API_KEY: not set"
|
||||
```
|
||||
|
||||
**3. Project `.env` file:**
|
||||
|
||||
```bash
|
||||
# Check each needed env var, e.g.:
|
||||
grep -q '^ANTHROPIC_API_KEY=' .env 2>/dev/null && echo "ANTHROPIC_API_KEY: in .env" || echo "ANTHROPIC_API_KEY: not in .env"
|
||||
grep -q '^BRAVE_SEARCH_API_KEY=' .env 2>/dev/null && echo "BRAVE_SEARCH_API_KEY: in .env" || echo "BRAVE_SEARCH_API_KEY: not in .env"
|
||||
```
|
||||
|
||||
#### Step 2c: HIVE_CREDENTIAL_KEY Check
|
||||
|
||||
If any credentials were found in the encrypted store metadata index, verify the encryption key is available. The key is typically persisted to shell config by a previous setup-credentials run.
|
||||
|
||||
Check both the current session AND shell config files:
|
||||
|
||||
```bash
|
||||
# Check 1: Current session
|
||||
printenv HIVE_CREDENTIAL_KEY > /dev/null 2>&1 && echo "session: set" || echo "session: not set"
|
||||
|
||||
# Check 2: Shell config files (where setup-credentials persists it)
|
||||
# Note: check each file individually to avoid non-zero exit when one doesn't exist
|
||||
for f in ~/.zshrc ~/.bashrc ~/.profile; do [ -f "$f" ] && grep -q 'HIVE_CREDENTIAL_KEY' "$f" && echo "$f"; done
|
||||
```
|
||||
|
||||
Decision logic:
|
||||
- **In current session** — no action needed, credentials in the store are usable
|
||||
- **In shell config but NOT in current session** — the key is persisted but this shell hasn't sourced it. Run `source ~/.zshrc` (or `~/.bashrc`), then re-check. Credentials in the store are usable after sourcing.
|
||||
- **Not in session AND not in shell config** — the key was never persisted. Warn the user that credentials in the store cannot be decrypted. Help fix the key situation (recover/re-persist), do NOT re-collect credential values that are already stored.
|
||||
|
||||
#### Step 2d: Compute Missing & Group
|
||||
|
||||
Diff the "needed" credentials against the "found" credentials to get the truly missing list.
|
||||
|
||||
Group related credentials by their `credential_group` field from the spec files. Credentials that share the same non-empty `credential_group` value should be presented as a single setup step rather than asking for each one individually.
|
||||
|
||||
**If nothing is missing and there's no HIVE_CREDENTIAL_KEY issue:** Report all credentials as configured and skip Steps 3-5. Example:
|
||||
|
||||
```
|
||||
All required credentials are already configured:
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — found in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — found in environment
|
||||
Your agent is ready to run!
|
||||
```
|
||||
|
||||
**If credentials are missing:** Continue to Step 3 with only the missing ones.
|
||||
|
||||
### Step 3: Present Auth Options for Each Missing Credential
|
||||
|
||||
For each missing credential, check what authentication methods are available:
|
||||
@@ -104,7 +152,7 @@ Present the available options using AskUserQuestion:
|
||||
```
|
||||
Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
|
||||
1) Aden Authorization Server (Recommended)
|
||||
1) Aden Platform (OAuth) (Recommended)
|
||||
Secure OAuth2 flow via integration.adenhq.com
|
||||
- Quick setup with automatic token refresh
|
||||
- No need to manage API keys manually
|
||||
@@ -114,7 +162,7 @@ Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
- Requires creating a HubSpot Private App
|
||||
- Full control over scopes and permissions
|
||||
|
||||
3) Custom Credential Store (Advanced)
|
||||
3) Local Credential Setup (Advanced)
|
||||
Programmatic configuration for CI/CD
|
||||
- For automated deployments
|
||||
- Requires manual API calls
|
||||
@@ -122,7 +170,7 @@ Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
|
||||
### Step 4: Execute Auth Flow Based on User Choice
|
||||
|
||||
#### Option 1: Aden Authorization Server
|
||||
#### Option 1: Aden Platform (OAuth)
|
||||
|
||||
This is the recommended flow for supported integrations (HubSpot, etc.).
|
||||
|
||||
@@ -174,7 +222,7 @@ shell_type = detect_shell() # 'bash', 'zsh', or 'unknown'
|
||||
success, config_path = add_env_var_to_shell_config(
|
||||
"ADEN_API_KEY",
|
||||
user_provided_key,
|
||||
comment="Aden authorization server API key"
|
||||
comment="Aden Platform (OAuth) API key"
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -313,7 +361,7 @@ if not result.valid:
|
||||
# 2. Continue anyway (not recommended)
|
||||
```
|
||||
|
||||
**4.2d. Store in Encrypted Credential Store**
|
||||
**4.2d. Store in Local Encrypted Store**
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore, CredentialObject, CredentialKey
|
||||
@@ -340,7 +388,7 @@ store.save_credential(cred)
|
||||
export HUBSPOT_ACCESS_TOKEN="the-value"
|
||||
```
|
||||
|
||||
#### Option 3: Custom Credential Store (Advanced)
|
||||
#### Option 3: Local Credential Setup (Advanced)
|
||||
|
||||
For programmatic/CI/CD setups.
|
||||
|
||||
@@ -408,10 +456,14 @@ Report the result to the user.
|
||||
|
||||
Health checks validate credentials by making lightweight API calls:
|
||||
|
||||
| Credential | Endpoint | What It Checks |
|
||||
| -------------- | --------------------------------------- | --------------------------------- |
|
||||
| `hubspot` | `GET /crm/v3/objects/contacts?limit=1` | Bearer token validity, CRM scopes |
|
||||
| `brave_search` | `GET /res/v1/web/search?q=test&count=1` | API key validity |
|
||||
| Credential | Endpoint | What It Checks |
|
||||
| --------------- | --------------------------------------- | ---------------------------------- |
|
||||
| `anthropic` | `POST /v1/messages` | API key validity |
|
||||
| `brave_search` | `GET /res/v1/web/search?q=test&count=1` | API key validity |
|
||||
| `google_search` | `GET /customsearch/v1?q=test&num=1` | API key + CSE ID validity |
|
||||
| `github` | `GET /user` | Token validity, user identity |
|
||||
| `hubspot` | `GET /crm/v3/objects/contacts?limit=1` | Bearer token validity, CRM scopes |
|
||||
| `resend` | `GET /domains` | API key validity |
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import check_credential_health, HealthCheckResult
|
||||
@@ -424,7 +476,7 @@ result: HealthCheckResult = check_credential_health("hubspot", token_value)
|
||||
|
||||
## Encryption Key (HIVE_CREDENTIAL_KEY)
|
||||
|
||||
The encrypted credential store requires `HIVE_CREDENTIAL_KEY` to encrypt/decrypt credentials.
|
||||
The local encrypted store requires `HIVE_CREDENTIAL_KEY` to encrypt/decrypt credentials.
|
||||
|
||||
- If the user doesn't have one, `EncryptedFileStorage` will auto-generate one and log it
|
||||
- The user MUST persist this key (e.g., in `~/.bashrc` or a secrets manager)
|
||||
@@ -443,7 +495,7 @@ If `HIVE_CREDENTIAL_KEY` is not set:
|
||||
- **NEVER** store credentials in plaintext files, git-tracked files, or agent configs
|
||||
- **NEVER** hardcode credentials in source code
|
||||
- **ALWAYS** use `SecretStr` from Pydantic when handling credential values in Python
|
||||
- **ALWAYS** use the encrypted credential store (`~/.hive/credentials`) for persistence
|
||||
- **ALWAYS** use the local encrypted store (`~/.hive/credentials`) for persistence
|
||||
- **ALWAYS** run health checks before storing credentials (when possible)
|
||||
- **ALWAYS** verify credentials were stored by re-running validation, not by reading them back
|
||||
- When modifying `~/.bashrc` or `~/.zshrc`, confirm with the user first
|
||||
@@ -456,7 +508,8 @@ All credential specs are defined in `tools/src/aden_tools/credentials/`:
|
||||
| ----------------- | ------------- | --------------------------------------------- | -------------- |
|
||||
| `llm.py` | LLM Providers | `anthropic` | No |
|
||||
| `search.py` | Search Tools | `brave_search`, `google_search`, `google_cse` | No |
|
||||
| `integrations.py` | Integrations | `hubspot` | Yes |
|
||||
| `email.py` | Email | `resend` | No |
|
||||
| `integrations.py` | Integrations | `github`, `hubspot` | No / Yes |
|
||||
|
||||
**Note:** Additional LLM providers (Cerebras, Groq, OpenAI) are handled by LiteLLM via environment
|
||||
variables (`CEREBRAS_API_KEY`, `GROQ_API_KEY`, `OPENAI_API_KEY`) but are not yet in CREDENTIAL_SPECS.
|
||||
@@ -507,66 +560,90 @@ token = store.get_key("hubspot", "access_token")
|
||||
## Example Session
|
||||
|
||||
```
|
||||
User: /setup-credentials for my hubspot-agent
|
||||
User: /setup-credentials for my research-agent
|
||||
|
||||
Agent: Let me check what credentials your hubspot-agent needs.
|
||||
Agent: Let me check what credentials your research-agent needs.
|
||||
|
||||
[Runs validation, finds ANTHROPIC_API_KEY and HUBSPOT_ACCESS_TOKEN missing]
|
||||
[Reads agent config]
|
||||
$ jq -r '.required_tools[]?' exports/research-agent/agent.json
|
||||
web_search
|
||||
google_search
|
||||
|
||||
Agent: Your hubspot-agent requires 2 credentials:
|
||||
$ jq -r '.graph.nodes[]?.node_type' exports/research-agent/agent.json | sort -u
|
||||
llm_tool_use
|
||||
|
||||
1. ANTHROPIC_API_KEY - Anthropic API key for LLM calls
|
||||
2. HUBSPOT_ACCESS_TOKEN - HubSpot access token
|
||||
[Maps tools/nodes to credentials using lookup table]
|
||||
Needed: anthropic, brave_search, google_search, google_cse
|
||||
|
||||
Let's set these up one at a time.
|
||||
[Checks encrypted store metadata index]
|
||||
$ cat ~/.hive/credentials/metadata/index.json | jq -r '.credentials | keys[]'
|
||||
anthropic
|
||||
brave_search
|
||||
|
||||
--- Setting up ANTHROPIC_API_KEY ---
|
||||
[Checks environment variables]
|
||||
$ printenv ANTHROPIC_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv BRAVE_SEARCH_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv GOOGLE_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv GOOGLE_CSE_ID > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
|
||||
This credential only supports direct API key entry.
|
||||
[Checks HIVE_CREDENTIAL_KEY since credentials found in store]
|
||||
$ printenv HIVE_CREDENTIAL_KEY > /dev/null 2>&1 && echo "session: set" || echo "session: not set"
|
||||
session: not set
|
||||
$ for f in ~/.zshrc ~/.bashrc ~/.profile; do [ -f "$f" ] && grep -q 'HIVE_CREDENTIAL_KEY' "$f" && echo "$f"; done
|
||||
/Users/user/.zshrc
|
||||
|
||||
[Shows instructions for getting Anthropic API key]
|
||||
[AskUserQuestion: "Please provide your Anthropic API key:"]
|
||||
[User provides key]
|
||||
[Stores in credential store, exports to env]
|
||||
[Key is in shell config but not current session — sourcing it]
|
||||
$ source ~/.zshrc
|
||||
|
||||
--- Setting up HUBSPOT_ACCESS_TOKEN ---
|
||||
[Computes missing credentials]
|
||||
Found:
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — in encrypted store
|
||||
Missing:
|
||||
✗ google_search (GOOGLE_API_KEY)
|
||||
✗ google_cse (GOOGLE_CSE_ID)
|
||||
|
||||
[AskUserQuestion with 3 options]
|
||||
Agent: 2 of 4 required credentials are already configured. Only Google Custom
|
||||
Search needs setup (2 values as a single group).
|
||||
|
||||
Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
1) Aden Authorization Server (Recommended)
|
||||
2) Direct API Key
|
||||
3) Custom Credential Store (Advanced)
|
||||
--- Setting up Google Custom Search (google_search + google_cse) ---
|
||||
|
||||
User: 1 (Aden)
|
||||
This requires two values that work together.
|
||||
|
||||
Agent: Great! Let me check if you have an Aden API key configured...
|
||||
First, the Google API Key:
|
||||
1. Go to https://console.cloud.google.com/apis/credentials
|
||||
2. Create a new project (or select an existing one)
|
||||
3. Enable the "Custom Search API" from the API Library
|
||||
4. Go to Credentials > Create Credentials > API Key
|
||||
5. Copy the generated API key
|
||||
|
||||
[Checks for ADEN_API_KEY, not found]
|
||||
[Opens browser to integration.adenhq.com/setup]
|
||||
|
||||
Agent: I've opened Aden's setup page. Please:
|
||||
1. Sign in or create an account
|
||||
2. Connect your HubSpot account (OAuth happens on Aden's site)
|
||||
3. Copy the API key shown after setup
|
||||
|
||||
[AskUserQuestion: "Please provide your Aden API key:"]
|
||||
[AskUserQuestion: "Please provide your Google API key:"]
|
||||
[User provides key]
|
||||
|
||||
Agent: Would you like me to save this to your shell config (~/.zshrc) for future sessions?
|
||||
Now, the Custom Search Engine ID:
|
||||
1. Go to https://programmablesearchengine.google.com/controlpanel/all
|
||||
2. Click "Add" to create a new search engine
|
||||
3. Under "What to search", select "Search the entire web"
|
||||
4. Give your search engine a name
|
||||
5. Click "Create"
|
||||
6. Copy the Search Engine ID (cx value)
|
||||
|
||||
[User approves]
|
||||
[Saves to ~/.zshrc and ~/.hive/configuration.json]
|
||||
[AskUserQuestion: "Please provide your Google CSE ID:"]
|
||||
[User provides ID]
|
||||
|
||||
Agent: Let me sync your credentials from Aden...
|
||||
[Runs health check with both values - GET /customsearch/v1?q=test&num=1 → 200 OK]
|
||||
[Stores both in local encrypted store, exports to env]
|
||||
|
||||
[Syncs credentials from Aden server - OAuth already done on Aden's side]
|
||||
[Runs health check]
|
||||
|
||||
Agent: HubSpot credentials validated successfully!
|
||||
✓ Google Custom Search credentials valid
|
||||
|
||||
All credentials are now configured:
|
||||
- ANTHROPIC_API_KEY: Stored in encrypted credential store
|
||||
- HUBSPOT_ACCESS_TOKEN: Synced from Aden (OAuth completed on Aden)
|
||||
- Validation passed - your agent is ready to run!
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — already in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — already in encrypted store
|
||||
✓ google_search (GOOGLE_API_KEY) — stored in encrypted store
|
||||
✓ google_cse (GOOGLE_CSE_ID) — stored in encrypted store
|
||||
Your agent is ready to run!
|
||||
```
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Report a bug to help us improve
|
||||
title: '[Bug]: '
|
||||
labels: bug
|
||||
title: "[Bug]: "
|
||||
labels: bug, enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Describe the Bug
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
---
|
||||
name: Feature Request
|
||||
about: Suggest a new feature or enhancement
|
||||
title: '[Feature]: '
|
||||
title: "[Feature]: "
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Problem Statement
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
---
|
||||
name: Integration Request
|
||||
about: Suggest a new integration
|
||||
title: "[Integration]:"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Service
|
||||
|
||||
Name and brief description of the service and what it enables agents to do.
|
||||
|
||||
**Description:** [e.g., "API key for Slack Bot" — short one-liner for the credential spec]
|
||||
|
||||
## Credential Identity
|
||||
|
||||
- **credential_id:** [e.g., `slack`]
|
||||
- **env_var:** [e.g., `SLACK_BOT_TOKEN`]
|
||||
- **credential_key:** [e.g., `access_token`, `api_key`, `bot_token`]
|
||||
|
||||
## Tools
|
||||
|
||||
Tool function names that require this credential:
|
||||
|
||||
- [e.g., `slack_send_message`]
|
||||
- [e.g., `slack_list_channels`]
|
||||
|
||||
## Auth Methods
|
||||
|
||||
- **Direct API key supported:** Yes / No
|
||||
- **Aden OAuth supported:** Yes / No
|
||||
|
||||
If Aden OAuth is supported, describe the OAuth scopes/permissions required.
|
||||
|
||||
## How to Get the Credential
|
||||
|
||||
Link where users obtain the key/token:
|
||||
|
||||
[e.g., https://api.slack.com/apps]
|
||||
|
||||
Step-by-step instructions:
|
||||
|
||||
1. Go to ...
|
||||
2. Create a ...
|
||||
3. Select scopes/permissions: ...
|
||||
4. Copy the key/token
|
||||
|
||||
## Health Check
|
||||
|
||||
A lightweight API call to validate the credential (no writes, no charges).
|
||||
|
||||
- **Endpoint:** [e.g., `https://slack.com/api/auth.test`]
|
||||
- **Method:** [e.g., `GET` or `POST`]
|
||||
- **Auth header:** [e.g., `Authorization: Bearer {token}` or `X-Api-Key: {key}`]
|
||||
- **Parameters (if any):** [e.g., `?limit=1`]
|
||||
- **200 means:** [e.g., key is valid]
|
||||
- **401 means:** [e.g., invalid or expired]
|
||||
- **429 means:** [e.g., rate limited but key is valid]
|
||||
|
||||
## Credential Group
|
||||
|
||||
Does this require multiple credentials configured together? (e.g., Google Custom Search needs
|
||||
both an API key and a CSE ID)
|
||||
|
||||
- [ ] No, single credential
|
||||
- [ ] Yes — list the other credential IDs in the group:
|
||||
|
||||
## Additional Context
|
||||
|
||||
Links to API docs, rate limits, free tier availability, or anything else relevant.
|
||||
+40
-19
@@ -21,23 +21,22 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
run: uv sync --project core --group dev
|
||||
|
||||
- name: Ruff lint
|
||||
run: |
|
||||
ruff check core/
|
||||
ruff check tools/
|
||||
uv run --project core ruff check core/
|
||||
uv run --project core ruff check tools/
|
||||
|
||||
- name: Ruff format
|
||||
run: |
|
||||
ruff format --check core/
|
||||
ruff format --check tools/
|
||||
uv run --project core ruff format --check core/
|
||||
uv run --project core ruff format --check tools/
|
||||
|
||||
test:
|
||||
name: Test Python Framework
|
||||
@@ -52,23 +51,23 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd core
|
||||
pytest tests/ -v
|
||||
uv run pytest tests/ -v
|
||||
|
||||
validate:
|
||||
name: Validate Agent Exports
|
||||
test-tools:
|
||||
name: Test Tools
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -76,13 +75,35 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies and run tests
|
||||
run: |
|
||||
cd tools
|
||||
uv sync --extra dev
|
||||
uv run pytest tests/ -v
|
||||
|
||||
validate:
|
||||
name: Validate Agent Exports
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test, test-tools]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Validate exported agents
|
||||
run: |
|
||||
|
||||
@@ -80,7 +80,13 @@ jobs:
|
||||
- help wanted: Extra attention is needed (if issue needs community input)
|
||||
- backlog: Tracked for the future, but not currently planned or prioritized
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug" and "help wanted").
|
||||
### 6. Estimate size (if NOT a duplicate, spam, or invalid)
|
||||
Apply exactly ONE size label to help contributors match their capacity to the task:
|
||||
- "size: small": Docs, typos, single-file fixes, config changes
|
||||
- "size: medium": Bug fixes with tests, adding a single tool, changes within one package
|
||||
- "size: large": Cross-package changes (core + tools), new modules, complex logic, architectural refactors
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug", "size: small", and "good first issue").
|
||||
|
||||
## Tools Available:
|
||||
- mcp__github__get_issue: Get issue details
|
||||
|
||||
@@ -21,18 +21,19 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd core
|
||||
pytest tests/ -v
|
||||
uv run pytest tests/ -v
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
|
||||
+6
-1
@@ -69,4 +69,9 @@ exports/*
|
||||
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.venv
|
||||
.claude/settings.local.json
|
||||
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": ".venv/bin/python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core",
|
||||
"env": {
|
||||
"PYTHONPATH": "../tools/src"
|
||||
}
|
||||
"command": "uv",
|
||||
"args": ["run", "-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core"
|
||||
},
|
||||
"tools": {
|
||||
"command": ".venv/bin/python",
|
||||
"args": ["mcp_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"env": {
|
||||
"PYTHONPATH": "src:../core"
|
||||
}
|
||||
"command": "uv",
|
||||
"args": ["run", "mcp_server.py", "--stdio"],
|
||||
"cwd": "tools"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+20
-22
@@ -44,7 +44,7 @@ Aden Agent Framework is a Python-based system for building goal-driven, self-imp
|
||||
Ensure you have installed:
|
||||
|
||||
- **Python 3.11+** - [Download](https://www.python.org/downloads/) (3.12 or 3.13 recommended)
|
||||
- **pip** - Package installer for Python (comes with Python)
|
||||
- **uv** - Python package manager ([Install](https://docs.astral.sh/uv/getting-started/installation/))
|
||||
- **git** - Version control
|
||||
- **Claude Code** - [Install](https://docs.anthropic.com/claude/docs/claude-code) (optional, for using building skills)
|
||||
|
||||
@@ -52,7 +52,7 @@ Verify installation:
|
||||
|
||||
```bash
|
||||
python --version # Should be 3.11+
|
||||
pip --version # Should be latest
|
||||
uv --version # Should be latest
|
||||
git --version # Any recent version
|
||||
```
|
||||
|
||||
@@ -128,8 +128,12 @@ hive/ # Repository root
|
||||
│
|
||||
├── .github/ # GitHub configuration
|
||||
│ ├── workflows/
|
||||
│ │ ├── ci.yml # Runs on every PR
|
||||
│ │ └── release.yml # Runs on tags
|
||||
│ │ ├── ci.yml # Lint, test, validate on every PR
|
||||
│ │ ├── release.yml # Runs on tags
|
||||
│ │ ├── pr-requirements.yml # PR requirement checks
|
||||
│ │ ├── pr-check-command.yml # PR check commands
|
||||
│ │ ├── claude-issue-triage.yml # Automated issue triage
|
||||
│ │ └── auto-close-duplicates.yml # Close duplicate issues
|
||||
│ ├── ISSUE_TEMPLATE/ # Bug report & feature request templates
|
||||
│ ├── PULL_REQUEST_TEMPLATE.md # PR description template
|
||||
│ └── CODEOWNERS # Auto-assign reviewers
|
||||
@@ -166,7 +170,6 @@ hive/ # Repository root
|
||||
│ │ ├── testing/ # Testing utilities
|
||||
│ │ └── __init__.py
|
||||
│ ├── pyproject.toml # Package metadata and dependencies
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ ├── README.md # Framework documentation
|
||||
│ ├── MCP_INTEGRATION_GUIDE.md # MCP server integration guide
|
||||
│ └── docs/ # Protocol documentation
|
||||
@@ -182,7 +185,6 @@ hive/ # Repository root
|
||||
│ │ ├── mcp_server.py # HTTP MCP server
|
||||
│ │ └── __init__.py
|
||||
│ ├── pyproject.toml # Package metadata
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ └── README.md # Tools documentation
|
||||
│
|
||||
├── exports/ # AGENT PACKAGES (user-created, gitignored)
|
||||
@@ -191,14 +193,16 @@ hive/ # Repository root
|
||||
├── docs/ # Documentation
|
||||
│ ├── getting-started.md # Quick start guide
|
||||
│ ├── configuration.md # Configuration reference
|
||||
│ ├── architecture.md # System architecture
|
||||
│ └── articles/ # Technical articles
|
||||
│ ├── architecture/ # System architecture
|
||||
│ ├── articles/ # Technical articles
|
||||
│ ├── quizzes/ # Developer quizzes
|
||||
│ └── i18n/ # Translations
|
||||
│
|
||||
├── scripts/ # Build & utility scripts
|
||||
│ ├── setup-python.sh # Python environment setup
|
||||
│ └── setup.sh # Legacy setup script
|
||||
│
|
||||
├── quickstart.sh # Install Claude Code skills
|
||||
├── quickstart.sh # Interactive setup wizard
|
||||
├── ENVIRONMENT_SETUP.md # Complete Python setup guide
|
||||
├── README.md # Project overview
|
||||
├── DEVELOPER.md # This file
|
||||
@@ -375,7 +379,7 @@ def test_ticket_categorization():
|
||||
- **PEP 8** - Follow Python style guide
|
||||
- **Type hints** - Use for function signatures and class attributes
|
||||
- **Docstrings** - Document classes and public functions
|
||||
- **Black** - Code formatter (run with `black .`)
|
||||
- **Ruff** - Linter and formatter (run with `make check`)
|
||||
|
||||
```python
|
||||
# Good
|
||||
@@ -509,8 +513,8 @@ chore(deps): update React to 18.2.0
|
||||
|
||||
1. Create a feature branch from `main`
|
||||
2. Make your changes with clear commits
|
||||
3. Run tests locally: `PYTHONPATH=core:exports python -m pytest`
|
||||
4. Run linting: `black --check .`
|
||||
3. Run tests locally: `make test`
|
||||
4. Run linting: `make check`
|
||||
5. Push and create a PR
|
||||
6. Fill out the PR template
|
||||
7. Request review from CODEOWNERS
|
||||
@@ -528,16 +532,11 @@ chore(deps): update React to 18.2.0
|
||||
```bash
|
||||
# Add to core framework
|
||||
cd core
|
||||
pip install <package>
|
||||
# Then add to requirements.txt or pyproject.toml
|
||||
uv add <package>
|
||||
|
||||
# Add to tools package
|
||||
cd tools
|
||||
pip install <package>
|
||||
# Then add to requirements.txt or pyproject.toml
|
||||
|
||||
# Reinstall in editable mode
|
||||
pip install -e .
|
||||
uv add <package>
|
||||
```
|
||||
|
||||
### Creating a New Agent
|
||||
@@ -670,9 +669,8 @@ cat .env
|
||||
# Or check shell environment
|
||||
echo $ANTHROPIC_API_KEY
|
||||
|
||||
# Copy from .env.example if needed
|
||||
cp .env.example .env
|
||||
# Then edit .env with your API keys
|
||||
# Create .env if needed
|
||||
# Then add your API keys
|
||||
```
|
||||
|
||||
|
||||
|
||||
+86
-3
@@ -21,6 +21,43 @@ This will:
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
## Quick Setup (Windows – PowerShell)
|
||||
|
||||
Windows users can use the native PowerShell setup script.
|
||||
|
||||
Before running the script, allow script execution for the current session:
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
```
|
||||
|
||||
Run setup from the project root:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
This will:
|
||||
|
||||
- Check Python version (requires 3.11+)
|
||||
- Create a local `.venv` virtual environment
|
||||
- Install the core framework package (`framework`)
|
||||
- Install the tools package (`aden_tools`)
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
After setup, activate the virtual environment:
|
||||
|
||||
```powershell
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Set `PYTHONPATH` (required in every new PowerShell session):
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
```
|
||||
|
||||
## Alpine Linux Setup
|
||||
|
||||
If you are using Alpine Linux (e.g., inside a Docker container), you must install system dependencies and use a virtual environment before running the setup script:
|
||||
@@ -100,6 +137,12 @@ For running agents with real LLMs:
|
||||
export ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
$env:ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
## Running Agents
|
||||
|
||||
All agent commands must be run from the project root with `PYTHONPATH` set:
|
||||
@@ -109,9 +152,14 @@ All agent commands must be run from the project root with `PYTHONPATH` set:
|
||||
PYTHONPATH=core:exports python -m agent_name COMMAND
|
||||
```
|
||||
|
||||
### Example Commands
|
||||
Windows (PowerShell):
|
||||
|
||||
After building an agent via `/building-agents-construction`, use these commands:
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
python -m agent_name COMMAND
|
||||
```
|
||||
|
||||
### Example: Support Ticket Agent
|
||||
|
||||
```bash
|
||||
# Validate agent structure
|
||||
@@ -248,6 +296,14 @@ source .venv/bin/activate
|
||||
PYTHONPATH=core:exports python -m your_agent_name demo
|
||||
```
|
||||
|
||||
### PowerShell: “running scripts is disabled on this system”
|
||||
|
||||
Run once per session:
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'framework'"
|
||||
|
||||
**Solution:** Install the core package:
|
||||
@@ -270,6 +326,12 @@ Or run the setup script:
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'openai.\_models'"
|
||||
|
||||
**Cause:** Outdated `openai` package (0.27.x) incompatible with `litellm`
|
||||
@@ -284,12 +346,21 @@ pip install --upgrade "openai>=1.0.0"
|
||||
|
||||
**Cause:** Not running from project root, missing PYTHONPATH, or agent not yet created
|
||||
|
||||
**Solution:** Ensure you're in the project root directory, have built an agent, and use:
|
||||
**Solution:** Ensure you're in `/hive/` and use:
|
||||
|
||||
Linux/macOS:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m your_agent_name validate
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
python -m support_ticket_agent validate
|
||||
```
|
||||
|
||||
### Agent imports fail with "broken installation"
|
||||
|
||||
**Symptom:** `pip list` shows packages pointing to non-existent directories
|
||||
@@ -304,6 +375,12 @@ pip uninstall -y framework tools
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
## Package Structure
|
||||
|
||||
The Hive framework consists of three Python packages:
|
||||
@@ -402,6 +479,12 @@ This design allows agents in `exports/` to be:
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
### 2. Build Agent (Claude Code)
|
||||
|
||||
```
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
[](https://github.com/adenhq/hive/blob/main/LICENSE)
|
||||
[](https://www.ycombinator.com/companies/aden)
|
||||
[](https://hub.docker.com/u/adenhq)
|
||||
[](https://discord.com/invite/MXE49hrKDk)
|
||||
[](https://x.com/aden_hq)
|
||||
[](https://www.linkedin.com/company/teamaden/)
|
||||
@@ -40,6 +39,31 @@ Build reliable, self-improving AI agents without hardcoding workflows. Define yo
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
Hive is designed for developers and teams who want to build **production-grade AI agents** without manually wiring complex workflows.
|
||||
|
||||
Hive is a good fit if you:
|
||||
|
||||
- Want AI agents that **execute real business processes**, not demos
|
||||
- Prefer **goal-driven development** over hardcoded workflows
|
||||
- Need **self-healing and adaptive agents** that improve over time
|
||||
- Require **human-in-the-loop control**, observability, and cost limits
|
||||
- Plan to run agents in **production environments**
|
||||
|
||||
Hive may not be the best fit if you’re only experimenting with simple agent chains or one-off scripts.
|
||||
|
||||
## When Should You Use Hive?
|
||||
|
||||
Use Hive when you need:
|
||||
|
||||
- Long-running, autonomous agents
|
||||
- Multi-agent coordination
|
||||
- Continuous improvement based on failures
|
||||
- Strong monitoring, safety, and budget controls
|
||||
- A framework that evolves with your goals
|
||||
|
||||
|
||||
## What is Aden
|
||||
|
||||
<p align="center">
|
||||
@@ -64,11 +88,13 @@ Aden is a platform for building, deploying, operating, and adapting AI agents:
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
## Prerequisites
|
||||
|
||||
- [Python 3.11+](https://www.python.org/downloads/) for agent development
|
||||
- Python 3.11+ for agent development
|
||||
- Claude Code or Cursor for utilizing agent skills
|
||||
|
||||
> **Note for Windows Users:** It is strongly recommended to use **WSL (Windows Subsystem for Linux)** or **Git Bash** to run this framework. Some core automation scripts may not execute correctly in standard Command Prompt or PowerShell.
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
|
||||
+1
-1
@@ -268,7 +268,7 @@ classDef done fill:#9e9e9e,color:#fff,stroke:#757575
|
||||
- [ ] Wake-up Tool (resume agent tasks)
|
||||
|
||||
### Deployment (Self-Hosted)
|
||||
- [ ] Docker container standardization
|
||||
- [ ] Workder agent docker container standardization
|
||||
- [ ] Headless backend execution
|
||||
- [ ] Exposed API for frontend attachment
|
||||
- [ ] Local monitoring & observability
|
||||
|
||||
@@ -0,0 +1,740 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
EventLoopNode WebSocket Demo
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams EventLoopNode execution to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/event_loop_wss_demo.py
|
||||
|
||||
Then open http://localhost:8765 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
import os # noqa: E402
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state (shared across WebSocket connections)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_demo_"))
|
||||
STORE = FileConversationStore(STORE_DIR / "conversation")
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — real tools via ToolRegistry (same pattern as GraphExecutor)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
# Credential store: Aden sync (OAuth2 tokens) + encrypted files + env var fallback
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_local_storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
|
||||
if os.environ.get("ADEN_API_KEY"):
|
||||
try:
|
||||
from framework.credentials.aden import ( # noqa: E402
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
_client = AdenCredentialClient(AdenClientConfig(base_url="https://api.adenhq.com"))
|
||||
_provider = AdenSyncProvider(client=_client)
|
||||
_storage = AdenCachedStorage(
|
||||
local_storage=_local_storage,
|
||||
aden_provider=_provider,
|
||||
)
|
||||
_cred_store = CredentialStore(storage=_storage, providers=[_provider], auto_refresh=True)
|
||||
_synced = _provider.sync_all(_cred_store)
|
||||
logger.info("Synced %d credentials from Aden", _synced)
|
||||
except Exception as e:
|
||||
logger.warning("Aden sync unavailable: %s", e)
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
else:
|
||||
logger.info("ADEN_API_KEY not set, using local credential storage")
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
|
||||
CREDENTIALS = CredentialStoreAdapter(_cred_store)
|
||||
|
||||
# Debug: log which credentials resolved
|
||||
for _name in ["brave_search", "hubspot", "anthropic"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# --- web_search (Brave Search API) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={"X-Subscription-Token": api_key, "Accept": "application/json"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
# --- web_scrape (httpx + BeautifulSoup, no playwright for sync compat) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(url, timeout=30.0, follow_redirects=True, headers=_SCRAPE_HEADERS)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {"url": url, "title": title, "content": text, "length": len(text)}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
# --- HubSpot CRM tools (optional, requires HUBSPOT_ACCESS_TOKEN) ---
|
||||
|
||||
_HUBSPOT_API = "https://api.hubapi.com"
|
||||
|
||||
|
||||
def _hubspot_headers() -> dict | None:
|
||||
token = CREDENTIALS.get("hubspot")
|
||||
if token:
|
||||
logger.debug("HubSpot token: %s...%s (len=%d)", token[:8], token[-4:], len(token))
|
||||
else:
|
||||
logger.debug("HubSpot token: not found")
|
||||
if not token:
|
||||
return None
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _exec_hubspot_search(inputs: dict) -> dict:
|
||||
headers = _hubspot_headers()
|
||||
if not headers:
|
||||
return {"error": "HUBSPOT_ACCESS_TOKEN not set"}
|
||||
object_type = inputs.get("object_type", "contacts")
|
||||
query = inputs.get("query", "")
|
||||
limit = min(inputs.get("limit", 10), 100)
|
||||
body: dict = {"limit": limit}
|
||||
if query:
|
||||
body["query"] = query
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_HUBSPOT_API}/crm/v3/objects/{object_type}/search",
|
||||
headers=headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HubSpot API HTTP {resp.status_code}: {resp.text[:200]}"}
|
||||
return resp.json()
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"HubSpot error: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="hubspot_search",
|
||||
tool=Tool(
|
||||
name="hubspot_search",
|
||||
description=(
|
||||
"Search HubSpot CRM objects (contacts, companies, or deals). "
|
||||
"Returns matching records with their properties."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"object_type": {
|
||||
"type": "string",
|
||||
"description": "CRM object type: 'contacts', 'companies', or 'deals'",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (name, email, domain, etc.)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (1-100, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["object_type"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_hubspot_search(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page (embedded)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>EventLoopNode Live Demo</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117; color: #c9d1d9;
|
||||
height: 100vh; display: flex; flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22; padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex; align-items: center; gap: 16px;
|
||||
}
|
||||
header h1 { font-size: 16px; color: #58a6ff; font-weight: 600; }
|
||||
.status {
|
||||
font-size: 12px; padding: 3px 10px; border-radius: 12px;
|
||||
background: #21262d; color: #8b949e;
|
||||
}
|
||||
.status.running { background: #1a4b2e; color: #3fb950; }
|
||||
.status.done { background: #1a3a5c; color: #58a6ff; }
|
||||
.status.error { background: #4b1a1a; color: #f85149; }
|
||||
.chat { flex: 1; overflow-y: auto; padding: 16px; }
|
||||
.msg {
|
||||
margin: 8px 0; padding: 10px 14px; border-radius: 8px;
|
||||
line-height: 1.6; white-space: pre-wrap; word-wrap: break-word;
|
||||
}
|
||||
.msg.user { background: #1a3a5c; color: #58a6ff; }
|
||||
.msg.assistant { background: #161b22; color: #c9d1d9; }
|
||||
.msg.event {
|
||||
background: transparent; color: #8b949e; font-size: 11px;
|
||||
padding: 4px 14px; border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop { border-left-color: #58a6ff; }
|
||||
.msg.event.tool { border-left-color: #d29922; }
|
||||
.msg.event.stall { border-left-color: #f85149; }
|
||||
.input-bar {
|
||||
padding: 12px 16px; background: #161b22;
|
||||
border-top: 1px solid #30363d; display: flex; gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1; background: #0d1117; border: 1px solid #30363d;
|
||||
color: #c9d1d9; padding: 8px 12px; border-radius: 6px;
|
||||
font-family: inherit; font-size: 14px; outline: none;
|
||||
}
|
||||
.input-bar input:focus { border-color: #58a6ff; }
|
||||
.input-bar button {
|
||||
background: #238636; color: #fff; border: none;
|
||||
padding: 8px 20px; border-radius: 6px; cursor: pointer;
|
||||
font-family: inherit; font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover { background: #2ea043; }
|
||||
.input-bar button:disabled {
|
||||
background: #21262d; color: #484f58; cursor: not-allowed;
|
||||
}
|
||||
.input-bar button.clear { background: #da3633; }
|
||||
.input-bar button.clear:hover { background: #f85149; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>EventLoopNode Live</h1>
|
||||
<span id="status" class="status">Idle</span>
|
||||
<span id="iter" class="status" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Ask anything..." autofocus />
|
||||
<button id="go" onclick="run()">Send</button>
|
||||
<button class="clear"
|
||||
onclick="clearConversation()">Clear</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
const chat = document.getElementById('chat');
|
||||
const status = document.getElementById('status');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setStatus(text, cls) {
|
||||
status.textContent = text;
|
||||
status.className = 'status ' + cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setStatus('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setStatus('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'ready') {
|
||||
setStatus('Ready', 'done');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg('RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls);
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
}
|
||||
else if (evt.type === 'result') {
|
||||
setStatus('Session ended', evt.success ? 'done' : 'error');
|
||||
if (evt.error) addMsg('ERROR ' + evt.error, 'event stall');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
else if (evt.type === 'cleared') {
|
||||
chat.innerHTML = '';
|
||||
iterCount = 0;
|
||||
iterEl.textContent = 'Step 0';
|
||||
iterEl.style.display = 'none';
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
setStatus('Running', 'running');
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
function clearConversation() {
|
||||
if (ws && ws.readyState === 1) {
|
||||
ws.send(JSON.stringify({ command: 'clear' }));
|
||||
}
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Persistent WebSocket: long-lived EventLoopNode with client_facing blocking."""
|
||||
global STORE
|
||||
|
||||
# -- Event forwarding (WebSocket ← EventBus) ----------------------------
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
# -- Per-connection state -----------------------------------------------
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
node_spec = NodeSpec(
|
||||
id="assistant",
|
||||
name="Chat Assistant",
|
||||
description="A conversational assistant that remembers context across messages",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to tools. "
|
||||
"You can search the web, scrape webpages, and query HubSpot CRM. "
|
||||
"Use tools when the user asks for current information or external data. "
|
||||
"You have full conversation history, so you can reference previous messages."
|
||||
),
|
||||
)
|
||||
|
||||
# -- Ready callback: subscribe to CLIENT_INPUT_REQUESTED on the bus ---
|
||||
async def on_input_requested(event):
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "ready"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=on_input_requested,
|
||||
)
|
||||
|
||||
async def start_loop(first_message: str):
|
||||
"""Create an EventLoopNode and run it as a background task."""
|
||||
nonlocal node, loop_task
|
||||
|
||||
memory = SharedMemory()
|
||||
ctx = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="assistant",
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
config=LoopConfig(max_iterations=10_000, max_history_tokens=32_000),
|
||||
conversation_store=STORE,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
await node.inject_event(first_message)
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"error": result.error,
|
||||
"tokens": result.tokens_used,
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Loop ended: success={result.success}, tokens={result.tokens_used}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("Loop stopped: WebSocket closed")
|
||||
except Exception as e:
|
||||
logger.exception("Loop error")
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"output": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
loop_task = asyncio.create_task(_run())
|
||||
|
||||
async def stop_loop():
|
||||
"""Signal the node and wait for the loop task to finish."""
|
||||
nonlocal node, loop_task
|
||||
if loop_task and not loop_task.done():
|
||||
if node:
|
||||
node.signal_shutdown()
|
||||
try:
|
||||
await asyncio.wait_for(loop_task, timeout=5.0)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
loop_task.cancel()
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
# -- Message loop (runs for the lifetime of this WebSocket) -------------
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Clear command
|
||||
if msg.get("command") == "clear":
|
||||
import shutil
|
||||
|
||||
await stop_loop()
|
||||
await STORE.close()
|
||||
conv_dir = STORE_DIR / "conversation"
|
||||
if conv_dir.exists():
|
||||
shutil.rmtree(conv_dir)
|
||||
STORE = FileConversationStore(conv_dir)
|
||||
await websocket.send(json.dumps({"type": "cleared"}))
|
||||
logger.info("Conversation cleared")
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
if node is None:
|
||||
# First message — spin up the loop
|
||||
logger.info(f"Starting persistent loop: {topic}")
|
||||
await start_loop(topic)
|
||||
else:
|
||||
# Subsequent message — inject into the running loop
|
||||
logger.info(f"Injecting message: {topic}")
|
||||
await node.inject_event(topic)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
await stop_loop()
|
||||
logger.info("WebSocket closed, loop stopped")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler for serving the HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None # let websockets handle the upgrade
|
||||
# Serve the HTML page for any other path
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8765
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Demo running at http://localhost:{port}")
|
||||
logger.info("Open in your browser and enter a topic to research.")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,930 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Two-Node ContextHandoff Demo
|
||||
|
||||
Demonstrates ContextHandoff between two EventLoopNode instances:
|
||||
Node A (Researcher) → ContextHandoff → Node B (Analyst)
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams both nodes to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/handoff_demo.py
|
||||
|
||||
Then open http://localhost:8766 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.context_handoff import ContextHandoff # noqa: E402
|
||||
from framework.graph.conversation import NodeConversation # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("handoff_demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_handoff_"))
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Credentials
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Composite credential store: encrypted files (primary) + env vars (fallback)
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_composite = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
CREDENTIALS = CredentialStoreAdapter(CredentialStore(storage=_composite))
|
||||
|
||||
for _name in ["brave_search", "hubspot"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — web_search + web_scrape for Node A (Researcher)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(
|
||||
url,
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
headers=_SCRAPE_HEADERS,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": text,
|
||||
"length": len(text),
|
||||
}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Node Specs
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
RESEARCHER_SPEC = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches a topic using web search and scraping tools",
|
||||
node_type="event_loop",
|
||||
input_keys=["topic"],
|
||||
output_keys=["research_summary"],
|
||||
system_prompt=(
|
||||
"You are a thorough research assistant. Your job is to research "
|
||||
"the given topic using the web_search and web_scrape tools.\n\n"
|
||||
"1. Search for relevant information on the topic\n"
|
||||
"2. Scrape 1-2 of the most promising URLs for details\n"
|
||||
"3. Synthesize your findings into a comprehensive summary\n"
|
||||
"4. Use set_output with key='research_summary' to save your "
|
||||
"findings\n\n"
|
||||
"Be thorough but efficient. Aim for 2-4 search/scrape calls, "
|
||||
"then summarize and set_output."
|
||||
),
|
||||
)
|
||||
|
||||
ANALYST_SPEC = NodeSpec(
|
||||
id="analyst",
|
||||
name="Analyst",
|
||||
description="Analyzes research findings and provides insights",
|
||||
node_type="event_loop",
|
||||
input_keys=["context"],
|
||||
output_keys=["analysis"],
|
||||
system_prompt=(
|
||||
"You are a strategic analyst. You receive research findings from "
|
||||
"a previous researcher and must:\n\n"
|
||||
"1. Identify key themes and patterns\n"
|
||||
"2. Assess the reliability and significance of the findings\n"
|
||||
"3. Provide actionable insights and recommendations\n"
|
||||
"4. Use set_output with key='analysis' to save your analysis\n\n"
|
||||
"Be concise but insightful. Focus on what matters most."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>ContextHandoff Demo</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117;
|
||||
color: #c9d1d9;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22;
|
||||
padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
header h1 {
|
||||
font-size: 16px;
|
||||
color: #58a6ff;
|
||||
font-weight: 600;
|
||||
}
|
||||
.badge {
|
||||
font-size: 12px;
|
||||
padding: 3px 10px;
|
||||
border-radius: 12px;
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.researcher {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.badge.analyst {
|
||||
background: #1a4b2e;
|
||||
color: #3fb950;
|
||||
}
|
||||
.badge.handoff {
|
||||
background: #3d1f00;
|
||||
color: #d29922;
|
||||
}
|
||||
.badge.done {
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.error {
|
||||
background: #4b1a1a;
|
||||
color: #f85149;
|
||||
}
|
||||
.chat {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 16px;
|
||||
}
|
||||
.msg {
|
||||
margin: 8px 0;
|
||||
padding: 10px 14px;
|
||||
border-radius: 8px;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
.msg.user {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.msg.assistant {
|
||||
background: #161b22;
|
||||
color: #c9d1d9;
|
||||
}
|
||||
.msg.assistant.analyst-msg {
|
||||
border-left: 3px solid #3fb950;
|
||||
}
|
||||
.msg.event {
|
||||
background: transparent;
|
||||
color: #8b949e;
|
||||
font-size: 11px;
|
||||
padding: 4px 14px;
|
||||
border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop {
|
||||
border-left-color: #58a6ff;
|
||||
}
|
||||
.msg.event.tool {
|
||||
border-left-color: #d29922;
|
||||
}
|
||||
.msg.event.stall {
|
||||
border-left-color: #f85149;
|
||||
}
|
||||
.handoff-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #1c1200;
|
||||
border: 1px solid #d29922;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.handoff-banner h3 {
|
||||
color: #d29922;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.handoff-banner p, .result-banner p {
|
||||
color: #8b949e;
|
||||
font-size: 12px;
|
||||
line-height: 1.5;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap;
|
||||
text-align: left;
|
||||
}
|
||||
.result-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #0a2614;
|
||||
border: 1px solid #3fb950;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.result-banner h3 {
|
||||
color: #3fb950;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.result-banner .label {
|
||||
color: #58a6ff;
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
.result-banner .tokens {
|
||||
color: #484f58;
|
||||
font-size: 11px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.input-bar {
|
||||
padding: 12px 16px;
|
||||
background: #161b22;
|
||||
border-top: 1px solid #30363d;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1;
|
||||
background: #0d1117;
|
||||
border: 1px solid #30363d;
|
||||
color: #c9d1d9;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
font-family: inherit;
|
||||
font-size: 14px;
|
||||
outline: none;
|
||||
}
|
||||
.input-bar input:focus {
|
||||
border-color: #58a6ff;
|
||||
}
|
||||
.input-bar button {
|
||||
background: #238636;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 20px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover {
|
||||
background: #2ea043;
|
||||
}
|
||||
.input-bar button:disabled {
|
||||
background: #21262d;
|
||||
color: #484f58;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>ContextHandoff Demo</h1>
|
||||
<span id="phase" class="badge">Idle</span>
|
||||
<span id="iter" class="badge" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Enter a research topic..." autofocus />
|
||||
<button id="go" onclick="run()">Research</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
let currentPhase = 'idle';
|
||||
const chat = document.getElementById('chat');
|
||||
const phase = document.getElementById('phase');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setPhase(text, cls) {
|
||||
phase.textContent = text;
|
||||
phase.className = 'badge ' + cls;
|
||||
currentPhase = cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function addHandoffBanner(summary) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'handoff-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Context Handoff: Researcher -> Analyst';
|
||||
const p = document.createElement('p');
|
||||
p.textContent = summary || 'Passing research context...';
|
||||
banner.appendChild(h3);
|
||||
banner.appendChild(p);
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function addResultBanner(researcher, analyst, tokens) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'result-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Pipeline Complete';
|
||||
banner.appendChild(h3);
|
||||
|
||||
if (researcher && researcher.research_summary) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'RESEARCH SUMMARY';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = researcher.research_summary;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (analyst && analyst.analysis) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'ANALYSIS';
|
||||
lbl.style.color = '#3fb950';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = analyst.analysis;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (tokens) {
|
||||
const t = document.createElement('div');
|
||||
t.className = 'tokens';
|
||||
t.textContent = 'Total tokens: ' + tokens.toLocaleString();
|
||||
banner.appendChild(t);
|
||||
}
|
||||
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setPhase('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setPhase('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setPhase('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'phase') {
|
||||
if (evt.phase === 'researcher') {
|
||||
setPhase('Researcher', 'researcher');
|
||||
} else if (evt.phase === 'handoff') {
|
||||
setPhase('Handoff', 'handoff');
|
||||
} else if (evt.phase === 'analyst') {
|
||||
setPhase('Analyst', 'analyst');
|
||||
}
|
||||
iterCount = 0;
|
||||
iterEl.style.display = 'none';
|
||||
}
|
||||
else if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg(
|
||||
'RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls
|
||||
);
|
||||
var assistCls = currentPhase === 'analyst'
|
||||
? 'assistant analyst-msg' : 'assistant';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'handoff_context') {
|
||||
addHandoffBanner(evt.summary);
|
||||
var assistCls = 'assistant analyst-msg';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'node_result') {
|
||||
if (evt.node_id === 'researcher') {
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'done') {
|
||||
setPhase('Done', 'done');
|
||||
iterEl.style.display = 'none';
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
currentAssistantEl = null;
|
||||
addResultBanner(
|
||||
evt.researcher, evt.analyst, evt.total_tokens
|
||||
);
|
||||
goBtn.disabled = false;
|
||||
inputEl.placeholder = 'Enter another topic...';
|
||||
}
|
||||
else if (evt.type === 'error') {
|
||||
setPhase('Error', 'error');
|
||||
addMsg('ERROR ' + evt.message, 'event stall');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
chat.innerHTML = '';
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler — sequential Node A → Handoff → Node B
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Run the two-node handoff pipeline per user message."""
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
logger.info(f"Starting handoff pipeline for: {topic}")
|
||||
|
||||
try:
|
||||
await _run_pipeline(websocket, topic)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("WebSocket closed during pipeline")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Pipeline error")
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_pipeline(websocket, topic: str):
|
||||
"""Execute: Node A (research) → ContextHandoff → Node B (analysis)."""
|
||||
import shutil
|
||||
|
||||
# Fresh stores for each run
|
||||
run_dir = Path(tempfile.mkdtemp(prefix="hive_run_", dir=STORE_DIR))
|
||||
store_a = FileConversationStore(run_dir / "node_a")
|
||||
store_b = FileConversationStore(run_dir / "node_b")
|
||||
|
||||
# Shared event bus
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
# ---- Phase 1: Researcher ------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "researcher"}))
|
||||
|
||||
node_a = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge: accept when output_keys filled
|
||||
config=LoopConfig(
|
||||
max_iterations=20,
|
||||
max_tool_calls_per_turn=10,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_a,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
ctx_a = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="researcher",
|
||||
node_spec=RESEARCHER_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"topic": topic},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
result_a = await node_a.execute(ctx_a)
|
||||
logger.info(
|
||||
"Researcher done: success=%s, tokens=%s",
|
||||
result_a.success,
|
||||
result_a.tokens_used,
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_result",
|
||||
"node_id": "researcher",
|
||||
"success": result_a.success,
|
||||
"output": result_a.output,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if not result_a.success:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": f"Researcher failed: {result_a.error}",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# ---- Phase 2: Context Handoff -------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "handoff"}))
|
||||
|
||||
# Restore the researcher's conversation from store
|
||||
conversation_a = await NodeConversation.restore(store_a)
|
||||
if conversation_a is None:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Failed to restore researcher conversation",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
handoff_engine = ContextHandoff(llm=LLM)
|
||||
handoff_context = handoff_engine.summarize_conversation(
|
||||
conversation=conversation_a,
|
||||
node_id="researcher",
|
||||
output_keys=["research_summary"],
|
||||
)
|
||||
|
||||
formatted_handoff = ContextHandoff.format_as_input(handoff_context)
|
||||
logger.info(
|
||||
"Handoff: %d turns, ~%d tokens, keys=%s",
|
||||
handoff_context.turn_count,
|
||||
handoff_context.total_tokens_used,
|
||||
list(handoff_context.key_outputs.keys()),
|
||||
)
|
||||
|
||||
# Send handoff context to browser
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "handoff_context",
|
||||
"summary": handoff_context.summary[:500],
|
||||
"turn_count": handoff_context.turn_count,
|
||||
"tokens": handoff_context.total_tokens_used,
|
||||
"key_outputs": handoff_context.key_outputs,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# ---- Phase 3: Analyst ---------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "analyst"}))
|
||||
|
||||
node_b = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
max_tool_calls_per_turn=5,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_b,
|
||||
)
|
||||
|
||||
ctx_b = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="analyst",
|
||||
node_spec=ANALYST_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"context": formatted_handoff},
|
||||
llm=LLM,
|
||||
available_tools=[],
|
||||
)
|
||||
|
||||
result_b = await node_b.execute(ctx_b)
|
||||
logger.info(
|
||||
"Analyst done: success=%s, tokens=%s",
|
||||
result_b.success,
|
||||
result_b.tokens_used,
|
||||
)
|
||||
|
||||
# ---- Done ---------------------------------------------------------------
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "done",
|
||||
"researcher": result_a.output,
|
||||
"analyst": result_b.output,
|
||||
"total_tokens": ((result_a.tokens_used or 0) + (result_b.tokens_used or 0)),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Clean up temp stores
|
||||
try:
|
||||
shutil.rmtree(run_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8766
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Handoff demo at http://localhost:{port}")
|
||||
logger.info("Enter a research topic to start the pipeline.")
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,7 +15,7 @@ You cannot skip steps or bypass validation.
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -26,7 +26,7 @@ from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
|
||||
class BuildPhase(str, Enum):
|
||||
class BuildPhase(StrEnum):
|
||||
"""Current phase of the build process."""
|
||||
|
||||
INIT = "init" # Just started
|
||||
|
||||
@@ -64,6 +64,8 @@ class AdenCachedStorage(CredentialStorage):
|
||||
- **Reads**: Try local cache first, fallback to Aden if stale/missing
|
||||
- **Writes**: Always write to local cache
|
||||
- **Offline resilience**: Uses cached credentials when Aden is unreachable
|
||||
- **Provider-based lookup**: Match credentials by provider name (e.g., "hubspot")
|
||||
when direct ID lookup fails, since Aden uses hash-based IDs internally.
|
||||
|
||||
The cache TTL determines how long to trust local credentials before
|
||||
checking with the Aden server for updates. This balances:
|
||||
@@ -85,6 +87,7 @@ class AdenCachedStorage(CredentialStorage):
|
||||
|
||||
# First access fetches from Aden
|
||||
# Subsequent accesses use cache until TTL expires
|
||||
# Can look up by provider name OR credential ID
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
|
||||
@@ -111,21 +114,24 @@ class AdenCachedStorage(CredentialStorage):
|
||||
self._cache_ttl = timedelta(seconds=cache_ttl_seconds)
|
||||
self._prefer_local = prefer_local
|
||||
self._cache_timestamps: dict[str, datetime] = {}
|
||||
# Index: provider name (e.g., "hubspot") -> credential hash ID
|
||||
self._provider_index: dict[str, str] = {}
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save credential to local cache.
|
||||
Save credential to local cache and update provider index.
|
||||
|
||||
Args:
|
||||
credential: The credential to save.
|
||||
"""
|
||||
self._local.save(credential)
|
||||
self._cache_timestamps[credential.id] = datetime.now(UTC)
|
||||
self._index_provider(credential)
|
||||
logger.debug(f"Cached credential '{credential.id}'")
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load credential from cache, with Aden fallback.
|
||||
Load credential from cache, with Aden fallback and provider-based lookup.
|
||||
|
||||
The loading strategy depends on the `prefer_local` setting:
|
||||
|
||||
@@ -141,8 +147,37 @@ class AdenCachedStorage(CredentialStorage):
|
||||
2. Update local cache with response
|
||||
3. Fall back to local cache only if Aden fails
|
||||
|
||||
Provider-based lookup:
|
||||
When a provider index mapping exists for the credential_id (e.g.,
|
||||
"hubspot" → hash ID), the Aden-synced credential is loaded first.
|
||||
This ensures fresh OAuth tokens from Aden take priority over stale
|
||||
local credentials (env vars, old encrypted files).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
credential_id: The credential identifier or provider name.
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
"""
|
||||
# Check provider index first — Aden-synced credentials take priority
|
||||
resolved_id = self._provider_index.get(credential_id)
|
||||
if resolved_id and resolved_id != credential_id:
|
||||
result = self._load_by_id(resolved_id)
|
||||
if result is not None:
|
||||
logger.info(
|
||||
f"Loaded credential '{credential_id}' via provider index (id='{resolved_id}')"
|
||||
)
|
||||
return result
|
||||
|
||||
# Direct lookup (exact credential_id match)
|
||||
return self._load_by_id(credential_id)
|
||||
|
||||
def _load_by_id(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load credential by exact ID from cache, with Aden fallback.
|
||||
|
||||
Args:
|
||||
credential_id: The exact credential identifier.
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
@@ -200,15 +235,21 @@ class AdenCachedStorage(CredentialStorage):
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if credential exists in local cache.
|
||||
Check if credential exists in local cache (by ID or provider name).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
credential_id: The credential identifier or provider name.
|
||||
|
||||
Returns:
|
||||
True if credential exists locally.
|
||||
"""
|
||||
return self._local.exists(credential_id)
|
||||
if self._local.exists(credential_id):
|
||||
return True
|
||||
# Check provider index
|
||||
resolved_id = self._provider_index.get(credential_id)
|
||||
if resolved_id and resolved_id != credential_id:
|
||||
return self._local.exists(resolved_id)
|
||||
return False
|
||||
|
||||
def _is_cache_fresh(self, credential_id: str) -> bool:
|
||||
"""
|
||||
@@ -242,6 +283,47 @@ class AdenCachedStorage(CredentialStorage):
|
||||
self._cache_timestamps.clear()
|
||||
logger.debug("Invalidated all cache entries")
|
||||
|
||||
def _index_provider(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Index a credential by its provider/integration type.
|
||||
|
||||
Aden credentials carry an ``_integration_type`` key whose value is
|
||||
the provider name (e.g., ``hubspot``). This method maps that
|
||||
provider name to the credential's hash ID so that subsequent
|
||||
``load("hubspot")`` calls resolve to the correct credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to index.
|
||||
"""
|
||||
integration_type_key = credential.keys.get("_integration_type")
|
||||
if integration_type_key is None:
|
||||
return
|
||||
provider_name = integration_type_key.value.get_secret_value()
|
||||
if provider_name:
|
||||
self._provider_index[provider_name] = credential.id
|
||||
logger.debug(f"Indexed provider '{provider_name}' -> '{credential.id}'")
|
||||
|
||||
def rebuild_provider_index(self) -> int:
|
||||
"""
|
||||
Rebuild the provider index from all locally cached credentials.
|
||||
|
||||
Useful after loading from disk when the in-memory index is empty.
|
||||
|
||||
Returns:
|
||||
Number of provider mappings indexed.
|
||||
"""
|
||||
self._provider_index.clear()
|
||||
indexed = 0
|
||||
for cred_id in self._local.list_all():
|
||||
cred = self._local.load(cred_id)
|
||||
if cred:
|
||||
before = len(self._provider_index)
|
||||
self._index_provider(cred)
|
||||
if len(self._provider_index) > before:
|
||||
indexed += 1
|
||||
logger.debug(f"Rebuilt provider index with {indexed} mappings")
|
||||
return indexed
|
||||
|
||||
def sync_all_from_aden(self) -> int:
|
||||
"""
|
||||
Sync all credentials from Aden server to local cache.
|
||||
|
||||
@@ -589,6 +589,149 @@ class TestAdenCachedStorage:
|
||||
assert info["stale"]["is_fresh"] is False
|
||||
assert info["stale"]["ttl_remaining_seconds"] == 0
|
||||
|
||||
def test_save_indexes_provider(self, cached_storage):
|
||||
"""Test save builds the provider index from _integration_type key."""
|
||||
cred = CredentialObject(
|
||||
id="aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token-value"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert cached_storage._provider_index["hubspot"] == "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
|
||||
def test_load_by_provider_name(self, cached_storage):
|
||||
"""Test load resolves provider name to hash-based credential ID."""
|
||||
hash_id = "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("hubspot-token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Save builds the index
|
||||
cached_storage.save(cred)
|
||||
|
||||
# Load by provider name should resolve to the hash ID
|
||||
loaded = cached_storage.load("hubspot")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == hash_id
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "hubspot-token"
|
||||
|
||||
def test_load_by_direct_id_still_works(self, cached_storage):
|
||||
"""Test load by direct hash ID still works as before."""
|
||||
hash_id = "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
# Direct ID lookup should still work
|
||||
loaded = cached_storage.load(hash_id)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == hash_id
|
||||
|
||||
def test_exists_by_provider_name(self, cached_storage):
|
||||
"""Test exists resolves provider name to hash-based credential ID."""
|
||||
hash_id = "c2xhY2s6dGVzdDo5OTk="
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("slack-token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("slack"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert cached_storage.exists("slack") is True
|
||||
assert cached_storage.exists(hash_id) is True
|
||||
assert cached_storage.exists("nonexistent") is False
|
||||
|
||||
def test_rebuild_provider_index(self, cached_storage, local_storage):
|
||||
"""Test rebuild_provider_index reconstructs from local storage."""
|
||||
# Manually save credentials to local storage (bypassing cached_storage.save)
|
||||
for provider_name, hash_id in [("hubspot", "hash_hub"), ("slack", "hash_slack")]:
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr(provider_name),
|
||||
),
|
||||
},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
|
||||
# Index should be empty (we bypassed save)
|
||||
assert len(cached_storage._provider_index) == 0
|
||||
|
||||
# Rebuild
|
||||
indexed = cached_storage.rebuild_provider_index()
|
||||
|
||||
assert indexed == 2
|
||||
assert cached_storage._provider_index["hubspot"] == "hash_hub"
|
||||
assert cached_storage._provider_index["slack"] == "hash_slack"
|
||||
|
||||
def test_save_without_integration_type_no_index(self, cached_storage):
|
||||
"""Test save does not index credentials without _integration_type key."""
|
||||
cred = CredentialObject(
|
||||
id="plain-cred",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={
|
||||
"api_key": CredentialKey(
|
||||
name="api_key",
|
||||
value=SecretStr("key-value"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert "plain-cred" not in cached_storage._provider_index
|
||||
assert len(cached_storage._provider_index) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
|
||||
@@ -8,7 +8,7 @@ containing one or more keys (e.g., api_key, access_token, refresh_token).
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@@ -19,7 +19,7 @@ def _utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class CredentialType(str, Enum):
|
||||
class CredentialType(StrEnum):
|
||||
"""Types of credentials the store can manage."""
|
||||
|
||||
API_KEY = "api_key"
|
||||
|
||||
@@ -11,11 +11,11 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TokenPlacement(str, Enum):
|
||||
class TokenPlacement(StrEnum):
|
||||
"""Where to place the access token in HTTP requests."""
|
||||
|
||||
HEADER_BEARER = "header_bearer"
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
|
||||
|
||||
from framework.graph.client_io import (
|
||||
ActiveNodeClientIO,
|
||||
ClientIOGateway,
|
||||
InertNodeClientIO,
|
||||
NodeClientIO,
|
||||
)
|
||||
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import ConversationStore, Message, NodeConversation
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
JudgeProtocol,
|
||||
JudgeVerdict,
|
||||
LoopConfig,
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
|
||||
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
|
||||
@@ -77,4 +91,18 @@ __all__ = [
|
||||
"NodeConversation",
|
||||
"ConversationStore",
|
||||
"Message",
|
||||
# Event Loop
|
||||
"EventLoopNode",
|
||||
"LoopConfig",
|
||||
"OutputAccumulator",
|
||||
"JudgeProtocol",
|
||||
"JudgeVerdict",
|
||||
# Context Handoff
|
||||
"ContextHandoff",
|
||||
"HandoffContext",
|
||||
# Client I/O
|
||||
"NodeClientIO",
|
||||
"ActiveNodeClientIO",
|
||||
"InertNodeClientIO",
|
||||
"ClientIOGateway",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Client I/O gateway for graph nodes.
|
||||
|
||||
Provides the bridge between node code and external clients:
|
||||
- ActiveNodeClientIO: for client_facing=True nodes (streams output, accepts input)
|
||||
- InertNodeClientIO: for client_facing=False nodes (logs internally, redirects input)
|
||||
- ClientIOGateway: factory that creates the right variant per node
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeClientIO(ABC):
|
||||
"""Abstract base for node client I/O."""
|
||||
|
||||
@abstractmethod
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
"""Emit output content. If is_final=True, signal end of stream."""
|
||||
|
||||
@abstractmethod
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
"""Request input. Behavior depends on whether the node is client-facing."""
|
||||
|
||||
|
||||
class ActiveNodeClientIO(NodeClientIO):
|
||||
"""
|
||||
Client I/O for client_facing=True nodes.
|
||||
|
||||
- emit_output() queues content and publishes CLIENT_OUTPUT_DELTA.
|
||||
- request_input() publishes CLIENT_INPUT_REQUESTED, then awaits provide_input().
|
||||
- output_stream() yields queued content until the final sentinel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
|
||||
self._output_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._output_snapshot = ""
|
||||
|
||||
self._input_event: asyncio.Event | None = None
|
||||
self._input_result: str | None = None
|
||||
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
self._output_snapshot += content
|
||||
await self._output_queue.put(content)
|
||||
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_client_output_delta(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
snapshot=self._output_snapshot,
|
||||
)
|
||||
|
||||
if is_final:
|
||||
await self._output_queue.put(None)
|
||||
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
if self._input_event is not None:
|
||||
raise RuntimeError("request_input already pending for this node")
|
||||
|
||||
self._input_event = asyncio.Event()
|
||||
self._input_result = None
|
||||
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_client_input_requested(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
try:
|
||||
if timeout is not None:
|
||||
await asyncio.wait_for(self._input_event.wait(), timeout=timeout)
|
||||
else:
|
||||
await self._input_event.wait()
|
||||
finally:
|
||||
self._input_event = None
|
||||
|
||||
if self._input_result is None:
|
||||
raise RuntimeError("input event was set but no input was provided")
|
||||
result = self._input_result
|
||||
self._input_result = None
|
||||
return result
|
||||
|
||||
async def provide_input(self, content: str) -> None:
|
||||
"""Called externally to fulfill a pending request_input()."""
|
||||
if self._input_event is None:
|
||||
raise RuntimeError("no pending request_input to fulfill")
|
||||
self._input_result = content
|
||||
self._input_event.set()
|
||||
|
||||
async def output_stream(self) -> AsyncIterator[str]:
|
||||
"""Async iterator that yields output chunks until the final sentinel."""
|
||||
while True:
|
||||
chunk = await self._output_queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
|
||||
class InertNodeClientIO(NodeClientIO):
|
||||
"""
|
||||
Client I/O for client_facing=False nodes.
|
||||
|
||||
- emit_output() publishes NODE_INTERNAL_OUTPUT (content is not discarded).
|
||||
- request_input() publishes NODE_INPUT_BLOCKED and returns a redirect string.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_node_internal_output(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_node_input_blocked(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
return (
|
||||
"You are an internal processing node. There is no user to interact with."
|
||||
" Work with the data provided in your inputs to complete your task."
|
||||
)
|
||||
|
||||
|
||||
class ClientIOGateway:
|
||||
"""Factory that creates the appropriate NodeClientIO for a node."""
|
||||
|
||||
def __init__(self, event_bus: EventBus | None = None) -> None:
|
||||
self._event_bus = event_bus
|
||||
|
||||
def create_io(self, node_id: str, client_facing: bool) -> NodeClientIO:
|
||||
if client_facing:
|
||||
return ActiveNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
)
|
||||
return InertNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
)
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Context handoff: summarize a completed NodeConversation for the next graph node."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.conversation import _try_extract_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.provider import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRUNCATE_CHARS = 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandoffContext:
|
||||
"""Structured summary of a completed node conversation."""
|
||||
|
||||
source_node_id: str
|
||||
summary: str
|
||||
key_outputs: dict[str, Any]
|
||||
turn_count: int
|
||||
total_tokens_used: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContextHandoff
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ContextHandoff:
|
||||
"""Summarize a completed NodeConversation into a HandoffContext.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
llm : LLMProvider | None
|
||||
Optional LLM provider for abstractive summarization.
|
||||
When *None*, all summarization uses the extractive fallback.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLMProvider | None = None) -> None:
|
||||
self.llm = llm
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def summarize_conversation(
|
||||
self,
|
||||
conversation: NodeConversation,
|
||||
node_id: str,
|
||||
output_keys: list[str] | None = None,
|
||||
) -> HandoffContext:
|
||||
"""Produce a HandoffContext from *conversation*.
|
||||
|
||||
1. Extracts turn_count & total_tokens_used (sync properties).
|
||||
2. Extracts key_outputs by scanning assistant messages most-recent-first.
|
||||
3. Builds a summary via the LLM (if available) or extractive fallback.
|
||||
"""
|
||||
turn_count = conversation.turn_count
|
||||
total_tokens_used = conversation.estimate_tokens()
|
||||
messages = conversation.messages # defensive copy
|
||||
|
||||
# --- key outputs ---------------------------------------------------
|
||||
key_outputs: dict[str, Any] = {}
|
||||
if output_keys:
|
||||
remaining = set(output_keys)
|
||||
for msg in reversed(messages):
|
||||
if msg.role != "assistant" or not remaining:
|
||||
continue
|
||||
for key in list(remaining):
|
||||
value = _try_extract_key(msg.content, key)
|
||||
if value is not None:
|
||||
key_outputs[key] = value
|
||||
remaining.discard(key)
|
||||
|
||||
# --- summary -------------------------------------------------------
|
||||
if self.llm is not None:
|
||||
try:
|
||||
summary = self._llm_summary(messages, output_keys or [])
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"LLM summarization failed; falling back to extractive.",
|
||||
exc_info=True,
|
||||
)
|
||||
summary = self._extractive_summary(messages)
|
||||
else:
|
||||
summary = self._extractive_summary(messages)
|
||||
|
||||
return HandoffContext(
|
||||
source_node_id=node_id,
|
||||
summary=summary,
|
||||
key_outputs=key_outputs,
|
||||
turn_count=turn_count,
|
||||
total_tokens_used=total_tokens_used,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_as_input(handoff: HandoffContext) -> str:
|
||||
"""Render *handoff* as structured plain text for the next node's input."""
|
||||
header = (
|
||||
f"--- CONTEXT FROM: {handoff.source_node_id} "
|
||||
f"({handoff.turn_count} turns, ~{handoff.total_tokens_used} tokens) ---"
|
||||
)
|
||||
|
||||
sections: list[str] = [header, ""]
|
||||
|
||||
if handoff.key_outputs:
|
||||
sections.append("KEY OUTPUTS:")
|
||||
for k, v in handoff.key_outputs.items():
|
||||
sections.append(f"- {k}: {v}")
|
||||
sections.append("")
|
||||
|
||||
summary_text = handoff.summary or "No summary available."
|
||||
sections.append("SUMMARY:")
|
||||
sections.append(summary_text)
|
||||
sections.append("")
|
||||
sections.append("--- END CONTEXT ---")
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extractive_summary(messages: list) -> str:
|
||||
"""Build a summary from key assistant messages without an LLM.
|
||||
|
||||
Strategy:
|
||||
- Include the first assistant message (initial assessment).
|
||||
- Include the last assistant message (final conclusion).
|
||||
- Truncate each to ~500 chars.
|
||||
"""
|
||||
if not messages:
|
||||
return "Empty conversation."
|
||||
|
||||
assistant_msgs = [m for m in messages if m.role == "assistant"]
|
||||
if not assistant_msgs:
|
||||
return "No assistant responses."
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
first = assistant_msgs[0].content
|
||||
parts.append(first[:_TRUNCATE_CHARS])
|
||||
|
||||
if len(assistant_msgs) > 1:
|
||||
last = assistant_msgs[-1].content
|
||||
parts.append(last[:_TRUNCATE_CHARS])
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _llm_summary(self, messages: list, output_keys: list[str]) -> str:
|
||||
"""Produce a summary by calling the LLM provider."""
|
||||
if self.llm is None:
|
||||
raise ValueError("_llm_summary called without an LLM provider")
|
||||
|
||||
conversation_text = "\n".join(f"[{m.role}]: {m.content}" for m in messages)
|
||||
|
||||
key_hint = ""
|
||||
if output_keys:
|
||||
key_hint = (
|
||||
"\nThe following output keys are especially important: "
|
||||
+ ", ".join(output_keys)
|
||||
+ ".\n"
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a concise summarizer. Given the conversation below, "
|
||||
"produce a brief summary (at most ~500 tokens) that captures the "
|
||||
"key decisions, findings, and outcomes. Focus on what was concluded "
|
||||
"rather than the back-and-forth process." + key_hint
|
||||
)
|
||||
|
||||
response = self.llm.complete(
|
||||
messages=[{"role": "user", "content": conversation_text}],
|
||||
system=system_prompt,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
return response.content.strip()
|
||||
@@ -75,6 +75,16 @@ class Message:
|
||||
)
|
||||
|
||||
|
||||
def _extract_spillover_filename(content: str) -> str | None:
|
||||
"""Extract spillover filename from a truncated tool result.
|
||||
|
||||
Matches the pattern produced by EventLoopNode._truncate_tool_result():
|
||||
"saved to 'tool_github_list_stargazers_abc123.txt'"
|
||||
"""
|
||||
match = re.search(r"saved to '([^']+)'", content)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationStore protocol (Phase 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -108,6 +118,50 @@ class ConversationStore(Protocol):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_extract_key(content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a *key*'s value from message content.
|
||||
|
||||
Strategies (in order):
|
||||
1. Whole message is JSON — ``json.loads``, check for key.
|
||||
2. Embedded JSON via ``find_json_object`` helper.
|
||||
3. Colon format: ``key: value``.
|
||||
4. Equals format: ``key = value``.
|
||||
"""
|
||||
from framework.graph.node import find_json_object
|
||||
|
||||
# 1. Whole message is JSON
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 2. Embedded JSON via find_json_object
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Colon format: key: value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 4. Equals format: key = value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NodeConversation:
|
||||
"""Message history for a graph node with optional write-through persistence.
|
||||
|
||||
@@ -133,6 +187,7 @@ class NodeConversation:
|
||||
self._messages: list[Message] = []
|
||||
self._next_seq: int = 0
|
||||
self._meta_persisted: bool = False
|
||||
self._last_api_input_tokens: int | None = None
|
||||
|
||||
# --- Properties --------------------------------------------------------
|
||||
|
||||
@@ -205,14 +260,78 @@ class NodeConversation:
|
||||
# --- Query -------------------------------------------------------------
|
||||
|
||||
def to_llm_messages(self) -> list[dict[str, Any]]:
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded)."""
|
||||
return [m.to_llm_dict() for m in self._messages]
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded).
|
||||
|
||||
Automatically repairs orphaned tool_use blocks (assistant messages
|
||||
with tool_calls that lack corresponding tool-result messages). This
|
||||
can happen when a loop is cancelled mid-tool-execution.
|
||||
"""
|
||||
msgs = [m.to_llm_dict() for m in self._messages]
|
||||
return self._repair_orphaned_tool_calls(msgs)
|
||||
|
||||
@staticmethod
|
||||
def _repair_orphaned_tool_calls(
|
||||
msgs: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Ensure every tool_call has a matching tool-result message."""
|
||||
repaired: list[dict[str, Any]] = []
|
||||
for i, m in enumerate(msgs):
|
||||
repaired.append(m)
|
||||
tool_calls = m.get("tool_calls")
|
||||
if m.get("role") != "assistant" or not tool_calls:
|
||||
continue
|
||||
# Collect IDs of tool results that follow this assistant message
|
||||
answered: set[str] = set()
|
||||
for j in range(i + 1, len(msgs)):
|
||||
if msgs[j].get("role") == "tool":
|
||||
tid = msgs[j].get("tool_call_id")
|
||||
if tid:
|
||||
answered.add(tid)
|
||||
else:
|
||||
break # stop at first non-tool message
|
||||
# Patch any missing results
|
||||
for tc in tool_calls:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in answered:
|
||||
repaired.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": "ERROR: Tool execution was interrupted.",
|
||||
}
|
||||
)
|
||||
return repaired
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Rough token estimate: total characters / 4."""
|
||||
"""Best available token estimate.
|
||||
|
||||
Uses actual API input token count when available (set via
|
||||
:meth:`update_token_count`), otherwise falls back to the rough
|
||||
``total_chars / 4`` heuristic.
|
||||
"""
|
||||
if self._last_api_input_tokens is not None:
|
||||
return self._last_api_input_tokens
|
||||
total_chars = sum(len(m.content) for m in self._messages)
|
||||
return total_chars // 4
|
||||
|
||||
def update_token_count(self, actual_input_tokens: int) -> None:
|
||||
"""Store actual API input token count for more accurate compaction.
|
||||
|
||||
Called by EventLoopNode after each LLM call with the ``input_tokens``
|
||||
value from the API response. This value includes system prompt and
|
||||
tool definitions, so it may be higher than a message-only estimate.
|
||||
"""
|
||||
self._last_api_input_tokens = actual_input_tokens
|
||||
|
||||
def usage_ratio(self) -> float:
|
||||
"""Current token usage as a fraction of *max_history_tokens*.
|
||||
|
||||
Returns 0.0 when ``max_history_tokens`` is zero (unlimited).
|
||||
"""
|
||||
if self._max_history_tokens <= 0:
|
||||
return 0.0
|
||||
return self.estimate_tokens() / self._max_history_tokens
|
||||
|
||||
def needs_compaction(self) -> bool:
|
||||
return self.estimate_tokens() >= self._max_history_tokens * self._compaction_threshold
|
||||
|
||||
@@ -244,42 +363,89 @@ class NodeConversation:
|
||||
|
||||
def _try_extract_key(self, content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a key's value from message content."""
|
||||
from framework.graph.node import find_json_object
|
||||
|
||||
# 1. Whole message is JSON
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 2. Embedded JSON via find_json_object
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Colon format: key: value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 4. Equals format: key = value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
return _try_extract_key(content, key)
|
||||
|
||||
# --- Lifecycle ---------------------------------------------------------
|
||||
|
||||
async def prune_old_tool_results(
|
||||
self,
|
||||
protect_tokens: int = 5000,
|
||||
min_prune_tokens: int = 2000,
|
||||
) -> int:
|
||||
"""Replace old tool result content with compact placeholders.
|
||||
|
||||
Walks backward through messages. Recent tool results (within
|
||||
*protect_tokens*) are kept intact. Older tool results have their
|
||||
content replaced with a ~100-char placeholder that preserves the
|
||||
spillover filename reference (if any). Message structure (role,
|
||||
seq, tool_use_id) stays valid for the LLM API.
|
||||
|
||||
Error tool results are never pruned — they prevent re-calling
|
||||
failing tools.
|
||||
|
||||
Returns the number of messages pruned (0 if nothing was pruned).
|
||||
"""
|
||||
if not self._messages:
|
||||
return 0
|
||||
|
||||
# Phase 1: Walk backward, classify tool results as protected vs pruneable
|
||||
protected_tokens = 0
|
||||
pruneable: list[int] = [] # indices into self._messages
|
||||
pruneable_tokens = 0
|
||||
|
||||
for i in range(len(self._messages) - 1, -1, -1):
|
||||
msg = self._messages[i]
|
||||
if msg.role != "tool":
|
||||
continue
|
||||
if msg.is_error:
|
||||
continue # never prune errors
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
|
||||
est = len(msg.content) // 4
|
||||
if protected_tokens < protect_tokens:
|
||||
protected_tokens += est
|
||||
else:
|
||||
pruneable.append(i)
|
||||
pruneable_tokens += est
|
||||
|
||||
# Phase 2: Only prune if enough to be worthwhile
|
||||
if pruneable_tokens < min_prune_tokens:
|
||||
return 0
|
||||
|
||||
# Phase 3: Replace content with compact placeholder
|
||||
count = 0
|
||||
for i in pruneable:
|
||||
msg = self._messages[i]
|
||||
orig_len = len(msg.content)
|
||||
spillover = _extract_spillover_filename(msg.content)
|
||||
|
||||
if spillover:
|
||||
placeholder = (
|
||||
f"[Pruned tool result: {orig_len} chars. "
|
||||
f"Full data in '{spillover}'. "
|
||||
f"Use load_data('{spillover}') to retrieve.]"
|
||||
)
|
||||
else:
|
||||
placeholder = f"[Pruned tool result: {orig_len} chars cleared from context.]"
|
||||
|
||||
self._messages[i] = Message(
|
||||
seq=msg.seq,
|
||||
role=msg.role,
|
||||
content=placeholder,
|
||||
tool_use_id=msg.tool_use_id,
|
||||
tool_calls=msg.tool_calls,
|
||||
is_error=msg.is_error,
|
||||
)
|
||||
count += 1
|
||||
|
||||
if self._store:
|
||||
await self._store.write_part(msg.seq, self._messages[i].to_storage_dict())
|
||||
|
||||
# Reset token estimate — content lengths changed
|
||||
self._last_api_input_tokens = None
|
||||
return count
|
||||
|
||||
async def compact(self, summary: str, keep_recent: int = 2) -> None:
|
||||
"""Replace old messages with a summary, optionally keeping recent ones.
|
||||
|
||||
@@ -294,12 +460,18 @@ class NodeConversation:
|
||||
# Clamp: must discard at least 1 message
|
||||
keep_recent = max(0, min(keep_recent, len(self._messages) - 1))
|
||||
|
||||
if keep_recent > 0:
|
||||
old_messages = self._messages[:-keep_recent]
|
||||
recent_messages = self._messages[-keep_recent:]
|
||||
else:
|
||||
old_messages = self._messages
|
||||
recent_messages = []
|
||||
total = len(self._messages)
|
||||
split = total - keep_recent if keep_recent > 0 else total
|
||||
|
||||
# Advance split past orphaned tool results at the boundary.
|
||||
# Tool-role messages reference a tool_use from the preceding
|
||||
# assistant message; if that assistant message falls into the
|
||||
# compacted (old) portion the tool_result becomes invalid.
|
||||
while split < total and self._messages[split].role == "tool":
|
||||
split += 1
|
||||
|
||||
old_messages = list(self._messages[:split])
|
||||
recent_messages = list(self._messages[split:])
|
||||
|
||||
# Extract protected values from messages being discarded
|
||||
if self._output_keys:
|
||||
@@ -330,6 +502,7 @@ class NodeConversation:
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
self._messages = [summary_msg] + recent_messages
|
||||
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
|
||||
@@ -337,6 +510,7 @@ class NodeConversation:
|
||||
await self._store.delete_parts_before(self._next_seq)
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
self._messages.clear()
|
||||
self._last_api_input_tokens = None
|
||||
|
||||
def export_summary(self) -> str:
|
||||
"""Structured summary with [STATS], [CONFIG], [RECENT_MESSAGES] sections."""
|
||||
|
||||
@@ -21,7 +21,7 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
|
||||
given the current goal, context, and execution state.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -29,7 +29,7 @@ from pydantic import BaseModel, Field
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
|
||||
class EdgeCondition(str, Enum):
|
||||
class EdgeCondition(StrEnum):
|
||||
"""When an edge should be traversed."""
|
||||
|
||||
ALWAYS = "always" # Always after source completes
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,11 +11,12 @@ The executor:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.edge import EdgeSpec, GraphSpec
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import (
|
||||
FunctionNode,
|
||||
@@ -54,6 +55,9 @@ class ExecutionResult:
|
||||
had_partial_failures: bool = False # True if any node failed but eventually succeeded
|
||||
execution_quality: str = "clean" # "clean", "degraded", or "failed"
|
||||
|
||||
# Visit tracking (for feedback/callback edges)
|
||||
node_visit_counts: dict[str, int] = field(default_factory=dict) # {node_id: visit_count}
|
||||
|
||||
@property
|
||||
def is_clean_success(self) -> bool:
|
||||
"""True only if execution succeeded with no retries or failures."""
|
||||
@@ -250,6 +254,8 @@ class GraphExecutor:
|
||||
total_tokens = 0
|
||||
total_latency = 0
|
||||
node_retry_counts: dict[str, int] = {} # Track retries per node
|
||||
node_visit_counts: dict[str, int] = {} # Track visits for feedback loops
|
||||
_is_retry = False # True when looping back for a retry (not a new visit)
|
||||
|
||||
# Determine entry point (may differ if resuming)
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
@@ -278,6 +284,34 @@ class GraphExecutor:
|
||||
if node_spec is None:
|
||||
raise RuntimeError(f"Node not found: {current_node_id}")
|
||||
|
||||
# Enforce max_node_visits (feedback/callback edge support)
|
||||
# Don't increment visit count on retries — retries are not new visits
|
||||
if not _is_retry:
|
||||
cnt = node_visit_counts.get(current_node_id, 0) + 1
|
||||
node_visit_counts[current_node_id] = cnt
|
||||
_is_retry = False
|
||||
max_visits = getattr(node_spec, "max_node_visits", 1)
|
||||
if max_visits > 0 and node_visit_counts[current_node_id] > max_visits:
|
||||
self.logger.warning(
|
||||
f" ⊘ Node '{node_spec.name}' visit limit reached "
|
||||
f"({node_visit_counts[current_node_id]}/{max_visits}), skipping"
|
||||
)
|
||||
# Skip execution — follow outgoing edges using current memory
|
||||
skip_result = NodeResult(success=True, output=memory.read_all())
|
||||
next_node = self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
current_node_spec=node_spec,
|
||||
result=skip_result,
|
||||
memory=memory,
|
||||
)
|
||||
if next_node is None:
|
||||
self.logger.info(" → No more edges after visit limit, ending")
|
||||
break
|
||||
current_node_id = next_node
|
||||
continue
|
||||
|
||||
path.append(current_node_id)
|
||||
|
||||
# Check if pause (HITL) before execution
|
||||
@@ -380,6 +414,15 @@ class GraphExecutor:
|
||||
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
|
||||
max_retries = getattr(node_spec, "max_retries", 3)
|
||||
|
||||
# Event loop nodes handle retry internally via judge —
|
||||
# executor retry is catastrophic (retry multiplication)
|
||||
if node_spec.node_type == "event_loop" and max_retries > 0:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has max_retries={max_retries}. "
|
||||
"Overriding to 0 — event loop nodes handle retry internally via judge."
|
||||
)
|
||||
max_retries = 0
|
||||
|
||||
if node_retry_counts[current_node_id] < max_retries:
|
||||
# Retry - don't increment steps for retries
|
||||
steps -= 1
|
||||
@@ -395,6 +438,7 @@ class GraphExecutor:
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
_is_retry = True
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - fail the execution
|
||||
@@ -437,6 +481,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
# Check if we just executed a pause node - if so, save state and return
|
||||
@@ -476,6 +521,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
# Check if this is a terminal node - if so, we're done
|
||||
@@ -596,6 +642,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -622,6 +669,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
def _build_context(
|
||||
@@ -658,7 +706,15 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
VALID_NODE_TYPES = {
|
||||
"llm_tool_use",
|
||||
"llm_generate",
|
||||
"router",
|
||||
"function",
|
||||
"human_input",
|
||||
"event_loop",
|
||||
}
|
||||
DEPRECATED_NODE_TYPES = {"llm_tool_use": "event_loop", "llm_generate": "event_loop"}
|
||||
|
||||
def _get_node_implementation(
|
||||
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
|
||||
@@ -676,6 +732,17 @@ class GraphExecutor:
|
||||
f"Use 'llm_tool_use' for nodes that call tools, 'llm_generate' for text generation."
|
||||
)
|
||||
|
||||
# Warn on deprecated node types
|
||||
if node_spec.node_type in self.DEPRECATED_NODE_TYPES:
|
||||
replacement = self.DEPRECATED_NODE_TYPES[node_spec.node_type]
|
||||
warnings.warn(
|
||||
f"Node type '{node_spec.node_type}' is deprecated. "
|
||||
f"Use '{replacement}' instead. "
|
||||
f"Node: '{node_spec.id}'",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Create based on type
|
||||
if node_spec.node_type == "llm_tool_use":
|
||||
if not node_spec.tools:
|
||||
@@ -713,6 +780,13 @@ class GraphExecutor:
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "event_loop":
|
||||
# Event loop nodes must be pre-registered (like function nodes)
|
||||
raise RuntimeError(
|
||||
f"EventLoopNode '{node_spec.id}' not found in registry. "
|
||||
"Register it with executor.register_node() before execution."
|
||||
)
|
||||
|
||||
# Should never reach here due to validation above
|
||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||
|
||||
@@ -823,6 +897,21 @@ class GraphExecutor:
|
||||
):
|
||||
traversable.append(edge)
|
||||
|
||||
# Priority filtering for CONDITIONAL edges:
|
||||
# When multiple CONDITIONAL edges match, keep only the highest-priority
|
||||
# group. This prevents mutually-exclusive conditional branches (e.g.
|
||||
# forward vs. feedback) from incorrectly triggering fan-out.
|
||||
# ON_SUCCESS / other edge types are unaffected.
|
||||
if len(traversable) > 1:
|
||||
conditionals = [e for e in traversable if e.condition == EdgeCondition.CONDITIONAL]
|
||||
if len(conditionals) > 1:
|
||||
max_prio = max(e.priority for e in conditionals)
|
||||
traversable = [
|
||||
e
|
||||
for e in traversable
|
||||
if e.condition != EdgeCondition.CONDITIONAL or e.priority == max_prio
|
||||
]
|
||||
|
||||
return traversable
|
||||
|
||||
def _find_convergence_node(
|
||||
@@ -909,6 +998,17 @@ class GraphExecutor:
|
||||
branch.status = "failed"
|
||||
branch.error = f"Node {branch.node_id} not found in graph"
|
||||
return branch, RuntimeError(branch.error)
|
||||
|
||||
effective_max_retries = node_spec.max_retries
|
||||
if node_spec.node_type == "event_loop":
|
||||
if effective_max_retries > 1:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has "
|
||||
f"max_retries={effective_max_retries}. Overriding "
|
||||
"to 1 — event loop nodes handle retry internally."
|
||||
)
|
||||
effective_max_retries = 1
|
||||
|
||||
branch.status = "running"
|
||||
|
||||
try:
|
||||
@@ -942,7 +1042,7 @@ class GraphExecutor:
|
||||
|
||||
# Execute with retries
|
||||
last_result = None
|
||||
for attempt in range(node_spec.max_retries):
|
||||
for attempt in range(effective_max_retries):
|
||||
branch.retry_count = attempt
|
||||
|
||||
# Build context for this branch
|
||||
@@ -970,7 +1070,7 @@ class GraphExecutor:
|
||||
|
||||
self.logger.warning(
|
||||
f" ↻ Branch {node_spec.name}: "
|
||||
f"retry {attempt + 1}/{node_spec.max_retries}"
|
||||
f"retry {attempt + 1}/{effective_max_retries}"
|
||||
)
|
||||
|
||||
# All retries exhausted
|
||||
@@ -979,7 +1079,7 @@ class GraphExecutor:
|
||||
branch.result = last_result
|
||||
self.logger.error(
|
||||
f" ✗ Branch {node_spec.name}: "
|
||||
f"failed after {node_spec.max_retries} attempts"
|
||||
f"failed after {effective_max_retries} attempts"
|
||||
)
|
||||
return branch, last_result
|
||||
|
||||
|
||||
@@ -12,13 +12,13 @@ Goals are:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GoalStatus(str, Enum):
|
||||
class GoalStatus(StrEnum):
|
||||
"""Lifecycle status of a goal."""
|
||||
|
||||
DRAFT = "draft" # Being defined
|
||||
|
||||
@@ -6,11 +6,11 @@ where agents need to gather input from humans.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class HITLInputType(str, Enum):
|
||||
class HITLInputType(StrEnum):
|
||||
"""Type of input expected from human."""
|
||||
|
||||
FREE_TEXT = "free_text" # Open-ended text response
|
||||
|
||||
@@ -16,10 +16,12 @@ Protocol:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -153,7 +155,10 @@ class NodeSpec(BaseModel):
|
||||
# Node behavior type
|
||||
node_type: str = Field(
|
||||
default="llm_tool_use",
|
||||
description="Type: 'llm_tool_use', 'llm_generate', 'function', 'router', 'human_input'",
|
||||
description=(
|
||||
"Type: 'event_loop', 'function', 'router', 'human_input'. "
|
||||
"Deprecated: 'llm_tool_use', 'llm_generate' (use 'event_loop' instead)."
|
||||
),
|
||||
)
|
||||
|
||||
# Data flow
|
||||
@@ -205,6 +210,15 @@ class NodeSpec(BaseModel):
|
||||
max_retries: int = Field(default=3)
|
||||
retry_on: list[str] = Field(default_factory=list, description="Error types to retry on")
|
||||
|
||||
# Visit limits (for feedback/callback edges)
|
||||
max_node_visits: int = Field(
|
||||
default=1,
|
||||
description=(
|
||||
"Max times this node executes in one graph run. "
|
||||
"Set >1 for feedback loops. 0 = unlimited (max_steps guards)."
|
||||
),
|
||||
)
|
||||
|
||||
# Pydantic model for output validation
|
||||
output_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
@@ -218,6 +232,12 @@ class NodeSpec(BaseModel):
|
||||
description="Maximum retries when Pydantic validation fails (with feedback to LLM)",
|
||||
)
|
||||
|
||||
# Client-facing behavior
|
||||
client_facing: bool = Field(
|
||||
default=False,
|
||||
description="If True, this node streams output to the end user and can request input.",
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
@@ -1348,7 +1368,9 @@ Expected output keys: {output_keys}
|
||||
LLM Response:
|
||||
{raw_response}
|
||||
|
||||
Output ONLY the JSON object, nothing else."""
|
||||
Output ONLY the JSON object, nothing else.
|
||||
If no valid JSON object exists in the response, output exactly: {{"error": "NO_JSON_FOUND"}}
|
||||
Do NOT fabricate data or return empty objects."""
|
||||
|
||||
try:
|
||||
result = cleaner_llm.complete(
|
||||
@@ -1395,6 +1417,14 @@ Output ONLY the JSON object, nothing else."""
|
||||
parsed = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
parsed = json.loads(_fix_unescaped_newlines_in_json(cleaned))
|
||||
|
||||
# Validate LLM didn't return empty or fabricated data
|
||||
if parsed.get("error") == "NO_JSON_FOUND":
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
if not parsed or parsed == {}:
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
if all(v is None for v in parsed.values()):
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
logger.info(" ✓ LLM cleaned JSON output")
|
||||
return parsed
|
||||
|
||||
@@ -1504,6 +1534,8 @@ Output ONLY the JSON object, nothing else."""
|
||||
|
||||
def _build_system_prompt(self, ctx: NodeContext) -> str:
|
||||
"""Build the system prompt."""
|
||||
from datetime import datetime
|
||||
|
||||
parts = []
|
||||
|
||||
if ctx.node_spec.system_prompt:
|
||||
@@ -1526,6 +1558,15 @@ Output ONLY the JSON object, nothing else."""
|
||||
|
||||
parts.append(prompt)
|
||||
|
||||
# Inject current datetime so LLM knows "now"
|
||||
utc_dt = datetime.now(UTC)
|
||||
local_dt = datetime.now().astimezone()
|
||||
local_tz_name = local_dt.tzname() or "Unknown"
|
||||
parts.append("\n## Runtime Context")
|
||||
parts.append(f"- Current Date/Time (UTC): {utc_dt.isoformat()}")
|
||||
parts.append(f"- Local Timezone: {local_tz_name}")
|
||||
parts.append(f"- Current Date/Time (Local): {local_dt.isoformat()}")
|
||||
|
||||
if ctx.goal_context:
|
||||
parts.append("\n# Goal Context")
|
||||
parts.append(ctx.goal_context)
|
||||
@@ -1727,8 +1768,19 @@ class FunctionNode(NodeProtocol):
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Call the function
|
||||
result = self.func(**ctx.input_data)
|
||||
# Filter input_data to only declared input_keys to prevent
|
||||
# leaking extra memory keys from upstream nodes.
|
||||
if ctx.node_spec.input_keys:
|
||||
filtered = {
|
||||
k: v for k, v in ctx.input_data.items() if k in ctx.node_spec.input_keys
|
||||
}
|
||||
else:
|
||||
filtered = ctx.input_data
|
||||
|
||||
# Call the function (supports both sync and async)
|
||||
result = self.func(**filtered)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
latency_ms = int((time.time() - start) * 1000)
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ The Plan is the contract between the external planner and the executor:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
class ActionType(StrEnum):
|
||||
"""Types of actions a PlanStep can perform."""
|
||||
|
||||
LLM_CALL = "llm_call" # Call LLM for generation
|
||||
@@ -27,7 +27,7 @@ class ActionType(str, Enum):
|
||||
CODE_EXECUTION = "code_execution" # Execute dynamic code (sandboxed)
|
||||
|
||||
|
||||
class StepStatus(str, Enum):
|
||||
class StepStatus(StrEnum):
|
||||
"""Status of a plan step."""
|
||||
|
||||
PENDING = "pending"
|
||||
@@ -56,7 +56,7 @@ class StepStatus(str, Enum):
|
||||
return self == StepStatus.COMPLETED
|
||||
|
||||
|
||||
class ApprovalDecision(str, Enum):
|
||||
class ApprovalDecision(StrEnum):
|
||||
"""Human decision on a step requiring approval."""
|
||||
|
||||
APPROVE = "approve" # Execute as planned
|
||||
@@ -91,7 +91,7 @@ class ApprovalResult(BaseModel):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class JudgmentAction(str, Enum):
|
||||
class JudgmentAction(StrEnum):
|
||||
"""Actions the judge can take after evaluating a step."""
|
||||
|
||||
ACCEPT = "accept" # Step completed successfully, continue
|
||||
@@ -423,7 +423,7 @@ class Plan(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
class ExecutionStatus(StrEnum):
|
||||
"""Status of plan execution."""
|
||||
|
||||
COMPLETED = "completed"
|
||||
|
||||
@@ -75,16 +75,6 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
def visit_Constant(self, node: ast.Constant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Number/String/Bytes/NameConstant (Python < 3.8 compat if needed) ---
|
||||
def visit_Num(self, node: ast.Num) -> Any:
|
||||
return node.n
|
||||
|
||||
def visit_Str(self, node: ast.Str) -> Any:
|
||||
return node.s
|
||||
|
||||
def visit_NameConstant(self, node: ast.NameConstant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Data Structures ---
|
||||
def visit_List(self, node: ast.List) -> list:
|
||||
return [self.visit(elt) for elt in node.elts]
|
||||
|
||||
@@ -126,14 +126,16 @@ class OutputValidator:
|
||||
|
||||
for key in expected_keys:
|
||||
if key not in output:
|
||||
errors.append(f"Missing required output key: '{key}'")
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Missing required output key: '{key}'")
|
||||
elif not allow_empty:
|
||||
value = output[key]
|
||||
if value is None:
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Output key '{key}' is None")
|
||||
elif isinstance(value, str) and len(value.strip()) == 0:
|
||||
errors.append(f"Output key '{key}' is empty string")
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Output key '{key}' is empty string")
|
||||
|
||||
return ValidationResult(success=len(errors) == 0, errors=errors)
|
||||
|
||||
|
||||
@@ -1,8 +1,31 @@
|
||||
"""LLM provider abstraction."""
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse"]
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"StreamEvent",
|
||||
"TextDeltaEvent",
|
||||
"TextEndEvent",
|
||||
"ToolCallEvent",
|
||||
"ToolResultEvent",
|
||||
"ReasoningStartEvent",
|
||||
"ReasoningDeltaEvent",
|
||||
"FinishEvent",
|
||||
"StreamErrorEvent",
|
||||
]
|
||||
|
||||
try:
|
||||
from framework.llm.anthropic import AnthropicProvider # noqa: F401
|
||||
|
||||
@@ -18,7 +18,7 @@ def _get_api_key_from_credential_store() -> str | None:
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except ImportError:
|
||||
|
||||
@@ -7,10 +7,11 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,6 +24,7 @@ except ImportError:
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -161,11 +163,24 @@ class LiteLLMProvider(LLMProvider):
|
||||
content = response.choices[0].message.content if response.choices else None
|
||||
has_tool_calls = bool(response.choices and response.choices[0].message.tool_calls)
|
||||
if not content and not has_tool_calls:
|
||||
# If the conversation ends with an assistant message,
|
||||
# an empty response is expected — don't retry.
|
||||
messages = kwargs.get("messages", [])
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
logger.debug(
|
||||
"[retry] Empty response after assistant message — "
|
||||
"expected, not retrying."
|
||||
)
|
||||
return response
|
||||
|
||||
finish_reason = (
|
||||
response.choices[0].finish_reason if response.choices else "unknown"
|
||||
)
|
||||
# Dump full request to file for debugging
|
||||
messages = kwargs.get("messages", [])
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
@@ -378,11 +393,18 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
# Execute tools and add results.
|
||||
for tool_call in message.tool_calls:
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
# Surface error to LLM and skip tool execution
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
@@ -425,3 +447,185 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a completion via litellm.acompletion(stream=True).
|
||||
|
||||
Yields StreamEvent objects as chunks arrive from the provider.
|
||||
Tool call arguments are accumulated across chunks and yielded as
|
||||
a single ToolCallEvent with fully parsed JSON when complete.
|
||||
|
||||
Empty responses (e.g. Gemini stealth rate-limits that return 200
|
||||
with no content) are retried with exponential backoff, mirroring
|
||||
the retry behaviour of ``_completion_with_rate_limit_retry``.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
# Post-stream events (ToolCall, TextEnd, Finish) are buffered
|
||||
# because they depend on the full stream. TextDeltaEvents are
|
||||
# yielded immediately so callers see tokens in real time.
|
||||
tail_events: list[StreamEvent] = []
|
||||
accumulated_text = ""
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
async for chunk in response:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# --- Text content — yield immediately for real-time streaming ---
|
||||
if delta and delta.content:
|
||||
accumulated_text += delta.content
|
||||
yield TextDeltaEvent(
|
||||
content=delta.content,
|
||||
snapshot=accumulated_text,
|
||||
)
|
||||
|
||||
# --- Tool calls (accumulate across chunks) ---
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function:
|
||||
if tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
|
||||
# --- Finish ---
|
||||
if choice.finish_reason:
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
tail_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
tool_name=tc_data["name"],
|
||||
tool_input=parsed_args,
|
||||
)
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
tail_events.append(TextEndEvent(full_text=accumulated_text))
|
||||
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
tail_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether the stream produced any real content.
|
||||
# (If text deltas were yielded above, has_content is True
|
||||
# and we skip the retry path — nothing was yielded in vain.)
|
||||
has_content = accumulated_text or tool_calls_acc
|
||||
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
# If the conversation ends with an assistant message,
|
||||
# an empty stream is expected (nothing new to say).
|
||||
# Don't retry — just flush whatever we have.
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(full_messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
logger.debug(
|
||||
"[stream] Empty response after assistant message — "
|
||||
"expected, not retrying."
|
||||
)
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
)
|
||||
dump_path = _dump_failed_request(
|
||||
model=self.model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_stream",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} returned empty stream — "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
# Success (or final attempt) — flush remaining events.
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
|
||||
except RateLimitError as e:
|
||||
if attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
@@ -2,10 +2,16 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
|
||||
class MockLLMProvider(LLMProvider):
|
||||
@@ -175,3 +181,28 @@ class MockLLMProvider(LLMProvider):
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a mock completion as word-level TextDeltaEvents.
|
||||
|
||||
Splits the mock response into words and yields each as a separate
|
||||
TextDeltaEvent with an accumulating snapshot, exercising the full
|
||||
streaming pipeline without any API calls.
|
||||
"""
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
|
||||
for i, word in enumerate(words):
|
||||
chunk = word if i == 0 else " " + word
|
||||
accumulated += chunk
|
||||
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
|
||||
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(stop_reason="mock_complete", model=self.model)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -108,3 +108,45 @@ class LLMProvider(ABC):
|
||||
Final LLMResponse after tool use completes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator["StreamEvent"]:
|
||||
"""
|
||||
Stream a completion as an async iterator of StreamEvents.
|
||||
|
||||
Default implementation wraps complete() with synthetic events.
|
||||
Subclasses SHOULD override for true streaming.
|
||||
|
||||
Tool orchestration is the CALLER's responsibility:
|
||||
- Caller detects ToolCallEvent, executes tool, adds result
|
||||
to messages, calls stream() again.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
response = self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
|
||||
# Deferred import target for type annotation
|
||||
from framework.llm.stream_events import StreamEvent as StreamEvent # noqa: E402, F401
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Stream event types for LLM streaming responses.
|
||||
|
||||
Defines a discriminated union of frozen dataclasses representing every event
|
||||
a streaming LLM call can produce. These types form the contract between the
|
||||
LLM provider layer, EventLoopNode, event bus, persistence, and monitoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextDeltaEvent:
|
||||
"""A chunk of text produced by the LLM."""
|
||||
|
||||
type: Literal["text_delta"] = "text_delta"
|
||||
content: str = "" # this chunk's text
|
||||
snapshot: str = "" # accumulated text so far
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextEndEvent:
|
||||
"""Signals that text generation is complete."""
|
||||
|
||||
type: Literal["text_end"] = "text_end"
|
||||
full_text: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolCallEvent:
|
||||
"""The LLM has requested a tool call."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
tool_use_id: str = ""
|
||||
tool_name: str = ""
|
||||
tool_input: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolResultEvent:
|
||||
"""Result of executing a tool call."""
|
||||
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str = ""
|
||||
content: str = ""
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningStartEvent:
|
||||
"""The LLM has started a reasoning/thinking block."""
|
||||
|
||||
type: Literal["reasoning_start"] = "reasoning_start"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningDeltaEvent:
|
||||
"""A chunk of reasoning/thinking content."""
|
||||
|
||||
type: Literal["reasoning_delta"] = "reasoning_delta"
|
||||
content: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinishEvent:
|
||||
"""The LLM has finished generating."""
|
||||
|
||||
type: Literal["finish"] = "finish"
|
||||
stop_reason: str = ""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamErrorEvent:
|
||||
"""An error occurred during streaming."""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: str = ""
|
||||
recoverable: bool = False
|
||||
|
||||
|
||||
# Discriminated union of all stream event types
|
||||
StreamEvent = (
|
||||
TextDeltaEvent
|
||||
| TextEndEvent
|
||||
| ToolCallEvent
|
||||
| ToolResultEvent
|
||||
| ReasoningStartEvent
|
||||
| ReasoningDeltaEvent
|
||||
| FinishEvent
|
||||
| StreamErrorEvent
|
||||
)
|
||||
@@ -516,6 +516,36 @@ def _validate_tool_credentials(tools_list: list[str]) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def _validate_agent_path(agent_path: str) -> tuple[Path | None, str | None]:
|
||||
"""
|
||||
Validate and normalize agent_path.
|
||||
|
||||
Returns:
|
||||
(Path, None) if valid
|
||||
(None, error_json) if invalid
|
||||
"""
|
||||
if not agent_path:
|
||||
return None, json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "agent_path is required (e.g., 'exports/my_agent')",
|
||||
}
|
||||
)
|
||||
|
||||
path = Path(agent_path)
|
||||
|
||||
if not path.exists():
|
||||
return None, json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Agent path not found: {path}",
|
||||
"hint": "Run export_graph to create an agent in exports/ first",
|
||||
}
|
||||
)
|
||||
|
||||
return path, None
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def add_node(
|
||||
node_id: Annotated[str, "Unique identifier for the node"],
|
||||
@@ -2597,10 +2627,11 @@ def generate_constraint_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
agent_module = _get_agent_module_from_path(agent_path)
|
||||
agent_module = _get_agent_module_from_path(path)
|
||||
|
||||
# Format constraints for display
|
||||
constraints_formatted = (
|
||||
@@ -2619,9 +2650,9 @@ def generate_constraint_tests(
|
||||
return json.dumps(
|
||||
{
|
||||
"goal_id": goal_id,
|
||||
"agent_path": agent_path,
|
||||
"agent_path": str(path),
|
||||
"agent_module": agent_module,
|
||||
"output_file": f"{agent_path}/tests/test_constraints.py",
|
||||
"output_file": f"{str(path)}/tests/test_constraints.py",
|
||||
"constraints": [c.model_dump() for c in goal.constraints] if goal.constraints else [],
|
||||
"constraints_formatted": constraints_formatted,
|
||||
"test_guidelines": {
|
||||
@@ -2677,10 +2708,11 @@ def generate_success_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
agent_module = _get_agent_module_from_path(agent_path)
|
||||
agent_module = _get_agent_module_from_path(path)
|
||||
|
||||
# Parse node/tool names for context
|
||||
nodes = [n.strip() for n in node_names.split(",") if n.strip()]
|
||||
@@ -2705,9 +2737,9 @@ def generate_success_tests(
|
||||
return json.dumps(
|
||||
{
|
||||
"goal_id": goal_id,
|
||||
"agent_path": agent_path,
|
||||
"agent_path": str(path),
|
||||
"agent_module": agent_module,
|
||||
"output_file": f"{agent_path}/tests/test_success_criteria.py",
|
||||
"output_file": f"{str(path)}/tests/test_success_criteria.py",
|
||||
"success_criteria": [c.model_dump() for c in goal.success_criteria]
|
||||
if goal.success_criteria
|
||||
else [],
|
||||
@@ -2766,7 +2798,11 @@ def run_tests(
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -2957,10 +2993,11 @@ def debug_test(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -3101,10 +3138,11 @@ def list_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -3379,7 +3417,7 @@ def store_credential(
|
||||
display_name: Annotated[str, "Human-readable name (e.g., 'HubSpot Access Token')"] = "",
|
||||
) -> str:
|
||||
"""
|
||||
Store a credential securely in the encrypted credential store at ~/.hive/credentials.
|
||||
Store a credential securely in the local encrypted store at ~/.hive/credentials.
|
||||
|
||||
Uses Fernet encryption (AES-128-CBC + HMAC). Requires HIVE_CREDENTIAL_KEY env var.
|
||||
"""
|
||||
@@ -3421,7 +3459,7 @@ def store_credential(
|
||||
@mcp.tool()
|
||||
def list_stored_credentials() -> str:
|
||||
"""
|
||||
List all credentials currently stored in the encrypted credential store.
|
||||
List all credentials currently stored in the local encrypted store.
|
||||
|
||||
Returns credential IDs and metadata (never returns secret values).
|
||||
"""
|
||||
@@ -3461,7 +3499,7 @@ def delete_stored_credential(
|
||||
credential_name: Annotated[str, "Logical credential name to delete (e.g., 'hubspot')"],
|
||||
) -> str:
|
||||
"""
|
||||
Delete a credential from the encrypted credential store.
|
||||
Delete a credential from the local encrypted store.
|
||||
"""
|
||||
try:
|
||||
store = _get_credential_store()
|
||||
|
||||
@@ -411,25 +411,8 @@ class AgentRunner:
|
||||
return self._tool_registry.register_mcp_server(server_config)
|
||||
|
||||
def _load_mcp_servers_from_config(self, config_path: Path) -> None:
|
||||
"""
|
||||
Load and register MCP servers from a configuration file.
|
||||
|
||||
Args:
|
||||
config_path: Path to mcp_servers.json file
|
||||
"""
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
servers = config.get("servers", [])
|
||||
for server_config in servers:
|
||||
try:
|
||||
self._tool_registry.register_mcp_server(server_config)
|
||||
except Exception as e:
|
||||
server_name = server_config.get("name", "unknown")
|
||||
logger.warning(f"Failed to register MCP server '{server_name}': {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP servers config from {config_path}: {e}")
|
||||
"""Load and register MCP servers from a configuration file."""
|
||||
self._tool_registry.load_mcp_config(config_path)
|
||||
|
||||
def set_approval_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
|
||||
@@ -257,6 +257,34 @@ class ToolRegistry:
|
||||
"""
|
||||
self._session_context.update(context)
|
||||
|
||||
def load_mcp_config(self, config_path: Path) -> None:
|
||||
"""
|
||||
Load and register MCP servers from a config file.
|
||||
|
||||
Resolves relative ``cwd`` paths against the config file's parent
|
||||
directory so callers never need to handle path resolution themselves.
|
||||
|
||||
Args:
|
||||
config_path: Path to an ``mcp_servers.json`` file.
|
||||
"""
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP config from {config_path}: {e}")
|
||||
return
|
||||
|
||||
base_dir = config_path.parent
|
||||
for server_config in config.get("servers", []):
|
||||
cwd = server_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
server_config["cwd"] = str((base_dir / cwd).resolve())
|
||||
try:
|
||||
self.register_mcp_server(server_config)
|
||||
except Exception as e:
|
||||
name = server_config.get("name", "unknown")
|
||||
logger.warning(f"Failed to register MCP server '{name}': {e}")
|
||||
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
@@ -309,11 +337,21 @@ class ToolRegistry:
|
||||
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
|
||||
|
||||
# Create executor that calls the MCP server
|
||||
def make_mcp_executor(client_ref: MCPClient, tool_name: str, registry_ref):
|
||||
def make_mcp_executor(
|
||||
client_ref: MCPClient,
|
||||
tool_name: str,
|
||||
registry_ref,
|
||||
tool_params: set[str],
|
||||
):
|
||||
def executor(inputs: dict) -> Any:
|
||||
try:
|
||||
# Inject session context for tools that need it
|
||||
merged_inputs = {**registry_ref._session_context, **inputs}
|
||||
# Only inject session context params the tool accepts
|
||||
filtered_context = {
|
||||
k: v
|
||||
for k, v in registry_ref._session_context.items()
|
||||
if k in tool_params
|
||||
}
|
||||
merged_inputs = {**filtered_context, **inputs}
|
||||
result = client_ref.call_tool(tool_name, merged_inputs)
|
||||
# MCP tools return content array, extract the result
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
@@ -327,10 +365,11 @@ class ToolRegistry:
|
||||
|
||||
return executor
|
||||
|
||||
tool_params = set(mcp_tool.input_schema.get("properties", {}).keys())
|
||||
self.register(
|
||||
mcp_tool.name,
|
||||
tool,
|
||||
make_mcp_executor(client, mcp_tool.name, self),
|
||||
make_mcp_executor(client, mcp_tool.name, self, tool_params),
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -12,13 +12,13 @@ import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
class EventType(StrEnum):
|
||||
"""Types of events that can be published."""
|
||||
|
||||
# Execution lifecycle
|
||||
@@ -41,6 +41,28 @@ class EventType(str, Enum):
|
||||
STREAM_STARTED = "stream_started"
|
||||
STREAM_STOPPED = "stream_stopped"
|
||||
|
||||
# Node event-loop lifecycle
|
||||
NODE_LOOP_STARTED = "node_loop_started"
|
||||
NODE_LOOP_ITERATION = "node_loop_iteration"
|
||||
NODE_LOOP_COMPLETED = "node_loop_completed"
|
||||
|
||||
# LLM streaming observability
|
||||
LLM_TEXT_DELTA = "llm_text_delta"
|
||||
LLM_REASONING_DELTA = "llm_reasoning_delta"
|
||||
|
||||
# Tool lifecycle
|
||||
TOOL_CALL_STARTED = "tool_call_started"
|
||||
TOOL_CALL_COMPLETED = "tool_call_completed"
|
||||
|
||||
# Client I/O (client_facing=True nodes only)
|
||||
CLIENT_OUTPUT_DELTA = "client_output_delta"
|
||||
CLIENT_INPUT_REQUESTED = "client_input_requested"
|
||||
|
||||
# Internal node observability (client_facing=False nodes)
|
||||
NODE_INTERNAL_OUTPUT = "node_internal_output"
|
||||
NODE_INPUT_BLOCKED = "node_input_blocked"
|
||||
NODE_STALLED = "node_stalled"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
|
||||
@@ -51,6 +73,7 @@ class AgentEvent:
|
||||
|
||||
type: EventType
|
||||
stream_id: str
|
||||
node_id: str | None = None # Which node emitted this event
|
||||
execution_id: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
@@ -61,6 +84,7 @@ class AgentEvent:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"stream_id": self.stream_id,
|
||||
"node_id": self.node_id,
|
||||
"execution_id": self.execution_id,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
@@ -80,6 +104,7 @@ class Subscription:
|
||||
event_types: set[EventType]
|
||||
handler: EventHandler
|
||||
filter_stream: str | None = None # Only receive events from this stream
|
||||
filter_node: str | None = None # Only receive events from this node
|
||||
filter_execution: str | None = None # Only receive events from this execution
|
||||
|
||||
|
||||
@@ -138,6 +163,7 @@ class EventBus:
|
||||
event_types: list[EventType],
|
||||
handler: EventHandler,
|
||||
filter_stream: str | None = None,
|
||||
filter_node: str | None = None,
|
||||
filter_execution: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -147,6 +173,7 @@ class EventBus:
|
||||
event_types: Types of events to receive
|
||||
handler: Async function to call when event occurs
|
||||
filter_stream: Only receive events from this stream
|
||||
filter_node: Only receive events from this node
|
||||
filter_execution: Only receive events from this execution
|
||||
|
||||
Returns:
|
||||
@@ -160,6 +187,7 @@ class EventBus:
|
||||
event_types=set(event_types),
|
||||
handler=handler,
|
||||
filter_stream=filter_stream,
|
||||
filter_node=filter_node,
|
||||
filter_execution=filter_execution,
|
||||
)
|
||||
|
||||
@@ -218,6 +246,10 @@ class EventBus:
|
||||
if subscription.filter_stream and subscription.filter_stream != event.stream_id:
|
||||
return False
|
||||
|
||||
# Check node filter
|
||||
if subscription.filter_node and subscription.filter_node != event.node_id:
|
||||
return False
|
||||
|
||||
# Check execution filter
|
||||
if subscription.filter_execution and subscription.filter_execution != event.execution_id:
|
||||
return False
|
||||
@@ -359,6 +391,248 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
# === NODE EVENT-LOOP PUBLISHERS ===
|
||||
|
||||
async def emit_node_loop_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str | None = None,
|
||||
max_iterations: int | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop started event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"max_iterations": max_iterations},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_loop_iteration(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iteration: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop iteration event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_ITERATION,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"iteration": iteration},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_loop_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iterations: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop completed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"iterations": iterations},
|
||||
)
|
||||
)
|
||||
|
||||
# === LLM STREAMING PUBLISHERS ===
|
||||
|
||||
async def emit_llm_text_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM text delta event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_llm_reasoning_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM reasoning delta event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_REASONING_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content},
|
||||
)
|
||||
)
|
||||
|
||||
# === TOOL LIFECYCLE PUBLISHERS ===
|
||||
|
||||
async def emit_tool_call_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any] | None = None,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool call started event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_use_id": tool_use_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_input": tool_input or {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_tool_call_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
result: str = "",
|
||||
is_error: bool = False,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool call completed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_use_id": tool_use_id,
|
||||
"tool_name": tool_name,
|
||||
"result": result,
|
||||
"is_error": is_error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === CLIENT I/O PUBLISHERS ===
|
||||
|
||||
async def emit_client_output_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit client output delta event (client_facing=True nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_OUTPUT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_client_input_requested(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
prompt: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit client input requested event (client_facing=True nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"prompt": prompt},
|
||||
)
|
||||
)
|
||||
|
||||
# === INTERNAL NODE PUBLISHERS ===
|
||||
|
||||
async def emit_node_internal_output(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node internal output event (client_facing=False nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_INTERNAL_OUTPUT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_stalled(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
reason: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node stalled event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_STALLED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"reason": reason},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_input_blocked(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
prompt: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node input blocked event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_INPUT_BLOCKED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"prompt": prompt},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
@@ -410,6 +684,7 @@ class EventBus:
|
||||
self,
|
||||
event_type: EventType,
|
||||
stream_id: str | None = None,
|
||||
node_id: str | None = None,
|
||||
execution_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AgentEvent | None:
|
||||
@@ -419,6 +694,7 @@ class EventBus:
|
||||
Args:
|
||||
event_type: Type of event to wait for
|
||||
stream_id: Filter by stream
|
||||
node_id: Filter by node
|
||||
execution_id: Filter by execution
|
||||
timeout: Maximum time to wait (seconds)
|
||||
|
||||
@@ -438,6 +714,7 @@ class EventBus:
|
||||
event_types=[event_type],
|
||||
handler=handler,
|
||||
filter_stream=stream_id,
|
||||
filter_node=node_id,
|
||||
filter_execution=execution_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IsolationLevel(str, Enum):
|
||||
class IsolationLevel(StrEnum):
|
||||
"""State isolation level for concurrent executions."""
|
||||
|
||||
ISOLATED = "isolated" # Private state per execution
|
||||
@@ -25,7 +25,7 @@ class IsolationLevel(str, Enum):
|
||||
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
|
||||
|
||||
|
||||
class StateScope(str, Enum):
|
||||
class StateScope(StrEnum):
|
||||
"""Scope for state operations."""
|
||||
|
||||
EXECUTION = "execution" # Local to a single execution
|
||||
|
||||
@@ -10,13 +10,13 @@ This is MORE important than actions because:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
class DecisionType(StrEnum):
|
||||
"""Types of decisions an agent can make."""
|
||||
|
||||
TOOL_SELECTION = "tool_selection" # Which tool to use
|
||||
|
||||
@@ -6,7 +6,7 @@ summaries and metrics that Builder needs to understand what happened.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
@@ -14,7 +14,7 @@ from pydantic import BaseModel, Field, computed_field
|
||||
from framework.schemas.decision import Decision, Outcome
|
||||
|
||||
|
||||
class RunStatus(str, Enum):
|
||||
class RunStatus(StrEnum):
|
||||
"""Status of a run."""
|
||||
|
||||
RUNNING = "running"
|
||||
|
||||
@@ -167,14 +167,18 @@ class ConcurrentStorage:
|
||||
run: Run to save
|
||||
immediate: If True, save immediately (bypasses batching)
|
||||
"""
|
||||
# Invalidate summary cache since the run data is changing
|
||||
# This ensures load_summary() fetches fresh data after the save
|
||||
self._cache.pop(f"summary:{run.id}", None)
|
||||
|
||||
if immediate or not self._running:
|
||||
await self._save_run_locked(run)
|
||||
# Update cache only after successful immediate write
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
else:
|
||||
# For batched writes, cache will be updated in _flush_batch after successful write
|
||||
await self._write_queue.put(("run", run))
|
||||
|
||||
# Update cache
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
|
||||
async def _save_run_locked(self, run: Run) -> None:
|
||||
"""Save a run with file locking, including index locks."""
|
||||
lock_key = f"run:{run.id}"
|
||||
@@ -363,8 +367,12 @@ class ConcurrentStorage:
|
||||
try:
|
||||
if item_type == "run":
|
||||
await self._save_run_locked(item)
|
||||
# Update cache only after successful batched write
|
||||
# This fixes the race condition where cache was updated before write completed
|
||||
self._cache[f"run:{item.id}"] = CacheEntry(item, time.time())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {item_type}: {e}")
|
||||
# Cache is NOT updated on failure - prevents stale/inconsistent cache state
|
||||
|
||||
async def _flush_pending(self) -> None:
|
||||
"""Flush all pending writes."""
|
||||
|
||||
@@ -6,13 +6,13 @@ programmatic/MCP-based approval.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalAction(str, Enum):
|
||||
class ApprovalAction(StrEnum):
|
||||
"""Actions a user can take on a generated test."""
|
||||
|
||||
APPROVE = "approve" # Accept as-is
|
||||
|
||||
@@ -24,7 +24,7 @@ def _get_api_key():
|
||||
# 1. Try CredentialStoreAdapter for Anthropic
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except (ImportError, KeyError):
|
||||
@@ -57,7 +57,7 @@ def _get_api_key():
|
||||
"""Get API key from CredentialStoreAdapter or environment."""
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except (ImportError, KeyError):
|
||||
|
||||
@@ -6,13 +6,13 @@ but require mandatory user approval before being stored.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
class ApprovalStatus(StrEnum):
|
||||
"""Status of user approval for a generated test."""
|
||||
|
||||
PENDING = "pending" # Awaiting user review
|
||||
@@ -21,7 +21,7 @@ class ApprovalStatus(str, Enum):
|
||||
REJECTED = "rejected" # User declined (with reason)
|
||||
|
||||
|
||||
class TestType(str, Enum):
|
||||
class TestType(StrEnum):
|
||||
"""Type of test based on what it validates."""
|
||||
|
||||
__test__ = False # Not a pytest test class
|
||||
|
||||
@@ -6,13 +6,13 @@ categorization for guiding iteration strategy.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorCategory(str, Enum):
|
||||
class ErrorCategory(StrEnum):
|
||||
"""
|
||||
Category of test failure for guiding iteration.
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ dependencies = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"pytest-xdist>=3.0",
|
||||
"tools",
|
||||
]
|
||||
|
||||
# [project.optional-dependencies]
|
||||
@@ -21,6 +22,9 @@ dependencies = [
|
||||
[project.scripts]
|
||||
hive = "framework.cli:main"
|
||||
|
||||
[tool.uv.sources]
|
||||
tools = { workspace = true }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
@@ -43,6 +47,7 @@ lint.select = [
|
||||
"W", # pycodestyle warnings
|
||||
]
|
||||
|
||||
lint.per-file-ignores."demos/*" = ["E501"]
|
||||
lint.isort.combine-as-imports = true
|
||||
lint.isort.known-first-party = ["framework"]
|
||||
lint.isort.section-order = [
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
# Development dependencies
|
||||
-r requirements.txt
|
||||
|
||||
# Testing
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.23
|
||||
|
||||
# Linting & type checking
|
||||
ruff>=0.1.0
|
||||
mypy>=1.0
|
||||
@@ -1,14 +0,0 @@
|
||||
# Core dependencies
|
||||
pydantic>=2.0
|
||||
anthropic>=0.40.0
|
||||
httpx>=0.27.0
|
||||
litellm>=1.81.0
|
||||
|
||||
# MCP server dependencies
|
||||
mcp
|
||||
fastmcp
|
||||
|
||||
# Testing (required for test framework)
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.23
|
||||
pytest-xdist>=3.0
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Tests for ClientIO gateway (WP-9).
|
||||
|
||||
Covers:
|
||||
- ActiveNodeClientIO: emit_output → output_stream round-trip, request_input, timeout
|
||||
- InertNodeClientIO: emit_output publishes NODE_INTERNAL_OUTPUT, request_input returns redirect
|
||||
- ClientIOGateway: factory creates correct variant
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.client_io import (
|
||||
ActiveNodeClientIO,
|
||||
ClientIOGateway,
|
||||
InertNodeClientIO,
|
||||
NodeClientIO,
|
||||
)
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
_AGENT_EVENT_FIELDS = {"stream_id", "node_id", "execution_id", "correlation_id"}
|
||||
|
||||
|
||||
class MockEventBus:
|
||||
"""Lightweight stand-in for EventBus that records published events."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: list[AgentEvent] = []
|
||||
|
||||
async def _record(self, event_type: EventType, **kwargs) -> None:
|
||||
agent_kwargs = {k: v for k, v in kwargs.items() if k in _AGENT_EVENT_FIELDS}
|
||||
data = {k: v for k, v in kwargs.items() if k not in _AGENT_EVENT_FIELDS}
|
||||
self.events.append(AgentEvent(type=event_type, **agent_kwargs, data=data))
|
||||
|
||||
async def emit_client_output_delta(self, **kwargs) -> None:
|
||||
await self._record(EventType.CLIENT_OUTPUT_DELTA, **kwargs)
|
||||
|
||||
async def emit_client_input_requested(self, **kwargs) -> None:
|
||||
await self._record(EventType.CLIENT_INPUT_REQUESTED, **kwargs)
|
||||
|
||||
async def emit_node_internal_output(self, **kwargs) -> None:
|
||||
await self._record(EventType.NODE_INTERNAL_OUTPUT, **kwargs)
|
||||
|
||||
async def emit_node_input_blocked(self, **kwargs) -> None:
|
||||
await self._record(EventType.NODE_INPUT_BLOCKED, **kwargs)
|
||||
|
||||
|
||||
# --- ActiveNodeClientIO tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_emit_and_consume():
|
||||
"""emit_output → output_stream round-trip works correctly."""
|
||||
bus = MockEventBus()
|
||||
io = ActiveNodeClientIO(node_id="n1", event_bus=bus)
|
||||
|
||||
await io.emit_output("Hello ")
|
||||
await io.emit_output("World", is_final=True)
|
||||
|
||||
chunks = []
|
||||
async for chunk in io.output_stream():
|
||||
chunks.append(chunk)
|
||||
|
||||
assert chunks == ["Hello ", "World"]
|
||||
assert len(bus.events) == 2
|
||||
assert all(e.type == EventType.CLIENT_OUTPUT_DELTA for e in bus.events)
|
||||
# Verify snapshot accumulates
|
||||
assert bus.events[0].data["snapshot"] == "Hello "
|
||||
assert bus.events[1].data["snapshot"] == "Hello World"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_request_input():
|
||||
"""request_input blocks until provide_input is called."""
|
||||
bus = MockEventBus()
|
||||
io = ActiveNodeClientIO(node_id="n1", event_bus=bus)
|
||||
|
||||
async def fulfill_later():
|
||||
await asyncio.sleep(0.01)
|
||||
await io.provide_input("user says hi")
|
||||
|
||||
task = asyncio.create_task(fulfill_later())
|
||||
result = await io.request_input(prompt="What?")
|
||||
await task
|
||||
|
||||
assert result == "user says hi"
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
assert bus.events[0].data["prompt"] == "What?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_request_input_timeout():
|
||||
"""request_input raises TimeoutError when timeout expires."""
|
||||
io = ActiveNodeClientIO(node_id="n1")
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await io.request_input(prompt="waiting", timeout=0.01)
|
||||
|
||||
|
||||
# --- InertNodeClientIO tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inert_emit_publishes_internal():
|
||||
"""InertNodeClientIO.emit_output publishes NODE_INTERNAL_OUTPUT."""
|
||||
bus = MockEventBus()
|
||||
io = InertNodeClientIO(node_id="n2", event_bus=bus)
|
||||
|
||||
await io.emit_output("internal log")
|
||||
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.NODE_INTERNAL_OUTPUT
|
||||
assert bus.events[0].data["content"] == "internal log"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inert_request_input_returns_redirect():
|
||||
"""request_input returns a redirect string and publishes NODE_INPUT_BLOCKED."""
|
||||
bus = MockEventBus()
|
||||
io = InertNodeClientIO(node_id="n2", event_bus=bus)
|
||||
|
||||
result = await io.request_input(prompt="need data")
|
||||
|
||||
assert "internal processing node" in result
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.NODE_INPUT_BLOCKED
|
||||
assert bus.events[0].data["prompt"] == "need data"
|
||||
|
||||
|
||||
# --- ClientIOGateway tests ---
|
||||
|
||||
|
||||
def test_gateway_creates_active_for_client_facing():
|
||||
"""ClientIOGateway.create_io returns ActiveNodeClientIO when client_facing=True."""
|
||||
gateway = ClientIOGateway()
|
||||
io = gateway.create_io(node_id="n1", client_facing=True)
|
||||
|
||||
assert isinstance(io, ActiveNodeClientIO)
|
||||
assert isinstance(io, NodeClientIO)
|
||||
|
||||
|
||||
def test_gateway_creates_inert_for_internal():
|
||||
"""ClientIOGateway.create_io returns InertNodeClientIO when client_facing=False."""
|
||||
gateway = ClientIOGateway()
|
||||
io = gateway.create_io(node_id="n2", client_facing=False)
|
||||
|
||||
assert isinstance(io, InertNodeClientIO)
|
||||
assert isinstance(io, NodeClientIO)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tests for ConcurrentStorage race condition and cache invalidation fixes."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.schemas.run import Run, RunMetrics, RunStatus
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
|
||||
def create_test_run(
|
||||
run_id: str, goal_id: str = "test-goal", status: RunStatus = RunStatus.RUNNING
|
||||
) -> Run:
|
||||
"""Create a minimal test Run object."""
|
||||
return Run(
|
||||
id=run_id,
|
||||
goal_id=goal_id,
|
||||
status=status,
|
||||
narrative="Test run",
|
||||
metrics=RunMetrics(
|
||||
nodes_executed=[],
|
||||
),
|
||||
decisions=[],
|
||||
problems=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidation_on_save(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated when a run is saved.
|
||||
|
||||
This tests the fix for the cache invalidation bug where load_summary()
|
||||
would return stale data after a run was updated.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-1"
|
||||
|
||||
# Create and save initial run
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to populate the cache
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.RUNNING
|
||||
|
||||
# Update run with new status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary again - should get fresh data, not cached stale data
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.COMPLETED, (
|
||||
"Summary cache should be invalidated on save - got stale data"
|
||||
)
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_write_cache_consistency(tmp_path: Path):
|
||||
"""Test that cache is only updated after successful batched write.
|
||||
|
||||
This tests the fix for the race condition where cache was updated
|
||||
before the batched write completed.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path, batch_interval=0.05)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-2"
|
||||
|
||||
# Save via batching (immediate=False)
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=False)
|
||||
|
||||
# Before batch flush, cache should NOT contain the run
|
||||
# (This is the fix - previously cache was updated immediately)
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key not in storage._cache, (
|
||||
"Cache should not be updated before batch is flushed"
|
||||
)
|
||||
|
||||
# Wait for batch to flush
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# After batch flush, cache should contain the run
|
||||
assert cache_key in storage._cache, "Cache should be updated after batch flush"
|
||||
|
||||
# Verify data on disk matches cache
|
||||
loaded_run = await storage.load_run(run_id, use_cache=False)
|
||||
assert loaded_run is not None
|
||||
assert loaded_run.id == run_id
|
||||
assert loaded_run.status == RunStatus.RUNNING
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_immediate_write_updates_cache(tmp_path: Path):
|
||||
"""Test that immediate writes still update cache correctly."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-3"
|
||||
|
||||
# Save with immediate=True
|
||||
run = create_test_run(run_id, status=RunStatus.COMPLETED)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Cache should be updated immediately for immediate writes
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key in storage._cache, "Cache should be updated after immediate write"
|
||||
|
||||
# Verify cached value is correct
|
||||
cached_run = storage._cache[cache_key].value
|
||||
assert cached_run.id == run_id
|
||||
assert cached_run.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_cache_invalidated_on_multiple_saves(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated on each save, not just the first."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-4"
|
||||
|
||||
# First save
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to cache it
|
||||
summary1 = await storage.load_summary(run_id)
|
||||
assert summary1.status == RunStatus.RUNNING
|
||||
|
||||
# Second save with new status
|
||||
run.status = RunStatus.RUNNING
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh
|
||||
summary2 = await storage.load_summary(run_id)
|
||||
assert summary2.status == RunStatus.RUNNING
|
||||
|
||||
# Third save with final status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh again
|
||||
summary3 = await storage.load_summary(run_id)
|
||||
assert summary3.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
@@ -0,0 +1,326 @@
|
||||
"""Tests for ContextHandoff and HandoffContext."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SpyLLMProvider(MockLLMProvider):
|
||||
"""MockLLMProvider that records whether complete() was called."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.complete_called = False
|
||||
self.complete_call_args: dict[str, Any] | None = None
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
self.complete_called = True
|
||||
self.complete_call_args = {"messages": messages, **kwargs}
|
||||
return super().complete(messages, **kwargs)
|
||||
|
||||
|
||||
class FailingLLMProvider(LLMProvider):
|
||||
"""LLM provider that always raises."""
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list,
|
||||
tool_executor: Any,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
|
||||
async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation:
|
||||
"""Build a NodeConversation from (user, assistant) message pairs."""
|
||||
conv = NodeConversation()
|
||||
for user_msg, assistant_msg in pairs:
|
||||
await conv.add_user_message(user_msg)
|
||||
await conv.add_assistant_message(assistant_msg)
|
||||
return conv
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestHandoffContext
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandoffContext:
|
||||
def test_instantiation(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="node_A",
|
||||
summary="Summary text",
|
||||
key_outputs={"result": "42"},
|
||||
turn_count=3,
|
||||
total_tokens_used=1200,
|
||||
)
|
||||
assert hc.source_node_id == "node_A"
|
||||
assert hc.summary == "Summary text"
|
||||
assert hc.key_outputs == {"result": "42"}
|
||||
assert hc.turn_count == 3
|
||||
assert hc.total_tokens_used == 1200
|
||||
|
||||
def test_field_access(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="n1",
|
||||
summary="s",
|
||||
key_outputs={},
|
||||
turn_count=0,
|
||||
total_tokens_used=0,
|
||||
)
|
||||
assert hc.key_outputs == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestExtractiveSummary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractiveSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_summary_includes_first_last(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("hello", "First response here."),
|
||||
("continue", "Middle response."),
|
||||
("finish", "Final conclusion."),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="test_node")
|
||||
|
||||
assert "First response here." in hc.summary
|
||||
assert "Final conclusion." in hc.summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_summary_metadata(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("hi", "hello"),
|
||||
("bye", "goodbye"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="node_42")
|
||||
|
||||
assert hc.source_node_id == "node_42"
|
||||
assert hc.turn_count == 2
|
||||
assert hc.total_tokens_used > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_with_output_keys_colon(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("what is the answer?", "answer: 42"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["answer"])
|
||||
|
||||
assert hc.key_outputs["answer"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_with_output_keys_equals(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("compute", "result = success"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["result"])
|
||||
|
||||
assert hc.key_outputs["result"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_json_output_keys(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("give me json", '{"score": 95, "grade": "A"}'),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
|
||||
|
||||
assert hc.key_outputs["score"] == "95"
|
||||
assert hc.key_outputs["grade"] == "A"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_empty_conversation(self) -> None:
|
||||
conv = NodeConversation()
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="empty")
|
||||
|
||||
assert hc.summary == "Empty conversation."
|
||||
assert hc.turn_count == 0
|
||||
assert hc.key_outputs == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_no_assistant_messages(self) -> None:
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("hello?")
|
||||
await conv.add_user_message("anyone there?")
|
||||
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="silent")
|
||||
|
||||
assert hc.summary == "No assistant responses."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_most_recent_wins(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("first", "status: old_value"),
|
||||
("second", "status: new_value"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["status"])
|
||||
|
||||
assert hc.key_outputs["status"] == "new_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_truncation(self) -> None:
|
||||
long_text = "x" * 1000
|
||||
conv = await _build_conversation(
|
||||
("go", long_text),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n")
|
||||
|
||||
# Summary should be truncated to ~500 chars
|
||||
assert len(hc.summary) <= 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestLLMSummary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_summary_calls_provider(self) -> None:
|
||||
llm = SpyLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("hi", "hello back"),
|
||||
("what now?", "we are done"),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
hc = ch.summarize_conversation(conv, node_id="llm_node")
|
||||
|
||||
assert llm.complete_called, "LLM complete() was never invoked"
|
||||
assert hc.summary == "This is a mock response for testing purposes."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_summary_includes_output_key_hint(self) -> None:
|
||||
llm = SpyLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("compute", '{"score": 95}'),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
|
||||
|
||||
assert llm.complete_call_args is not None
|
||||
system = llm.complete_call_args.get("system", "")
|
||||
assert "score" in system
|
||||
assert "grade" in system
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_fallback_on_error(self) -> None:
|
||||
llm = FailingLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("start", "First assistant message."),
|
||||
("end", "Last assistant message."),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
hc = ch.summarize_conversation(conv, node_id="fallback_node")
|
||||
|
||||
# Should fall back to extractive (first + last assistant messages)
|
||||
assert "First assistant message." in hc.summary
|
||||
assert "Last assistant message." in hc.summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestFormatAsInput
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatAsInput:
|
||||
def test_format_structure(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="analyzer",
|
||||
summary="Analysis complete.",
|
||||
key_outputs={"score": "95"},
|
||||
turn_count=5,
|
||||
total_tokens_used=2000,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "--- CONTEXT FROM: analyzer" in output
|
||||
assert "KEY OUTPUTS:" in output
|
||||
assert "SUMMARY:" in output
|
||||
assert "--- END CONTEXT ---" in output
|
||||
|
||||
def test_format_no_key_outputs(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="simple",
|
||||
summary="Done.",
|
||||
key_outputs={},
|
||||
turn_count=1,
|
||||
total_tokens_used=100,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "KEY OUTPUTS:" not in output
|
||||
assert "SUMMARY:" in output
|
||||
|
||||
def test_format_content_values(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="node_X",
|
||||
summary="Found 3 bugs.",
|
||||
key_outputs={"bugs": "3", "severity": "high"},
|
||||
turn_count=7,
|
||||
total_tokens_used=5000,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "node_X" in output
|
||||
assert "7 turns" in output
|
||||
assert "~5000 tokens" in output
|
||||
assert "- bugs: 3" in output
|
||||
assert "- severity: high" in output
|
||||
assert "Found 3 bugs." in output
|
||||
|
||||
def test_format_empty_summary(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="n",
|
||||
summary="",
|
||||
key_outputs={},
|
||||
turn_count=0,
|
||||
total_tokens_used=0,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "No summary available." in output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_as_input_usable_as_message(self) -> None:
|
||||
"""Formatted output can be fed into a NodeConversation as a user message."""
|
||||
hc = HandoffContext(
|
||||
source_node_id="prev_node",
|
||||
summary="Completed analysis.",
|
||||
key_outputs={"result": "42"},
|
||||
turn_count=3,
|
||||
total_tokens_used=900,
|
||||
)
|
||||
text = ContextHandoff.format_as_input(hc)
|
||||
|
||||
conv = NodeConversation()
|
||||
msg = await conv.add_user_message(text)
|
||||
|
||||
assert msg.role == "user"
|
||||
assert "CONTEXT FROM: prev_node" in msg.content
|
||||
assert conv.turn_count == 1
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,906 @@
|
||||
"""WP-8: Tests for EventLoopNode, OutputAccumulator, LoopConfig, JudgeProtocol.
|
||||
|
||||
Uses real FileConversationStore (no mocks for storage) and a MockStreamingLLM
|
||||
that yields pre-programmed StreamEvents to control the loop deterministically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
JudgeProtocol,
|
||||
JudgeVerdict,
|
||||
LoopConfig,
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeSpec, SharedMemory
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.event_bus import EventBus, EventType
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM that yields pre-programmed stream events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences.
|
||||
|
||||
Each call to stream() consumes the next scenario from the list.
|
||||
Cycles back to the beginning if more calls are made than scenarios.
|
||||
"""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build a simple text-only scenario
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list:
|
||||
"""Build a stream scenario that produces text and finishes."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(
|
||||
stop_reason="stop", input_tokens=input_tokens, output_tokens=output_tokens, model="mock"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def tool_call_scenario(
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
tool_use_id: str = "call_1",
|
||||
text: str = "",
|
||||
) -> list:
|
||||
"""Build a stream scenario that produces a tool call."""
|
||||
events = []
|
||||
if text:
|
||||
events.append(TextDeltaEvent(content=text, snapshot=text))
|
||||
events.append(
|
||||
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
|
||||
)
|
||||
events.append(
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_spec():
|
||||
return NodeSpec(
|
||||
id="test_loop",
|
||||
name="Test Loop",
|
||||
description="A test event loop node",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
system_prompt="You are a test assistant.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory():
|
||||
return SharedMemory()
|
||||
|
||||
|
||||
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
|
||||
"""Build a NodeContext for testing."""
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_spec.id,
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data=input_data or {},
|
||||
llm=llm,
|
||||
available_tools=tools or [],
|
||||
goal_context=goal_context,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# NodeProtocol conformance
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNodeProtocolConformance:
|
||||
def test_subclasses_node_protocol(self):
|
||||
"""EventLoopNode must be a subclass of NodeProtocol."""
|
||||
assert issubclass(EventLoopNode, NodeProtocol)
|
||||
|
||||
def test_has_execute_method(self):
|
||||
node = EventLoopNode()
|
||||
assert hasattr(node, "execute")
|
||||
assert asyncio.iscoroutinefunction(node.execute)
|
||||
|
||||
def test_has_validate_input(self):
|
||||
node = EventLoopNode()
|
||||
assert hasattr(node, "validate_input")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Basic loop execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestBasicLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_text_only_implicit_accept(self, runtime, node_spec, memory):
|
||||
"""No tools, no judge. LLM produces text, implicit accept on stop."""
|
||||
# Override to no output_keys so implicit judge accepts immediately
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello world")])
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.tokens_used > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_llm_returns_failure(self, runtime, node_spec, memory):
|
||||
"""ctx.llm=None should return failure immediately."""
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm=None)
|
||||
|
||||
node = EventLoopNode()
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "LLM" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_failure(self, runtime, node_spec, memory):
|
||||
"""When max_iterations is reached without acceptance, should fail."""
|
||||
# LLM always produces text but never calls set_output, so implicit
|
||||
# judge retries asking for missing keys
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "Max iterations" in result.error
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Judge integration
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestJudgeIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_accept(self, runtime, node_spec, memory):
|
||||
"""Mock judge ACCEPT -> success."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Done!")])
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
judge.evaluate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_escalate(self, runtime, node_spec, memory):
|
||||
"""Mock judge ESCALATE -> failure."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Attempt")])
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(
|
||||
return_value=JudgeVerdict(action="ESCALATE", feedback="Tone violation")
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "escalated" in result.error.lower()
|
||||
assert "Tone violation" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_retry_then_accept(self, runtime, node_spec, memory):
|
||||
"""RETRY twice, then ACCEPT. Should run 3 iterations."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("attempt 1"),
|
||||
text_scenario("attempt 2"),
|
||||
text_scenario("attempt 3"),
|
||||
]
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def evaluate_fn(context):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
return JudgeVerdict(action="RETRY", feedback="Try harder")
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=10))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert call_count == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# set_output tool
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSetOutput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_accumulates(self, runtime, node_spec, memory):
|
||||
"""LLM calls set_output -> values appear in NodeResult.output."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call set_output
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "42"}),
|
||||
# Turn 2: text response (triggers implicit judge)
|
||||
text_scenario("Done, result is 42"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory):
|
||||
"""set_output with key not in output_keys -> is_error=True."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call set_output with bad key
|
||||
tool_call_scenario("set_output", {"key": "bad_key", "value": "x"}),
|
||||
# Turn 2: call set_output with good key
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
|
||||
# Turn 3: text done
|
||||
text_scenario("Done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "ok"
|
||||
assert "bad_key" not in result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
|
||||
"""Judge accepts but output keys are missing -> retry with hint."""
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: text without set_output -> judge accepts but keys missing -> retry
|
||||
text_scenario("I'll get to it"),
|
||||
# Turn 2: set_output
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
# Turn 3: text -> judge accepts, keys present -> success
|
||||
text_scenario("All done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "done"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stall detection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStallDetection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stall_detection(self, runtime, node_spec, memory):
|
||||
"""3 identical responses should trigger stall detection."""
|
||||
node_spec.output_keys = [] # so implicit judge would accept
|
||||
# But we need the judge to RETRY so we actually get 3 identical responses
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
|
||||
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("same answer")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=10, stall_detection_threshold=3),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "stalled" in result.error.lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# EventBus lifecycle events
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEventBusLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_events_published(self, runtime, node_spec, memory):
|
||||
"""NODE_LOOP_STARTED, NODE_LOOP_ITERATION, NODE_LOOP_COMPLETED should be published."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("ok")])
|
||||
bus = EventBus()
|
||||
|
||||
received_events = []
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
],
|
||||
handler=lambda e: received_events.append(e.type),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert EventType.NODE_LOOP_STARTED in received_events
|
||||
assert EventType.NODE_LOOP_ITERATION in received_events
|
||||
assert EventType.NODE_LOOP_COMPLETED in received_events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
|
||||
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
|
||||
spec = NodeSpec(
|
||||
id="ui_node",
|
||||
name="UI Node",
|
||||
description="Streams to user",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("visible to user")])
|
||||
bus = EventBus()
|
||||
|
||||
received_types = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.LLM_TEXT_DELTA],
|
||||
handler=lambda e: received_types.append(e.type),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
|
||||
# client_facing + text-only blocks for user input; use shutdown to unblock
|
||||
async def auto_shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(auto_shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert EventType.CLIENT_OUTPUT_DELTA in received_types
|
||||
assert EventType.LLM_TEXT_DELTA not in received_types
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Client-facing blocking
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestClientFacingBlocking:
|
||||
"""Tests for native client_facing input blocking in EventLoopNode."""
|
||||
|
||||
@pytest.fixture
|
||||
def client_spec(self):
|
||||
return NodeSpec(
|
||||
id="chat",
|
||||
name="Chat",
|
||||
description="chat node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_blocks_on_text(self, runtime, memory, client_spec):
|
||||
"""client_facing + text-only response blocks until inject_event."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("Hello!"),
|
||||
text_scenario("Got your message."),
|
||||
]
|
||||
)
|
||||
bus = EventBus()
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def user_responds():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("I need help")
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
user_task = asyncio.create_task(user_responds())
|
||||
result = await node.execute(ctx)
|
||||
await user_task
|
||||
|
||||
assert result.success is True
|
||||
# LLM should have been called at least twice (first response + after inject)
|
||||
assert llm._call_index >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_does_not_block_on_tools(self, runtime, memory):
|
||||
"""client_facing + tool calls should NOT block — judge evaluates normally."""
|
||||
spec = NodeSpec(
|
||||
id="chat",
|
||||
name="Chat",
|
||||
description="chat node",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
client_facing=True,
|
||||
)
|
||||
# Scenario 1: LLM calls set_output (tool call present → no blocking, judge RETRYs)
|
||||
# Scenario 2: LLM produces text (implicit judge sees output key set → ACCEPT)
|
||||
# But scenario 2 is text-only on client_facing → would block.
|
||||
# So we need shutdown to handle that case.
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
text_scenario("All set!"),
|
||||
]
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
# After set_output, implicit judge RETRYs (tool calls present).
|
||||
# Next turn: text-only on client_facing → blocks.
|
||||
# But implicit judge should ACCEPT first (output key is set, no tools).
|
||||
# Actually, client_facing check happens BEFORE judge, so it blocks.
|
||||
# Use shutdown as safety net.
|
||||
async def auto_shutdown():
|
||||
await asyncio.sleep(0.1)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(auto_shutdown())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_client_facing_unchanged(self, runtime, memory):
|
||||
"""client_facing=False should not block — existing behavior."""
|
||||
spec = NodeSpec(
|
||||
id="internal",
|
||||
name="Internal",
|
||||
description="internal node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
# Should complete without blocking (implicit judge ACCEPTs on no tools + no keys)
|
||||
result = await node.execute(ctx)
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_shutdown_unblocks(self, runtime, memory, client_spec):
|
||||
"""signal_shutdown should unblock a waiting client_facing node."""
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Waiting...")])
|
||||
bus = EventBus()
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=10))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def shutdown_after_delay():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown_after_delay())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_input_requested_event_published(self, runtime, memory, client_spec):
|
||||
"""CLIENT_INPUT_REQUESTED should be published when blocking."""
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello!")])
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def capture(e):
|
||||
received.append(e)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=capture,
|
||||
)
|
||||
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert len(received) >= 1
|
||||
assert received[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tool execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestToolExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_feedback(self, runtime, node_spec, memory):
|
||||
"""Tool call -> result fed back to conversation via stream loop."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
def my_tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=f"Result for {tool_use.name}",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call a tool
|
||||
tool_call_scenario("search", {"query": "test"}, tool_use_id="call_search"),
|
||||
# Turn 2: text response after seeing tool result
|
||||
text_scenario("Found the answer"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="Search", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
tool_executor=my_tool_executor,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
# stream() should have been called twice (tool call turn + final text turn)
|
||||
assert llm._call_index >= 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Write-through persistence with real FileConversationStore
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWriteThroughPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_written_to_store(self, tmp_path, runtime, node_spec, memory):
|
||||
"""Messages should be persisted immediately via write-through."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
# Verify parts were written to disk
|
||||
parts = await store.read_parts()
|
||||
assert len(parts) >= 2 # at least initial user msg + assistant msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_accumulator_write_through(self, tmp_path, runtime, node_spec, memory):
|
||||
"""set_output values should be persisted in cursor immediately."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "persisted_value"}),
|
||||
text_scenario("Done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "persisted_value"
|
||||
|
||||
# Verify output was written to cursor on disk
|
||||
cursor = await store.read_cursor()
|
||||
assert cursor is not None
|
||||
assert cursor["outputs"]["result"] == "persisted_value"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Crash recovery (restore from real FileConversationStore)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCrashRecovery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_checkpoint(self, tmp_path, runtime, node_spec, memory):
|
||||
"""Populate a store with state, then verify EventLoopNode restores from it."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
|
||||
# Simulate a previous run that wrote conversation + cursor
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a test assistant.",
|
||||
output_keys=["result"],
|
||||
store=store,
|
||||
)
|
||||
await conv.add_user_message("Initial input")
|
||||
await conv.add_assistant_message("Working on it...")
|
||||
|
||||
# Write cursor with iteration and outputs
|
||||
await store.write_cursor(
|
||||
{
|
||||
"iteration": 1,
|
||||
"next_seq": conv.next_seq,
|
||||
"outputs": {"result": "partial_value"},
|
||||
}
|
||||
)
|
||||
|
||||
# Now create a new EventLoopNode and execute -- it should restore
|
||||
node_spec.output_keys = [] # no required keys so implicit accept works
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Continuing...")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
# Should have the restored output
|
||||
assert result.output.get("result") == "partial_value"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# External event injection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEventInjection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_event(self, runtime, node_spec, memory):
|
||||
"""inject_event() content should appear as user message in next iteration."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
judge_calls = []
|
||||
|
||||
async def evaluate_fn(context):
|
||||
judge_calls.append(context)
|
||||
if len(judge_calls) >= 2:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("iteration 1"),
|
||||
text_scenario("iteration 2"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
|
||||
# Pre-inject an event before execute runs
|
||||
await node.inject_event("Priority: CEO wants meeting rescheduled")
|
||||
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
# Verify the injected content made it into the LLM messages
|
||||
all_messages = []
|
||||
for call in llm.stream_calls:
|
||||
all_messages.extend(call["messages"])
|
||||
injected_found = any("[External event]" in str(m.get("content", "")) for m in all_messages)
|
||||
assert injected_found
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Pause/resume
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestPauseResume:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_returns_early(self, runtime, node_spec, memory):
|
||||
"""pause_requested in input_data should trigger early return."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("should not run")])
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
input_data={"pause_requested": True},
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
# Should return success (paused, not failed)
|
||||
assert result.success is True
|
||||
# LLM should not have been called (paused before first turn)
|
||||
assert llm._call_index == 0
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stream errors
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamErrors:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_recoverable_stream_error_raises(self, runtime, node_spec, memory):
|
||||
"""Non-recoverable StreamErrorEvent should raise RuntimeError."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
[StreamErrorEvent(error="Connection lost", recoverable=False)],
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
|
||||
with pytest.raises(RuntimeError, match="Stream error"):
|
||||
await node.execute(ctx)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# OutputAccumulator unit tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOutputAccumulator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get(self):
|
||||
acc = OutputAccumulator()
|
||||
await acc.set("key1", "value1")
|
||||
assert acc.get("key1") == "value1"
|
||||
assert acc.get("nonexistent") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_dict(self):
|
||||
acc = OutputAccumulator()
|
||||
await acc.set("a", 1)
|
||||
await acc.set("b", 2)
|
||||
assert acc.to_dict() == {"a": 1, "b": 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_all_keys(self):
|
||||
acc = OutputAccumulator()
|
||||
assert acc.has_all_keys([]) is True
|
||||
assert acc.has_all_keys(["x"]) is False
|
||||
await acc.set("x", "val")
|
||||
assert acc.has_all_keys(["x"]) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_through_to_real_store(self, tmp_path):
|
||||
"""OutputAccumulator should write through to FileConversationStore cursor."""
|
||||
store = FileConversationStore(tmp_path / "acc_test")
|
||||
acc = OutputAccumulator(store=store)
|
||||
|
||||
await acc.set("result", "hello")
|
||||
|
||||
cursor = await store.read_cursor()
|
||||
assert cursor["outputs"]["result"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_real_store(self, tmp_path):
|
||||
"""OutputAccumulator.restore() should rebuild from FileConversationStore."""
|
||||
store = FileConversationStore(tmp_path / "acc_restore")
|
||||
await store.write_cursor({"outputs": {"key1": "val1", "key2": "val2"}})
|
||||
|
||||
acc = await OutputAccumulator.restore(store)
|
||||
assert acc.get("key1") == "val1"
|
||||
assert acc.get("key2") == "val2"
|
||||
assert acc.has_all_keys(["key1", "key2"]) is True
|
||||
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Tests for event_loop node type wiring (Issue #2513).
|
||||
|
||||
Covers:
|
||||
- NodeSpec.client_facing field
|
||||
- event_loop in VALID_NODE_TYPES
|
||||
- _get_node_implementation() event_loop branch
|
||||
- no-retry enforcement in serial execution path
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
class AlwaysFailsNode(NodeProtocol):
|
||||
"""A test node that always fails."""
|
||||
|
||||
def __init__(self):
|
||||
self.attempt_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
self.attempt_count += 1
|
||||
return NodeResult(success=False, error=f"Permanent error (attempt {self.attempt_count})")
|
||||
|
||||
|
||||
class SucceedsOnceNode(NodeProtocol):
|
||||
"""A test node that always succeeds."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"result": "ok"})
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fast_sleep(monkeypatch):
|
||||
"""Mock asyncio.sleep to avoid real delays from exponential backoff."""
|
||||
monkeypatch.setattr("asyncio.sleep", AsyncMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
"""Create a mock Runtime for testing."""
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.start_run = MagicMock(return_value="test_run_id")
|
||||
runtime.decide = MagicMock(return_value="test_decision_id")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
runtime.report_problem = MagicMock()
|
||||
runtime.set_node = MagicMock()
|
||||
return runtime
|
||||
|
||||
|
||||
# --- NodeSpec.client_facing tests ---
|
||||
|
||||
|
||||
def test_client_facing_defaults_false():
|
||||
"""NodeSpec without client_facing should default to False."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="llm_generate",
|
||||
)
|
||||
assert spec.client_facing is False
|
||||
|
||||
|
||||
def test_client_facing_explicit_true():
|
||||
"""NodeSpec with client_facing=True should retain the value."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
)
|
||||
assert spec.client_facing is True
|
||||
|
||||
|
||||
# --- VALID_NODE_TYPES tests ---
|
||||
|
||||
|
||||
def test_event_loop_in_valid_node_types():
|
||||
"""'event_loop' must be in GraphExecutor.VALID_NODE_TYPES."""
|
||||
assert "event_loop" in GraphExecutor.VALID_NODE_TYPES
|
||||
|
||||
|
||||
def test_event_loop_node_spec_accepted():
|
||||
"""Creating a NodeSpec with node_type='event_loop' should not raise."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
assert spec.node_type == "event_loop"
|
||||
|
||||
|
||||
# --- _get_node_implementation() tests ---
|
||||
|
||||
|
||||
def test_unregistered_event_loop_raises(runtime):
|
||||
"""An event_loop node not in the registry should raise RuntimeError."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
|
||||
with pytest.raises(RuntimeError, match="not found in registry"):
|
||||
executor._get_node_implementation(spec)
|
||||
|
||||
|
||||
def test_registered_event_loop_returns_impl(runtime):
|
||||
"""A registered event_loop node should be returned from the registry."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
impl = SucceedsOnceNode()
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("el1", impl)
|
||||
|
||||
result = executor._get_node_implementation(spec)
|
||||
assert result is impl
|
||||
|
||||
|
||||
# --- No-retry enforcement (serial path) ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_forced_zero(runtime):
|
||||
"""An event_loop node with max_retries=3 should only execute once (no retry)."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_fail",
|
||||
name="Failing Event Loop",
|
||||
description="event loop that fails",
|
||||
node_type="event_loop",
|
||||
max_retries=3,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_fail",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_fail"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_fail", failing_node)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Event loop nodes get max_retries overridden to 0, meaning execute once then fail
|
||||
assert not result.success
|
||||
assert failing_node.attempt_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_zero_no_warning(runtime, caplog):
|
||||
"""An event_loop node with max_retries=0 should not log a warning."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_zero",
|
||||
name="Zero Retry Event Loop",
|
||||
description="event loop with 0 retries",
|
||||
node_type="event_loop",
|
||||
max_retries=0,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_zero",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_zero"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_zero", failing_node)
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
# max_retries=0 should not trigger the override warning
|
||||
assert "Overriding to 0" not in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_positive_logs_warning(runtime, caplog):
|
||||
"""An event_loop node with max_retries=3 should log a warning about override."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_warn",
|
||||
name="Warning Event Loop",
|
||||
description="event loop with retries",
|
||||
node_type="event_loop",
|
||||
max_retries=3,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_warn",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_warn"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_warn", failing_node)
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
assert "Overriding to 0" in caplog.text
|
||||
assert "el_warn" in caplog.text
|
||||
|
||||
|
||||
# --- Existing node types unaffected ---
|
||||
|
||||
|
||||
def test_existing_node_types_unchanged():
|
||||
"""All pre-existing node types must still be in VALID_NODE_TYPES with defaults preserved."""
|
||||
expected = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
assert expected.issubset(GraphExecutor.VALID_NODE_TYPES)
|
||||
|
||||
# Default node_type is still llm_tool_use
|
||||
spec = NodeSpec(id="x", name="X", description="x")
|
||||
assert spec.node_type == "llm_tool_use"
|
||||
|
||||
# Default max_retries is still 3
|
||||
assert spec.max_retries == 3
|
||||
|
||||
# Default client_facing is False
|
||||
assert spec.client_facing is False
|
||||
@@ -0,0 +1,978 @@
|
||||
"""Tests for extending the stream event type system.
|
||||
|
||||
Validates that the StreamEvent discriminated union pattern supports:
|
||||
- Type-based dispatch (matching on event.type)
|
||||
- Pattern matching / isinstance branching
|
||||
- Custom event subclasses following the same frozen-dataclass convention
|
||||
- Serialization of mixed event sequences
|
||||
|
||||
WP-2 tests validate EventType enum extension and node-level event routing:
|
||||
- All 12 new EventType enum members with correct string values
|
||||
- node_id routing on AgentEvent
|
||||
- filter_node on Subscription
|
||||
- Backward compatibility with existing enum members
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import FrozenInstanceError, asdict, dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType, Subscription
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers: type-based dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
def dispatch_event(event) -> str:
|
||||
"""Dispatch an event by its type field, returning a label."""
|
||||
handlers = {
|
||||
"text_delta": lambda e: f"text:{e.content}",
|
||||
"text_end": lambda e: f"end:{len(e.full_text)}chars",
|
||||
"tool_call": lambda e: f"call:{e.tool_name}",
|
||||
"tool_result": lambda e: f"result:{e.tool_use_id}",
|
||||
"reasoning_start": lambda _: "reasoning:start",
|
||||
"reasoning_delta": lambda e: f"reasoning:{e.content[:20]}",
|
||||
"finish": lambda e: f"finish:{e.stop_reason}",
|
||||
"error": lambda e: f"error:{e.error}",
|
||||
}
|
||||
handler = handlers.get(event.type)
|
||||
if handler is None:
|
||||
return f"unknown:{event.type}"
|
||||
return handler(event)
|
||||
|
||||
|
||||
def collect_text(events: list) -> str:
|
||||
"""Accumulate full text from a stream of events."""
|
||||
for event in reversed(events):
|
||||
if isinstance(event, TextEndEvent):
|
||||
return event.full_text
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
return event.snapshot
|
||||
return ""
|
||||
|
||||
|
||||
def extract_tool_calls(events: list) -> list[dict[str, Any]]:
|
||||
"""Extract tool call info from a stream of events."""
|
||||
return [
|
||||
{"id": e.tool_use_id, "name": e.tool_name, "input": e.tool_input}
|
||||
for e in events
|
||||
if isinstance(e, ToolCallEvent)
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type-based dispatch tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTypeDispatch:
|
||||
"""Dispatch on event.type string for handler routing."""
|
||||
|
||||
def test_dispatch_text_delta(self):
|
||||
e = TextDeltaEvent(content="hello")
|
||||
assert dispatch_event(e) == "text:hello"
|
||||
|
||||
def test_dispatch_text_end(self):
|
||||
e = TextEndEvent(full_text="hello world")
|
||||
assert dispatch_event(e) == "end:11chars"
|
||||
|
||||
def test_dispatch_tool_call(self):
|
||||
e = ToolCallEvent(tool_name="web_search")
|
||||
assert dispatch_event(e) == "call:web_search"
|
||||
|
||||
def test_dispatch_tool_result(self):
|
||||
e = ToolResultEvent(tool_use_id="abc")
|
||||
assert dispatch_event(e) == "result:abc"
|
||||
|
||||
def test_dispatch_reasoning_start(self):
|
||||
e = ReasoningStartEvent()
|
||||
assert dispatch_event(e) == "reasoning:start"
|
||||
|
||||
def test_dispatch_reasoning_delta(self):
|
||||
e = ReasoningDeltaEvent(content="Let me think step by step")
|
||||
assert dispatch_event(e) == "reasoning:Let me think step by"
|
||||
|
||||
def test_dispatch_finish(self):
|
||||
e = FinishEvent(stop_reason="end_turn")
|
||||
assert dispatch_event(e) == "finish:end_turn"
|
||||
|
||||
def test_dispatch_error(self):
|
||||
e = StreamErrorEvent(error="timeout")
|
||||
assert dispatch_event(e) == "error:timeout"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# isinstance-based filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestInstanceFiltering:
|
||||
"""Filter event streams using isinstance for each event type."""
|
||||
|
||||
@pytest.fixture
|
||||
def text_stream(self) -> list:
|
||||
"""Simulate a text-only stream."""
|
||||
return [
|
||||
TextDeltaEvent(content="Hello", snapshot="Hello"),
|
||||
TextDeltaEvent(content=" world", snapshot="Hello world"),
|
||||
TextDeltaEvent(content="!", snapshot="Hello world!"),
|
||||
TextEndEvent(full_text="Hello world!"),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=3, model="test"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def tool_stream(self) -> list:
|
||||
"""Simulate a tool call stream."""
|
||||
return [
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_1",
|
||||
tool_name="get_weather",
|
||||
tool_input={"city": "London"},
|
||||
),
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_2",
|
||||
tool_name="calculator",
|
||||
tool_input={"expression": "2+2"},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def reasoning_stream(self) -> list:
|
||||
"""Simulate a stream with reasoning blocks."""
|
||||
return [
|
||||
ReasoningStartEvent(),
|
||||
ReasoningDeltaEvent(content="Let me analyze this..."),
|
||||
ReasoningDeltaEvent(content="The answer is 42."),
|
||||
TextDeltaEvent(content="The answer is 42.", snapshot="The answer is 42."),
|
||||
TextEndEvent(full_text="The answer is 42."),
|
||||
FinishEvent(stop_reason="end_turn"),
|
||||
]
|
||||
|
||||
def test_collect_text(self, text_stream):
|
||||
assert collect_text(text_stream) == "Hello world!"
|
||||
|
||||
def test_collect_text_from_tool_stream(self, tool_stream):
|
||||
assert collect_text(tool_stream) == ""
|
||||
|
||||
def test_extract_tool_calls(self, tool_stream):
|
||||
calls = extract_tool_calls(tool_stream)
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["name"] == "get_weather"
|
||||
assert calls[1]["name"] == "calculator"
|
||||
|
||||
def test_extract_tool_calls_from_text_stream(self, text_stream):
|
||||
assert extract_tool_calls(text_stream) == []
|
||||
|
||||
def test_filter_text_deltas(self, text_stream):
|
||||
deltas = [e for e in text_stream if isinstance(e, TextDeltaEvent)]
|
||||
assert len(deltas) == 3
|
||||
|
||||
def test_filter_finish(self, text_stream):
|
||||
finishes = [e for e in text_stream if isinstance(e, FinishEvent)]
|
||||
assert len(finishes) == 1
|
||||
assert finishes[0].stop_reason == "stop"
|
||||
|
||||
def test_reasoning_then_text(self, reasoning_stream):
|
||||
reasoning = [e for e in reasoning_stream if isinstance(e, ReasoningDeltaEvent)]
|
||||
text = collect_text(reasoning_stream)
|
||||
assert len(reasoning) == 2
|
||||
assert text == "The answer is 42."
|
||||
|
||||
def test_mixed_stream_type_counts(self, reasoning_stream):
|
||||
type_counts = {}
|
||||
for e in reasoning_stream:
|
||||
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
||||
assert type_counts == {
|
||||
"reasoning_start": 1,
|
||||
"reasoning_delta": 2,
|
||||
"text_delta": 1,
|
||||
"text_end": 1,
|
||||
"finish": 1,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom event extension pattern
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass(frozen=True)
|
||||
class CustomMetricsEvent:
|
||||
"""Example custom event following the same pattern."""
|
||||
|
||||
type: Literal["custom_metrics"] = "custom_metrics"
|
||||
latency_ms: float = 0.0
|
||||
tokens_per_second: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CustomCitationEvent:
|
||||
"""Example citation event extending the pattern."""
|
||||
|
||||
type: Literal["citation"] = "citation"
|
||||
source_url: str = ""
|
||||
quote: str = ""
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class TestCustomEventExtension:
|
||||
"""Custom events should follow the same frozen-dataclass convention."""
|
||||
|
||||
def test_custom_event_construction(self):
|
||||
e = CustomMetricsEvent(latency_ms=150.5, tokens_per_second=42.3)
|
||||
assert e.type == "custom_metrics"
|
||||
assert e.latency_ms == 150.5
|
||||
|
||||
def test_custom_event_frozen(self):
|
||||
e = CustomMetricsEvent()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.type = "modified"
|
||||
|
||||
def test_custom_event_serialization(self):
|
||||
e = CustomMetricsEvent(
|
||||
latency_ms=100.0,
|
||||
tokens_per_second=50.0,
|
||||
metadata={"provider": "anthropic"},
|
||||
)
|
||||
d = asdict(e)
|
||||
assert d["type"] == "custom_metrics"
|
||||
assert d["metadata"] == {"provider": "anthropic"}
|
||||
|
||||
def test_custom_event_dispatch(self):
|
||||
"""Custom events can extend the dispatch map."""
|
||||
e = CustomMetricsEvent(latency_ms=200.0)
|
||||
# Falls through to "unknown" in our dispatch_event
|
||||
assert dispatch_event(e) == "unknown:custom_metrics"
|
||||
|
||||
def test_custom_event_in_mixed_stream(self):
|
||||
"""Custom events can coexist with standard events in a list."""
|
||||
stream = [
|
||||
TextDeltaEvent(content="hi", snapshot="hi"),
|
||||
CustomMetricsEvent(latency_ms=50.0),
|
||||
TextEndEvent(full_text="hi"),
|
||||
CustomCitationEvent(source_url="https://example.com", quote="hi"),
|
||||
FinishEvent(stop_reason="stop"),
|
||||
]
|
||||
standard = [
|
||||
e
|
||||
for e in stream
|
||||
if hasattr(e, "type")
|
||||
and e.type
|
||||
in {
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"tool_call",
|
||||
"tool_result",
|
||||
"reasoning_start",
|
||||
"reasoning_delta",
|
||||
"finish",
|
||||
"error",
|
||||
}
|
||||
]
|
||||
custom = [
|
||||
e
|
||||
for e in stream
|
||||
if e.type
|
||||
not in {
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"tool_call",
|
||||
"tool_result",
|
||||
"reasoning_start",
|
||||
"reasoning_delta",
|
||||
"finish",
|
||||
"error",
|
||||
}
|
||||
]
|
||||
assert len(standard) == 3
|
||||
assert len(custom) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization of full event sequences
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSequenceSerialization:
|
||||
"""Serialize entire event sequences, as done by the dump tests."""
|
||||
|
||||
def test_serialize_text_sequence(self):
|
||||
events = [
|
||||
TextDeltaEvent(content="Hello", snapshot="Hello"),
|
||||
TextDeltaEvent(content=" world", snapshot="Hello world"),
|
||||
TextEndEvent(full_text="Hello world"),
|
||||
FinishEvent(stop_reason="stop", model="test-model"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert len(serialized) == 4
|
||||
assert serialized[0]["index"] == 0
|
||||
assert serialized[0]["type"] == "text_delta"
|
||||
assert serialized[-1]["type"] == "finish"
|
||||
assert serialized[-1]["model"] == "test-model"
|
||||
|
||||
def test_serialize_tool_sequence(self):
|
||||
events = [
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_1",
|
||||
tool_name="search",
|
||||
tool_input={"query": "test"},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert serialized[0]["tool_input"] == {"query": "test"}
|
||||
assert serialized[1]["stop_reason"] == "tool_calls"
|
||||
|
||||
def test_serialize_error_sequence(self):
|
||||
events = [
|
||||
TextDeltaEvent(content="partial"),
|
||||
StreamErrorEvent(error="connection reset", recoverable=True),
|
||||
FinishEvent(stop_reason="error"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert serialized[1]["type"] == "error"
|
||||
assert serialized[1]["recoverable"] is True
|
||||
|
||||
def test_roundtrip_snapshot_accumulation(self):
|
||||
"""Verify snapshot grows monotonically through serialization."""
|
||||
chunks = ["Hello", " beautiful", " world", "!"]
|
||||
events = []
|
||||
snapshot = ""
|
||||
for chunk in chunks:
|
||||
snapshot += chunk
|
||||
events.append(TextDeltaEvent(content=chunk, snapshot=snapshot))
|
||||
|
||||
serialized = [asdict(e) for e in events]
|
||||
for i in range(1, len(serialized)):
|
||||
assert len(serialized[i]["snapshot"]) > len(serialized[i - 1]["snapshot"])
|
||||
assert serialized[-1]["snapshot"] == "Hello beautiful world!"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# WP-2: EventType Enum Extension + Node-Level Event Routing
|
||||
# ===========================================================================
|
||||
|
||||
# The 12 new EventType members added by WP-2
|
||||
WP2_EVENT_TYPES = {
|
||||
# Node event-loop lifecycle
|
||||
EventType.NODE_LOOP_STARTED: "node_loop_started",
|
||||
EventType.NODE_LOOP_ITERATION: "node_loop_iteration",
|
||||
EventType.NODE_LOOP_COMPLETED: "node_loop_completed",
|
||||
# LLM streaming observability
|
||||
EventType.LLM_TEXT_DELTA: "llm_text_delta",
|
||||
EventType.LLM_REASONING_DELTA: "llm_reasoning_delta",
|
||||
# Tool lifecycle
|
||||
EventType.TOOL_CALL_STARTED: "tool_call_started",
|
||||
EventType.TOOL_CALL_COMPLETED: "tool_call_completed",
|
||||
# Client I/O
|
||||
EventType.CLIENT_OUTPUT_DELTA: "client_output_delta",
|
||||
EventType.CLIENT_INPUT_REQUESTED: "client_input_requested",
|
||||
# Internal node observability
|
||||
EventType.NODE_INTERNAL_OUTPUT: "node_internal_output",
|
||||
EventType.NODE_INPUT_BLOCKED: "node_input_blocked",
|
||||
EventType.NODE_STALLED: "node_stalled",
|
||||
}
|
||||
|
||||
# Pre-existing enum members that must remain unchanged
|
||||
ORIGINAL_EVENT_TYPES = {
|
||||
EventType.EXECUTION_STARTED: "execution_started",
|
||||
EventType.EXECUTION_COMPLETED: "execution_completed",
|
||||
EventType.EXECUTION_FAILED: "execution_failed",
|
||||
EventType.EXECUTION_PAUSED: "execution_paused",
|
||||
EventType.EXECUTION_RESUMED: "execution_resumed",
|
||||
EventType.STATE_CHANGED: "state_changed",
|
||||
EventType.STATE_CONFLICT: "state_conflict",
|
||||
EventType.GOAL_PROGRESS: "goal_progress",
|
||||
EventType.GOAL_ACHIEVED: "goal_achieved",
|
||||
EventType.CONSTRAINT_VIOLATION: "constraint_violation",
|
||||
EventType.STREAM_STARTED: "stream_started",
|
||||
EventType.STREAM_STOPPED: "stream_stopped",
|
||||
EventType.CUSTOM: "custom",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part A: EventType enum members
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2EventTypeEnumMembers:
|
||||
"""All 12 new EventType members exist with correct string values."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,expected_value",
|
||||
WP2_EVENT_TYPES.items(),
|
||||
ids=lambda x: x.name if isinstance(x, EventType) else x,
|
||||
)
|
||||
def test_new_member_value(self, member, expected_value):
|
||||
assert member.value == expected_value
|
||||
|
||||
def test_all_12_new_members_exist(self):
|
||||
assert len(WP2_EVENT_TYPES) == 12
|
||||
|
||||
def test_new_member_string_values_are_unique(self):
|
||||
values = list(WP2_EVENT_TYPES.values())
|
||||
assert len(values) == len(set(values))
|
||||
|
||||
def test_no_collision_with_original_members(self):
|
||||
new_values = set(WP2_EVENT_TYPES.values())
|
||||
old_values = set(ORIGINAL_EVENT_TYPES.values())
|
||||
overlap = new_values & old_values
|
||||
assert overlap == set(), f"Colliding values: {overlap}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,expected_value",
|
||||
ORIGINAL_EVENT_TYPES.items(),
|
||||
ids=lambda x: x.name if isinstance(x, EventType) else x,
|
||||
)
|
||||
def test_original_members_unchanged(self, member, expected_value):
|
||||
assert member.value == expected_value
|
||||
|
||||
def test_event_type_is_str_enum(self):
|
||||
"""EventType members compare equal to their string values."""
|
||||
assert EventType.NODE_LOOP_STARTED == "node_loop_started"
|
||||
assert EventType.LLM_TEXT_DELTA == "llm_text_delta"
|
||||
assert EventType.LLM_TEXT_DELTA.value == "llm_text_delta"
|
||||
|
||||
def test_event_type_accessible_by_name(self):
|
||||
assert EventType["NODE_LOOP_STARTED"] is EventType.NODE_LOOP_STARTED
|
||||
assert EventType["TOOL_CALL_COMPLETED"] is EventType.TOOL_CALL_COMPLETED
|
||||
|
||||
def test_event_type_accessible_by_value(self):
|
||||
assert EventType("node_loop_started") is EventType.NODE_LOOP_STARTED
|
||||
assert EventType("tool_call_completed") is EventType.TOOL_CALL_COMPLETED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part B: AgentEvent.node_id and Subscription.filter_node
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2AgentEventNodeId:
|
||||
"""AgentEvent supports node_id as a first-class field."""
|
||||
|
||||
def test_node_id_defaults_to_none(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="stream-1",
|
||||
)
|
||||
assert event.node_id is None
|
||||
|
||||
def test_node_id_can_be_set(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="email_composer",
|
||||
)
|
||||
assert event.node_id == "email_composer"
|
||||
|
||||
def test_node_id_in_to_dict(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="stream-1",
|
||||
node_id="search_node",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert d["node_id"] == "search_node"
|
||||
|
||||
def test_node_id_none_in_to_dict(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="stream-1",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert "node_id" in d
|
||||
assert d["node_id"] is None
|
||||
|
||||
|
||||
class TestWP2SubscriptionFilterNode:
|
||||
"""Subscription supports filter_node for node-level routing."""
|
||||
|
||||
@staticmethod
|
||||
async def _noop_handler(event: AgentEvent) -> None:
|
||||
pass
|
||||
|
||||
def test_filter_node_defaults_to_none(self):
|
||||
sub = Subscription(
|
||||
id="sub_1",
|
||||
event_types={EventType.LLM_TEXT_DELTA},
|
||||
handler=self._noop_handler,
|
||||
)
|
||||
assert sub.filter_node is None
|
||||
|
||||
def test_filter_node_can_be_set(self):
|
||||
sub = Subscription(
|
||||
id="sub_1",
|
||||
event_types={EventType.LLM_TEXT_DELTA},
|
||||
handler=self._noop_handler,
|
||||
filter_node="email_composer",
|
||||
)
|
||||
assert sub.filter_node == "email_composer"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part B: Node-level event routing integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2NodeLevelRouting:
|
||||
"""EventBus routes events by node_id using filter_node."""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
return EventBus()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_receives_matching_events(self, bus):
|
||||
"""Subscriber with filter_node='node-A' receives events from node-A."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
filter_node="node-A",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
data={"content": "hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].node_id == "node-A"
|
||||
assert received[0].data["content"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_rejects_non_matching_events(self, bus):
|
||||
"""Subscriber with filter_node='node-B' does NOT receive node-A events."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
filter_node="node-B",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
data={"content": "hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_filter_node_receives_all_events(self, bus):
|
||||
"""Subscriber with no filter_node receives events from all nodes."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-B",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interleaved_nodes_separated_by_filter(self, bus):
|
||||
"""Two subscribers on different nodes get only their node's events."""
|
||||
node_a_events = []
|
||||
node_b_events = []
|
||||
|
||||
async def handler_a(event):
|
||||
node_a_events.append(event)
|
||||
|
||||
async def handler_b(event):
|
||||
node_b_events.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler_a,
|
||||
filter_node="email_sender",
|
||||
)
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler_b,
|
||||
filter_node="inbox_scanner",
|
||||
)
|
||||
|
||||
# Interleaved events from both nodes
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="email_sender",
|
||||
data={"content": "Dear Jo"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="inbox_scanner",
|
||||
data={"content": "RE: Meeting conf"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="email_sender",
|
||||
data={"content": "hn, Thank you for"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="inbox_scanner",
|
||||
data={"content": "irmed for Thursday"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(node_a_events) == 2
|
||||
assert len(node_b_events) == 2
|
||||
assert node_a_events[0].data["content"] == "Dear Jo"
|
||||
assert node_a_events[1].data["content"] == "hn, Thank you for"
|
||||
assert node_b_events[0].data["content"] == "RE: Meeting conf"
|
||||
assert node_b_events[1].data["content"] == "irmed for Thursday"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_combined_with_filter_stream(self, bus):
|
||||
"""filter_node and filter_stream work together."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.TOOL_CALL_STARTED],
|
||||
handler=handler,
|
||||
filter_stream="webhook",
|
||||
filter_node="search_node",
|
||||
)
|
||||
|
||||
# Matching both filters
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="webhook",
|
||||
node_id="search_node",
|
||||
)
|
||||
)
|
||||
# Wrong stream
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="api",
|
||||
node_id="search_node",
|
||||
)
|
||||
)
|
||||
# Wrong node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="webhook",
|
||||
node_id="other_node",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].stream_id == "webhook"
|
||||
assert received[0].node_id == "search_node"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_with_node_id(self, bus):
|
||||
"""wait_for() accepts node_id parameter for filtering."""
|
||||
|
||||
async def publish_later():
|
||||
await asyncio.sleep(0.01)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="target_node",
|
||||
data={"iterations": 3},
|
||||
)
|
||||
)
|
||||
|
||||
task = asyncio.create_task(publish_later())
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.NODE_LOOP_COMPLETED,
|
||||
node_id="target_node",
|
||||
timeout=2.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "target_node"
|
||||
assert event.data["iterations"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_ignores_wrong_node(self, bus):
|
||||
"""wait_for() with node_id ignores events from other nodes."""
|
||||
|
||||
async def publish_wrong_then_right():
|
||||
await asyncio.sleep(0.01)
|
||||
# Wrong node — should be ignored
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="wrong_node",
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
# Right node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="target_node",
|
||||
data={"iterations": 5},
|
||||
)
|
||||
)
|
||||
|
||||
task = asyncio.create_task(publish_wrong_then_right())
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.NODE_LOOP_COMPLETED,
|
||||
node_id="target_node",
|
||||
timeout=2.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "target_node"
|
||||
assert event.data["iterations"] == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2: Convenience publisher methods
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2ConveniencePublishers:
|
||||
"""EventBus convenience methods for new WP-2 event types."""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
return EventBus()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_started(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_STARTED], handler=handler)
|
||||
await bus.emit_node_loop_started(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].node_id == "n1"
|
||||
assert received[0].data["max_iterations"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_iteration(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_ITERATION], handler=handler)
|
||||
await bus.emit_node_loop_iteration(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
iteration=3,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["iteration"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_completed(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_COMPLETED], handler=handler)
|
||||
await bus.emit_node_loop_completed(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
iterations=5,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["iterations"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_llm_text_delta(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.LLM_TEXT_DELTA], handler=handler)
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
content="hello",
|
||||
snapshot="hello world",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["content"] == "hello"
|
||||
assert received[0].data["snapshot"] == "hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_started(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler)
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
tool_use_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "test"},
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["tool_name"] == "web_search"
|
||||
assert received[0].data["tool_input"] == {"query": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_completed(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler)
|
||||
await bus.emit_tool_call_completed(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
tool_use_id="call_1",
|
||||
tool_name="web_search",
|
||||
result="3 results found",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["result"] == "3 results found"
|
||||
assert received[0].data["is_error"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_client_output_delta(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.CLIENT_OUTPUT_DELTA], handler=handler)
|
||||
await bus.emit_client_output_delta(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
content="chunk",
|
||||
snapshot="full chunk",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["content"] == "chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_stalled(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_STALLED], handler=handler)
|
||||
await bus.emit_node_stalled(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
reason="no progress after 10 iterations",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["reason"] == "no progress after 10 iterations"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convenience_publishers_set_node_id(self, bus):
|
||||
"""All WP-2 convenience publishers set node_id on the emitted event."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA, EventType.TOOL_CALL_STARTED],
|
||||
handler=handler,
|
||||
filter_node="my_node",
|
||||
)
|
||||
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="my_node",
|
||||
content="hi",
|
||||
snapshot="hi",
|
||||
)
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s1",
|
||||
node_id="my_node",
|
||||
tool_use_id="c1",
|
||||
tool_name="calc",
|
||||
)
|
||||
# Wrong node — should not be received
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="other_node",
|
||||
content="bye",
|
||||
snapshot="bye",
|
||||
)
|
||||
|
||||
assert len(received) == 2
|
||||
assert all(e.node_id == "my_node" for e in received)
|
||||
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
Tests for feedback/callback edges and max_node_visits in GraphExecutor.
|
||||
|
||||
Covers:
|
||||
- NodeSpec.max_node_visits default value
|
||||
- Visit limit enforcement (skip on exceed)
|
||||
- Multiple visits allowed when max_node_visits > 1
|
||||
- Unlimited visits with max_node_visits=0
|
||||
- Conditional feedback edges (backward traversal)
|
||||
- Conditional edge NOT firing (forward path taken)
|
||||
- node_visit_counts populated in ExecutionResult
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock node implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SuccessNode(NodeProtocol):
|
||||
"""Always succeeds with configurable output."""
|
||||
|
||||
def __init__(self, output: dict | None = None):
|
||||
self._output = output or {"result": "ok"}
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
self.execute_count += 1
|
||||
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
|
||||
|
||||
|
||||
class StatefulNode(NodeProtocol):
|
||||
"""Returns different outputs on successive executions."""
|
||||
|
||||
def __init__(self, outputs: list[dict]):
|
||||
self._outputs = outputs
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
output = self._outputs[min(self.execute_count, len(self._outputs) - 1)]
|
||||
self.execute_count += 1
|
||||
return NodeResult(success=True, output=output, tokens_used=10, latency_ms=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_id")
|
||||
rt.decide = MagicMock(return_value="decision_id")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def goal():
|
||||
return Goal(id="g1", name="Test", description="Feedback edge tests")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. NodeSpec default
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_max_node_visits_default():
|
||||
"""NodeSpec.max_node_visits should default to 1."""
|
||||
spec = NodeSpec(id="n", name="N", description="test", node_type="function", output_keys=["out"])
|
||||
assert spec.max_node_visits == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Visit limit skips node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_skips_node(runtime, goal):
|
||||
"""A→B→A cycle with A.max_visits=1: second visit to A should be skipped.
|
||||
|
||||
Neither node is terminal — max_steps is the guard. After A is skipped,
|
||||
the skip-redirect loop (A skip→B→A skip→B...) burns through max_steps.
|
||||
"""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry with visit limit",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=1,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0, # unlimited — let max_steps guard
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[], # No terminal — max_steps is the guard
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# A should only execute once (all subsequent visits are skipped)
|
||||
assert a_impl.execute_count == 1
|
||||
# Path should contain "a" exactly once (skipped visits aren't appended)
|
||||
assert result.path.count("a") == 1
|
||||
# Visit count tracks ALL visits (including skipped ones)
|
||||
assert result.node_visit_counts["a"] >= 2
|
||||
# B executes multiple times (no visit limit)
|
||||
assert b_impl.execute_count >= 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Visit limit allows multiple
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_allows_multiple(runtime, goal):
|
||||
"""A→B→A cycle with A.max_visits=2: A executes twice before skip."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry allows two visits",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0, # unlimited
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# A should execute exactly twice
|
||||
assert a_impl.execute_count == 2
|
||||
# Path should contain "a" exactly twice
|
||||
assert result.path.count("a") == 2
|
||||
# Visit count includes skipped visits too
|
||||
assert result.node_visit_counts["a"] >= 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Visit limit zero = unlimited
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_zero_unlimited(runtime, goal):
|
||||
"""max_node_visits=0 means unlimited; max_steps is the only guard."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="unlimited visits",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=0,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0,
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[],
|
||||
max_steps=6, # A,B,A,B,A,B
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# With max_steps=6: A,B,A,B,A,B → each executes 3 times
|
||||
assert a_impl.execute_count == 3
|
||||
assert b_impl.execute_count == 3
|
||||
assert result.steps_executed == 6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Conditional feedback edge fires
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_feedback_edge(runtime, goal):
|
||||
"""Writer→Director backward edge fires when needs_revision==True in output.
|
||||
|
||||
Edge conditions evaluate `output` (current node result) and `memory`
|
||||
(accumulated shared state). The writer's output hasn't been written to
|
||||
memory yet when edges are evaluated, so we use `output.get(...)`.
|
||||
"""
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="plans work",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="writes draft",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="final output",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="feedback_graph",
|
||||
goal_id="g1",
|
||||
name="Feedback Graph",
|
||||
entry_node="director",
|
||||
nodes=[director, writer, output_node],
|
||||
edges=[
|
||||
EdgeSpec(
|
||||
id="director_to_writer",
|
||||
source="director",
|
||||
target="writer",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
# Forward path: writer → output (when NOT needs_revision)
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') != True",
|
||||
priority=0,
|
||||
),
|
||||
# Feedback path: writer → director (when needs_revision)
|
||||
EdgeSpec(
|
||||
id="writer_feedback",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
director_impl = SuccessNode({"plan": "research AI"})
|
||||
# Writer: first call sets needs_revision=True, second sets False
|
||||
writer_impl = StatefulNode(
|
||||
[
|
||||
{"draft": "draft_v1", "needs_revision": True},
|
||||
{"draft": "draft_v2", "needs_revision": False},
|
||||
]
|
||||
)
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("director", director_impl)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Director executed twice (initial + feedback)
|
||||
assert director_impl.execute_count == 2
|
||||
# Writer executed twice (first draft rejected, second accepted)
|
||||
assert writer_impl.execute_count == 2
|
||||
# Output executed once
|
||||
assert output_impl.execute_count == 1
|
||||
# Full path: director → writer → director → writer → output
|
||||
assert result.path == ["director", "writer", "director", "writer", "output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Conditional feedback edge does NOT fire
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_feedback_false(runtime, goal):
|
||||
"""Writer→Director backward edge does NOT fire when needs_revision is False."""
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="plans work",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="writes draft",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="final output",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="feedback_graph",
|
||||
goal_id="g1",
|
||||
name="Feedback Graph",
|
||||
entry_node="director",
|
||||
nodes=[director, writer, output_node],
|
||||
edges=[
|
||||
EdgeSpec(
|
||||
id="director_to_writer",
|
||||
source="director",
|
||||
target="writer",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') != True",
|
||||
priority=0,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="writer_feedback",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
director_impl = SuccessNode({"plan": "research AI"})
|
||||
# Writer always outputs good draft (no revision needed)
|
||||
writer_impl = SuccessNode({"draft": "perfect_draft", "needs_revision": False})
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("director", director_impl)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Director only executed once (no feedback loop)
|
||||
assert director_impl.execute_count == 1
|
||||
# Writer only executed once
|
||||
assert writer_impl.execute_count == 1
|
||||
# Output executed
|
||||
assert output_impl.execute_count == 1
|
||||
# Straight-through path
|
||||
assert result.path == ["director", "writer", "output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Visit counts in ExecutionResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_counts_in_result(runtime, goal):
|
||||
"""ExecutionResult.node_visit_counts is populated with actual visit counts."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="terminal",
|
||||
node_type="function",
|
||||
input_keys=["a_out"],
|
||||
output_keys=["b_out"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="linear_graph",
|
||||
goal_id="g1",
|
||||
name="Linear Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=["b"],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", SuccessNode({"a_out": "x"}))
|
||||
executor.register_node("b", SuccessNode({"b_out": "y"}))
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
assert result.node_visit_counts == {"a": 1, "b": 1}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Conditional priority prevents fan-out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_priority_prevents_fanout(runtime, goal):
|
||||
"""When multiple CONDITIONAL edges match, only highest-priority fires.
|
||||
|
||||
Simulates: writer produces output where both forward and feedback
|
||||
conditions could match. The higher-priority forward edge should win;
|
||||
the executor must NOT treat this as fan-out.
|
||||
"""
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="produces output",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="forward target",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="feedback target",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="priority_graph",
|
||||
goal_id="g1",
|
||||
name="Priority Graph",
|
||||
entry_node="writer",
|
||||
nodes=[writer, output_node, director],
|
||||
edges=[
|
||||
# Forward: higher priority (1)
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('draft') is not None",
|
||||
priority=1,
|
||||
),
|
||||
# Feedback: lower priority (-1)
|
||||
EdgeSpec(
|
||||
id="writer_to_director",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
# Writer sets BOTH output keys — both conditions are true
|
||||
writer_impl = SuccessNode({"draft": "my draft", "needs_revision": True})
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
director_impl = SuccessNode({"plan": "plan"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
executor.register_node("director", director_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Forward edge (priority 1) wins — output executes, director does NOT
|
||||
assert output_impl.execute_count == 1
|
||||
assert director_impl.execute_count == 0
|
||||
assert result.path == ["writer", "output"]
|
||||
@@ -209,6 +209,62 @@ class TestLiteLLMProviderToolUse:
|
||||
assert result.output_tokens == 25 # 15 + 10
|
||||
assert mock_completion.call_count == 2
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools_invalid_json_arguments_are_handled(self, mock_completion):
|
||||
"""Test that invalid JSON tool arguments do not execute the tool."""
|
||||
# Mock response with invalid JSON arguments
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.choices = [MagicMock()]
|
||||
tool_call_response.choices[0].message.content = None
|
||||
tool_call_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
tool_call_response.choices[0].message.tool_calls[0].id = "call_123"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.name = "test_tool"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.arguments = "{invalid json"
|
||||
tool_call_response.choices[0].finish_reason = "tool_calls"
|
||||
tool_call_response.model = "gpt-4o-mini"
|
||||
tool_call_response.usage.prompt_tokens = 10
|
||||
tool_call_response.usage.completion_tokens = 5
|
||||
|
||||
# Final response (LLM continues after tool error)
|
||||
final_response = MagicMock()
|
||||
final_response.choices = [MagicMock()]
|
||||
final_response.choices[0].message.content = "Handled error"
|
||||
final_response.choices[0].message.tool_calls = None
|
||||
final_response.choices[0].finish_reason = "stop"
|
||||
final_response.model = "gpt-4o-mini"
|
||||
final_response.usage.prompt_tokens = 5
|
||||
final_response.usage.completion_tokens = 5
|
||||
|
||||
mock_completion.side_effect = [tool_call_response, final_response]
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={"properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
called["value"] = True
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id, content="should not be called", is_error=False
|
||||
)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "Run tool"}],
|
||||
system="You are a test assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert called["value"] is False
|
||||
assert result.content == "Handled error"
|
||||
|
||||
|
||||
class TestToolConversion:
|
||||
"""Test tool format conversion."""
|
||||
|
||||
@@ -0,0 +1,389 @@
|
||||
"""Real-API streaming tests for LiteLLM provider.
|
||||
|
||||
Calls live LLM APIs and dumps stream events to JSON files for review.
|
||||
Results are saved to core/tests/stream_event_dumps/{provider}_{model}_{scenario}.json
|
||||
|
||||
Run with:
|
||||
cd core && python -m pytest tests/test_litellm_streaming.py -v -s -k "RealAPI"
|
||||
|
||||
Requires API keys set in environment:
|
||||
ANTHROPIC_API_KEY, OPENAI_API_KEY, GEMINI_API_KEY (or via credential store)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMP_DIR = Path(__file__).parent / "stream_event_dumps"
|
||||
|
||||
|
||||
def _serialize_event(index: int, event: StreamEvent) -> dict:
|
||||
"""Serialize a StreamEvent to a JSON-safe dict."""
|
||||
d = asdict(event) # type: ignore[arg-type]
|
||||
d["index"] = index
|
||||
# Move index to front for readability
|
||||
return {"index": index, **{k: v for k, v in d.items() if k != "index"}}
|
||||
|
||||
|
||||
def _dump_events(events: list[StreamEvent], filename: str) -> Path:
|
||||
"""Write stream events to a JSON file in the dump directory."""
|
||||
DUMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filepath = DUMP_DIR / filename
|
||||
serialized = [_serialize_event(i, e) for i, e in enumerate(events)]
|
||||
filepath.write_text(json.dumps(serialized, indent=2) + "\n")
|
||||
logger.info(f"Dumped {len(events)} events to {filepath}")
|
||||
return filepath
|
||||
|
||||
|
||||
async def _collect_stream(provider: LiteLLMProvider, **kwargs) -> list[StreamEvent]:
|
||||
"""Collect all stream events from a provider.stream() call."""
|
||||
events: list[StreamEvent] = []
|
||||
async for event in provider.stream(**kwargs):
|
||||
events.append(event)
|
||||
# Log each event type as it arrives
|
||||
logger.debug(f" [{len(events) - 1}] {event.type}: {event}")
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test matrix: (model_id, dump_prefix, env_var_for_skip)
|
||||
# ---------------------------------------------------------------------------
|
||||
MODELS = [
|
||||
(
|
||||
"anthropic/claude-haiku-4-5-20251001",
|
||||
"anthropic_claude-haiku-4-5-20251001",
|
||||
"ANTHROPIC_API_KEY",
|
||||
),
|
||||
("gpt-4.1-nano", "gpt-4.1-nano", "OPENAI_API_KEY"),
|
||||
("gemini/gemini-2.0-flash", "gemini_gemini-2.0-flash", "GEMINI_API_KEY"),
|
||||
]
|
||||
|
||||
WEATHER_TOOL = Tool(
|
||||
name="get_weather",
|
||||
description="Get the current weather for a city.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name, e.g. 'Tokyo'",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
)
|
||||
|
||||
SEARCH_TOOL = Tool(
|
||||
name="web_search",
|
||||
description="Search the web for information.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
|
||||
CALCULATOR_TOOL = Tool(
|
||||
name="calculator",
|
||||
description="Perform arithmetic calculations.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Math expression to evaluate, e.g. '2 + 2'",
|
||||
}
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _has_api_key(env_var: str) -> bool:
|
||||
"""Check if an API key is available (env var or credential store)."""
|
||||
if os.environ.get(env_var):
|
||||
return True
|
||||
# Try credential store
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
provider_name = env_var.replace("_API_KEY", "").lower()
|
||||
return creds.is_available(provider_name)
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — text streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRealAPITextStreaming:
|
||||
"""Stream a simple text response from each provider and dump events."""
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a multi-paragraph response to exercise chunked delivery."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
|
||||
"Cover fetch, decode, and execute stages. Be concise but thorough."
|
||||
),
|
||||
}
|
||||
],
|
||||
system="You are a computer science teacher. Give clear, structured explanations.",
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_text.json")
|
||||
|
||||
# Basic structural assertions
|
||||
assert len(events) >= 4, f"Expected at least 4 events, got {len(events)}"
|
||||
|
||||
# Must have multiple text deltas for a longer response
|
||||
text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
|
||||
assert len(text_deltas) >= 3, f"Expected 3+ TextDeltaEvents, got {len(text_deltas)}"
|
||||
|
||||
# Snapshot must accumulate monotonically
|
||||
for i in range(1, len(text_deltas)):
|
||||
assert len(text_deltas[i].snapshot) > len(text_deltas[i - 1].snapshot), (
|
||||
f"Snapshot did not grow at index {i}"
|
||||
)
|
||||
|
||||
# Must end with TextEndEvent then FinishEvent
|
||||
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
|
||||
assert len(text_ends) == 1, f"Expected 1 TextEndEvent, got {len(text_ends)}"
|
||||
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1, f"Expected 1 FinishEvent, got {len(finish_events)}"
|
||||
assert finish_events[0].stop_reason in ("stop", "end_turn")
|
||||
|
||||
# TextEndEvent.full_text should match last snapshot
|
||||
assert text_ends[0].full_text == text_deltas[-1].snapshot
|
||||
|
||||
# Response should actually contain multi-paragraph content
|
||||
full_text = text_ends[0].full_text
|
||||
assert len(full_text) > 200, f"Response too short ({len(full_text)} chars)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — tool call streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRealAPIToolCallStreaming:
|
||||
"""Stream a tool call response from each provider and dump events."""
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a single tool call with complex arguments."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web for 'Python 3.13 release notes'.",
|
||||
}
|
||||
],
|
||||
system="You have access to tools. Use the appropriate tool.",
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_tool_call.json")
|
||||
|
||||
# Basic structural assertions
|
||||
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"
|
||||
|
||||
# Must have a tool call event
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert len(tool_calls) >= 1, "No ToolCallEvent received"
|
||||
|
||||
tc = tool_calls[0]
|
||||
assert tc.tool_name == "web_search"
|
||||
assert "query" in tc.tool_input
|
||||
assert tc.tool_use_id != ""
|
||||
|
||||
# Must end with FinishEvent
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason in ("tool_calls", "tool_use", "stop")
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_tool_call_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a response that should invoke multiple tool calls."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I need three things done in parallel: "
|
||||
"1) Get the weather in London, "
|
||||
"2) Get the weather in New York, "
|
||||
"3) Calculate 1337 * 42. "
|
||||
"Use the tools for all three."
|
||||
),
|
||||
}
|
||||
],
|
||||
system=(
|
||||
"You have access to tools. When the user asks for multiple things, "
|
||||
"call all the needed tools. Always use tools, never guess results."
|
||||
),
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_multi_tool.json")
|
||||
|
||||
# Must have multiple tool call events
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert len(tool_calls) >= 2, (
|
||||
f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
|
||||
)
|
||||
|
||||
# Verify tool names used
|
||||
tool_names = {tc.tool_name for tc in tool_calls}
|
||||
assert "get_weather" in tool_names, "Expected get_weather tool call"
|
||||
|
||||
# All tool calls should have non-empty IDs
|
||||
for tc in tool_calls:
|
||||
assert tc.tool_use_id != "", f"Empty tool_use_id on {tc.tool_name}"
|
||||
assert tc.tool_input, f"Empty tool_input on {tc.tool_name}"
|
||||
|
||||
# Must end with FinishEvent
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience runner for manual invocation
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
"""Run all streaming tests and dump results. Usage: python tests/test_litellm_streaming.py"""
|
||||
|
||||
ALL_TOOLS = [WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL]
|
||||
|
||||
async def _run_all():
|
||||
for model, prefix, env_var in MODELS:
|
||||
if not _has_api_key(env_var):
|
||||
print(f"SKIP {prefix}: {env_var} not set")
|
||||
continue
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
|
||||
# Text streaming (multi-paragraph)
|
||||
print(f"\n--- {prefix} text ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
|
||||
"Cover fetch, decode, and execute stages. Be concise but thorough."
|
||||
),
|
||||
}
|
||||
],
|
||||
system="You are a computer science teacher. Give clear, structured explanations.",
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_text.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
# Tool call streaming
|
||||
print(f"\n--- {prefix} tool_call ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web for 'Python 3.13 release notes'.",
|
||||
}
|
||||
],
|
||||
system="You have access to tools. Use the appropriate tool.",
|
||||
tools=ALL_TOOLS,
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_tool_call.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
# Multi-tool call streaming
|
||||
print(f"\n--- {prefix} multi_tool ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I need three things done in parallel: "
|
||||
"1) Get the weather in London, "
|
||||
"2) Get the weather in New York, "
|
||||
"3) Calculate 1337 * 42. "
|
||||
"Use the tools for all three."
|
||||
),
|
||||
}
|
||||
],
|
||||
system=(
|
||||
"You have access to tools. When the user asks for multiple things, "
|
||||
"call all the needed tools. Always use tools, never guess results."
|
||||
),
|
||||
tools=ALL_TOOLS,
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_multi_tool.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(_run_all())
|
||||
@@ -168,6 +168,68 @@ class TestNodeConversation:
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.estimate_tokens() == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_token_count_overrides_estimate(self):
|
||||
"""When actual API token count is provided, estimate_tokens uses it."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.estimate_tokens() == 100 # chars/4 fallback
|
||||
|
||||
conv.update_token_count(500)
|
||||
assert conv.estimate_tokens() == 500 # actual API value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_resets_token_count(self):
|
||||
"""After compaction, actual token count is cleared (recalibrates on next LLM call)."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("a" * 400)
|
||||
conv.update_token_count(500)
|
||||
assert conv.estimate_tokens() == 500
|
||||
|
||||
await conv.compact("summary", keep_recent=0)
|
||||
# Falls back to chars/4 for the summary message
|
||||
assert conv.estimate_tokens() == len("summary") // 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_resets_token_count(self):
|
||||
"""clear() also resets the actual token count."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("hello")
|
||||
conv.update_token_count(1000)
|
||||
assert conv.estimate_tokens() == 1000
|
||||
|
||||
await conv.clear()
|
||||
assert conv.estimate_tokens() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_ratio(self):
|
||||
"""usage_ratio returns estimate / max_history_tokens."""
|
||||
conv = NodeConversation(max_history_tokens=1000)
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.usage_ratio() == pytest.approx(0.1) # 100/1000
|
||||
|
||||
conv.update_token_count(800)
|
||||
assert conv.usage_ratio() == pytest.approx(0.8) # 800/1000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_ratio_zero_budget(self):
|
||||
"""usage_ratio returns 0 when max_history_tokens is 0 (unlimited)."""
|
||||
conv = NodeConversation(max_history_tokens=0)
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.usage_ratio() == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction_with_actual_tokens(self):
|
||||
"""needs_compaction uses actual API token count when available."""
|
||||
conv = NodeConversation(max_history_tokens=1000, compaction_threshold=0.8)
|
||||
await conv.add_user_message("a" * 100) # chars/4 = 25, well under 800
|
||||
|
||||
assert conv.needs_compaction() is False
|
||||
|
||||
# Simulate API reporting much higher actual token usage
|
||||
conv.update_token_count(850)
|
||||
assert conv.needs_compaction() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction(self):
|
||||
conv = NodeConversation(max_history_tokens=100, compaction_threshold=0.8)
|
||||
@@ -558,3 +620,313 @@ class TestFileConversationStore:
|
||||
assert (base / "cursor.json").exists()
|
||||
assert (base / "parts" / "0000000000.json").exists()
|
||||
assert (base / "parts" / "0000000001.json").exists()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Integration tests — real FileConversationStore, no mocks
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestConversationIntegration:
|
||||
"""End-to-end tests using real FileConversationStore on disk.
|
||||
|
||||
Every test creates a fresh directory, writes real JSON files,
|
||||
and restores from a *new* store instance (simulating process restart).
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_agent_conversation(self, tmp_path):
|
||||
"""Simulate a realistic agent conversation with multiple turns,
|
||||
tool calls, and tool results — then restore from disk."""
|
||||
base = tmp_path / "agent_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a helpful travel agent.",
|
||||
max_history_tokens=16000,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Turn 1: user asks, assistant responds with tool call
|
||||
await conv.add_user_message("Find me flights from NYC to London next Friday.")
|
||||
await conv.add_assistant_message(
|
||||
"Let me search for flights.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_flight_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_flights",
|
||||
"arguments": '{"origin":"JFK","destination":"LHR","date":"2025-06-13"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_flight_1",
|
||||
'{"flights":[{"airline":"BA","price":450,"departure":"08:00"},{"airline":"AA","price":520,"departure":"14:30"}]}',
|
||||
)
|
||||
|
||||
# Turn 2: assistant presents results, user picks one
|
||||
await conv.add_assistant_message(
|
||||
"I found 2 flights:\n"
|
||||
"1. British Airways at $450, departing 08:00\n"
|
||||
"2. American Airlines at $520, departing 14:30\n"
|
||||
"Which one would you like?"
|
||||
)
|
||||
await conv.add_user_message("Book the British Airways one.")
|
||||
await conv.add_assistant_message(
|
||||
"Booking the BA flight now.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_book_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "book_flight",
|
||||
"arguments": '{"flight_id":"BA-JFK-LHR-0800","passenger":"user"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_book_1",
|
||||
'{"confirmation":"BA-12345","status":"confirmed"}',
|
||||
)
|
||||
await conv.add_assistant_message("Your flight is booked! Confirmation: BA-12345.")
|
||||
|
||||
# Verify in-memory state
|
||||
assert conv.turn_count == 2
|
||||
assert conv.message_count == 8
|
||||
assert conv.next_seq == 8
|
||||
|
||||
# --- Simulate process restart: new store, same path ---
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.system_prompt == "You are a helpful travel agent."
|
||||
assert restored.turn_count == 2
|
||||
assert restored.message_count == 8
|
||||
assert restored.next_seq == 8
|
||||
|
||||
# Verify message content integrity
|
||||
msgs = restored.messages
|
||||
assert msgs[0].role == "user"
|
||||
assert "NYC to London" in msgs[0].content
|
||||
assert msgs[1].role == "assistant"
|
||||
assert msgs[1].tool_calls[0]["id"] == "call_flight_1"
|
||||
assert msgs[2].role == "tool"
|
||||
assert msgs[2].tool_use_id == "call_flight_1"
|
||||
assert "BA" in msgs[2].content
|
||||
assert msgs[7].content == "Your flight is booked! Confirmation: BA-12345."
|
||||
|
||||
# Verify LLM-format output
|
||||
llm_msgs = restored.to_llm_messages()
|
||||
assert llm_msgs[0] == {"role": "user", "content": msgs[0].content}
|
||||
assert llm_msgs[2]["role"] == "tool"
|
||||
assert llm_msgs[2]["tool_call_id"] == "call_flight_1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_and_restore_preserves_continuity(self, tmp_path):
|
||||
"""Build up a long conversation, compact it, continue adding
|
||||
messages, then restore — verifying seq continuity and content."""
|
||||
base = tmp_path / "compact_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="research assistant",
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Build 10 messages (5 turns)
|
||||
for i in range(5):
|
||||
await conv.add_user_message(f"question {i}")
|
||||
await conv.add_assistant_message(f"answer {i}")
|
||||
|
||||
assert conv.message_count == 10
|
||||
assert conv.next_seq == 10
|
||||
|
||||
# Compact: keep last 2 messages (question 4, answer 4)
|
||||
await conv.compact("Summary of questions 0-3 and their answers.", keep_recent=2)
|
||||
|
||||
assert conv.message_count == 3 # summary + 2 recent
|
||||
assert conv.messages[0].content == "Summary of questions 0-3 and their answers."
|
||||
assert conv.messages[1].content == "question 4"
|
||||
assert conv.messages[2].content == "answer 4"
|
||||
|
||||
# Continue the conversation post-compaction
|
||||
await conv.add_user_message("question 5")
|
||||
await conv.add_assistant_message("answer 5")
|
||||
assert conv.next_seq == 12
|
||||
|
||||
# Verify disk: old part files (seq 0-7) should be deleted
|
||||
parts_dir = base / "parts"
|
||||
part_files = sorted(parts_dir.glob("*.json"))
|
||||
part_seqs = [int(f.stem) for f in part_files]
|
||||
# Should have: summary (seq 7), question 4 (seq 8), answer 4 (seq 9),
|
||||
# question 5 (seq 10), answer 5 (seq 11)
|
||||
assert all(s >= 7 for s in part_seqs), f"Stale parts found: {part_seqs}"
|
||||
|
||||
# Restore from fresh store
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.next_seq == 12
|
||||
assert restored.message_count == 5
|
||||
assert "Summary of questions 0-3" in restored.messages[0].content
|
||||
assert restored.messages[-1].content == "answer 5"
|
||||
|
||||
# Verify seq monotonicity across all restored messages
|
||||
seqs = [m.seq for m in restored.messages]
|
||||
assert seqs == sorted(seqs), f"Seqs not monotonic: {seqs}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_key_preservation_through_compact_and_restore(self, tmp_path):
|
||||
"""Output keys in compacted messages survive disk persistence."""
|
||||
base = tmp_path / "output_key_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="classifier",
|
||||
output_keys=["classification", "confidence"],
|
||||
store=store,
|
||||
)
|
||||
|
||||
await conv.add_user_message("Classify this email: 'You won a prize!'")
|
||||
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.97"}')
|
||||
await conv.add_user_message("What about: 'Meeting at 3pm'")
|
||||
await conv.add_assistant_message('{"classification": "ham", "confidence": "0.99"}')
|
||||
await conv.add_user_message("And: 'Buy cheap meds now'")
|
||||
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.95"}')
|
||||
|
||||
# Compact keeping only the last 2 messages
|
||||
await conv.compact("Classified 3 emails.", keep_recent=2)
|
||||
|
||||
# The summary should contain preserved output keys from discarded messages
|
||||
summary_content = conv.messages[0].content
|
||||
assert "PRESERVED VALUES" in summary_content
|
||||
# Most recent values from discarded messages (msgs 0-3) are "ham"/"0.99"
|
||||
assert "ham" in summary_content or "spam" in summary_content
|
||||
|
||||
# Restore and verify the preserved values survived
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
assert "PRESERVED VALUES" in restored.messages[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_roundtrip(self, tmp_path):
|
||||
"""Tool errors persist and restore with ERROR: prefix in LLM output."""
|
||||
base = tmp_path / "error_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(store=store)
|
||||
|
||||
await conv.add_user_message("Calculate 1/0")
|
||||
await conv.add_assistant_message(
|
||||
"Let me calculate that.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_calc",
|
||||
"type": "function",
|
||||
"function": {"name": "calculator", "arguments": '{"expr":"1/0"}'},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_calc", "ZeroDivisionError: division by zero", is_error=True
|
||||
)
|
||||
await conv.add_assistant_message("The calculation failed: division by zero is undefined.")
|
||||
|
||||
# Restore
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
|
||||
tool_msg = restored.messages[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.is_error is True
|
||||
assert tool_msg.tool_use_id == "call_calc"
|
||||
|
||||
llm_dict = tool_msg.to_llm_dict()
|
||||
assert llm_dict["content"].startswith("ERROR: ")
|
||||
assert "ZeroDivisionError" in llm_dict["content"]
|
||||
assert llm_dict["tool_call_id"] == "call_calc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_conversations_isolated(self, tmp_path):
|
||||
"""Two conversations in separate directories don't interfere."""
|
||||
store_a = FileConversationStore(tmp_path / "conv_a")
|
||||
store_b = FileConversationStore(tmp_path / "conv_b")
|
||||
|
||||
conv_a = NodeConversation(system_prompt="Agent A", store=store_a)
|
||||
conv_b = NodeConversation(system_prompt="Agent B", store=store_b)
|
||||
|
||||
await conv_a.add_user_message("Hello from A")
|
||||
await conv_b.add_user_message("Hello from B")
|
||||
await conv_a.add_assistant_message("Response A")
|
||||
await conv_b.add_assistant_message("Response B")
|
||||
await conv_b.add_user_message("Follow-up B")
|
||||
|
||||
# Restore independently
|
||||
restored_a = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_a"))
|
||||
restored_b = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_b"))
|
||||
|
||||
assert restored_a.system_prompt == "Agent A"
|
||||
assert restored_b.system_prompt == "Agent B"
|
||||
assert restored_a.message_count == 2
|
||||
assert restored_b.message_count == 3
|
||||
assert restored_a.messages[0].content == "Hello from A"
|
||||
assert restored_b.messages[2].content == "Follow-up B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_removes_all_files(self, tmp_path):
|
||||
"""destroy() wipes the entire conversation directory."""
|
||||
base = tmp_path / "doomed_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(system_prompt="temp", store=store)
|
||||
await conv.add_user_message("ephemeral")
|
||||
await conv.add_assistant_message("gone soon")
|
||||
|
||||
assert base.exists()
|
||||
assert (base / "meta.json").exists()
|
||||
assert (base / "parts").exists()
|
||||
|
||||
await store.destroy()
|
||||
|
||||
assert not base.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_empty_store_returns_none(self, tmp_path):
|
||||
"""Restoring from a path that was never written to returns None."""
|
||||
store = FileConversationStore(tmp_path / "empty")
|
||||
result = await NodeConversation.restore(store)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_then_continue_then_restore(self, tmp_path):
|
||||
"""clear() removes messages but preserves seq counter for new messages."""
|
||||
base = tmp_path / "clear_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(system_prompt="s", store=store)
|
||||
|
||||
await conv.add_user_message("old msg 0")
|
||||
await conv.add_assistant_message("old msg 1")
|
||||
assert conv.next_seq == 2
|
||||
|
||||
await conv.clear()
|
||||
assert conv.message_count == 0
|
||||
assert conv.next_seq == 2 # seq counter preserved
|
||||
|
||||
# Continue with new messages — seqs should start at 2
|
||||
await conv.add_user_message("new msg")
|
||||
await conv.add_assistant_message("new response")
|
||||
assert conv.next_seq == 4
|
||||
assert conv.messages[0].seq == 2
|
||||
assert conv.messages[1].seq == 3
|
||||
|
||||
# Restore
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
assert restored.message_count == 2
|
||||
assert restored.next_seq == 4
|
||||
assert restored.messages[0].content == "new msg"
|
||||
assert restored.messages[0].seq == 2
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
"""Tests for stream event dataclasses.
|
||||
|
||||
Validates construction, defaults, immutability, serialization, and the
|
||||
StreamEvent discriminated union type.
|
||||
"""
|
||||
|
||||
from dataclasses import FrozenInstanceError, asdict, fields
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
# All concrete event classes in the union
|
||||
ALL_EVENT_CLASSES = [
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
ReasoningStartEvent,
|
||||
ReasoningDeltaEvent,
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction & defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventDefaults:
|
||||
"""Each event class should be constructible with zero arguments."""
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_default_construction(self, cls):
|
||||
event = cls()
|
||||
assert event.type != ""
|
||||
|
||||
def test_text_delta_defaults(self):
|
||||
e = TextDeltaEvent()
|
||||
assert e.type == "text_delta"
|
||||
assert e.content == ""
|
||||
assert e.snapshot == ""
|
||||
|
||||
def test_text_end_defaults(self):
|
||||
e = TextEndEvent()
|
||||
assert e.type == "text_end"
|
||||
assert e.full_text == ""
|
||||
|
||||
def test_tool_call_defaults(self):
|
||||
e = ToolCallEvent()
|
||||
assert e.type == "tool_call"
|
||||
assert e.tool_use_id == ""
|
||||
assert e.tool_name == ""
|
||||
assert e.tool_input == {}
|
||||
|
||||
def test_tool_result_defaults(self):
|
||||
e = ToolResultEvent()
|
||||
assert e.type == "tool_result"
|
||||
assert e.tool_use_id == ""
|
||||
assert e.content == ""
|
||||
assert e.is_error is False
|
||||
|
||||
def test_reasoning_start_defaults(self):
|
||||
e = ReasoningStartEvent()
|
||||
assert e.type == "reasoning_start"
|
||||
|
||||
def test_reasoning_delta_defaults(self):
|
||||
e = ReasoningDeltaEvent()
|
||||
assert e.type == "reasoning_delta"
|
||||
assert e.content == ""
|
||||
|
||||
def test_finish_defaults(self):
|
||||
e = FinishEvent()
|
||||
assert e.type == "finish"
|
||||
assert e.stop_reason == ""
|
||||
assert e.input_tokens == 0
|
||||
assert e.output_tokens == 0
|
||||
assert e.model == ""
|
||||
|
||||
def test_stream_error_defaults(self):
|
||||
e = StreamErrorEvent()
|
||||
assert e.type == "error"
|
||||
assert e.error == ""
|
||||
assert e.recoverable is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction with values
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventConstruction:
|
||||
"""Events should store provided field values correctly."""
|
||||
|
||||
def test_text_delta_with_values(self):
|
||||
e = TextDeltaEvent(content="hello", snapshot="hello world")
|
||||
assert e.content == "hello"
|
||||
assert e.snapshot == "hello world"
|
||||
|
||||
def test_text_end_with_values(self):
|
||||
e = TextEndEvent(full_text="the complete response")
|
||||
assert e.full_text == "the complete response"
|
||||
|
||||
def test_tool_call_with_values(self):
|
||||
e = ToolCallEvent(
|
||||
tool_use_id="call_abc123",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "python", "num_results": 5},
|
||||
)
|
||||
assert e.tool_use_id == "call_abc123"
|
||||
assert e.tool_name == "web_search"
|
||||
assert e.tool_input == {"query": "python", "num_results": 5}
|
||||
|
||||
def test_tool_result_with_values(self):
|
||||
e = ToolResultEvent(
|
||||
tool_use_id="call_abc123",
|
||||
content="search results here",
|
||||
is_error=False,
|
||||
)
|
||||
assert e.tool_use_id == "call_abc123"
|
||||
assert e.content == "search results here"
|
||||
assert e.is_error is False
|
||||
|
||||
def test_tool_result_error(self):
|
||||
e = ToolResultEvent(
|
||||
tool_use_id="call_fail",
|
||||
content="timeout",
|
||||
is_error=True,
|
||||
)
|
||||
assert e.is_error is True
|
||||
|
||||
def test_reasoning_delta_with_content(self):
|
||||
e = ReasoningDeltaEvent(content="Let me think about this...")
|
||||
assert e.content == "Let me think about this..."
|
||||
|
||||
def test_finish_with_values(self):
|
||||
e = FinishEvent(
|
||||
stop_reason="end_turn",
|
||||
input_tokens=150,
|
||||
output_tokens=300,
|
||||
model="claude-haiku-4-5",
|
||||
)
|
||||
assert e.stop_reason == "end_turn"
|
||||
assert e.input_tokens == 150
|
||||
assert e.output_tokens == 300
|
||||
assert e.model == "claude-haiku-4-5"
|
||||
|
||||
def test_stream_error_with_values(self):
|
||||
e = StreamErrorEvent(error="rate limit exceeded", recoverable=True)
|
||||
assert e.error == "rate limit exceeded"
|
||||
assert e.recoverable is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frozen immutability
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventImmutability:
|
||||
"""All events are frozen dataclasses — fields cannot be reassigned."""
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_frozen(self, cls):
|
||||
event = cls()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
event.type = "modified"
|
||||
|
||||
def test_text_delta_frozen_content(self):
|
||||
e = TextDeltaEvent(content="hello")
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.content = "modified"
|
||||
|
||||
def test_tool_call_frozen_input(self):
|
||||
e = ToolCallEvent(tool_input={"key": "value"})
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.tool_input = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type literal values
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTypeLiterals:
|
||||
"""Each event's `type` field should match its Literal annotation."""
|
||||
|
||||
EXPECTED_TYPES = {
|
||||
TextDeltaEvent: "text_delta",
|
||||
TextEndEvent: "text_end",
|
||||
ToolCallEvent: "tool_call",
|
||||
ToolResultEvent: "tool_result",
|
||||
ReasoningStartEvent: "reasoning_start",
|
||||
ReasoningDeltaEvent: "reasoning_delta",
|
||||
FinishEvent: "finish",
|
||||
StreamErrorEvent: "error",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls,expected_type",
|
||||
EXPECTED_TYPES.items(),
|
||||
ids=lambda x: x.__name__ if isinstance(x, type) else x,
|
||||
)
|
||||
def test_type_value(self, cls, expected_type):
|
||||
assert cls().type == expected_type
|
||||
|
||||
def test_all_types_unique(self):
|
||||
types = [cls().type for cls in ALL_EVENT_CLASSES]
|
||||
assert len(types) == len(set(types)), f"Duplicate type values: {types}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization via dataclasses.asdict
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventSerialization:
|
||||
"""Events should round-trip through asdict for JSON serialization."""
|
||||
|
||||
def test_text_delta_asdict(self):
|
||||
e = TextDeltaEvent(content="chunk", snapshot="full chunk")
|
||||
d = asdict(e)
|
||||
assert d == {"type": "text_delta", "content": "chunk", "snapshot": "full chunk"}
|
||||
|
||||
def test_tool_call_asdict(self):
|
||||
e = ToolCallEvent(
|
||||
tool_use_id="id_1",
|
||||
tool_name="calc",
|
||||
tool_input={"expression": "2+2"},
|
||||
)
|
||||
d = asdict(e)
|
||||
assert d["tool_name"] == "calc"
|
||||
assert d["tool_input"] == {"expression": "2+2"}
|
||||
|
||||
def test_finish_asdict(self):
|
||||
e = FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=20, model="gpt-4")
|
||||
d = asdict(e)
|
||||
assert d == {
|
||||
"type": "finish",
|
||||
"stop_reason": "stop",
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_asdict_contains_type(self, cls):
|
||||
d = asdict(cls())
|
||||
assert "type" in d
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_asdict_keys_match_fields(self, cls):
|
||||
event = cls()
|
||||
d = asdict(event)
|
||||
field_names = {f.name for f in fields(cls)}
|
||||
assert set(d.keys()) == field_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StreamEvent union type
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestStreamEventUnion:
|
||||
"""The StreamEvent union should include all event classes."""
|
||||
|
||||
def test_union_contains_all_classes(self):
|
||||
# StreamEvent is a UnionType (PEP 604 syntax: X | Y | Z)
|
||||
union_args = StreamEvent.__args__ # type: ignore[attr-defined]
|
||||
for cls in ALL_EVENT_CLASSES:
|
||||
assert cls in union_args, f"{cls.__name__} not in StreamEvent union"
|
||||
|
||||
def test_union_has_exactly_expected_members(self):
|
||||
union_args = set(StreamEvent.__args__) # type: ignore[attr-defined]
|
||||
expected = set(ALL_EVENT_CLASSES)
|
||||
assert union_args == expected
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_isinstance_check(self, cls):
|
||||
"""Each event instance should be an instance of its class (basic sanity)."""
|
||||
event = cls()
|
||||
assert isinstance(event, cls)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Equality & hashing (frozen dataclasses support both)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventEquality:
|
||||
"""Frozen dataclasses support equality and hashing."""
|
||||
|
||||
def test_equal_events(self):
|
||||
a = TextDeltaEvent(content="hi", snapshot="hi")
|
||||
b = TextDeltaEvent(content="hi", snapshot="hi")
|
||||
assert a == b
|
||||
|
||||
def test_unequal_events(self):
|
||||
a = TextDeltaEvent(content="hi")
|
||||
b = TextDeltaEvent(content="bye")
|
||||
assert a != b
|
||||
|
||||
def test_different_types_not_equal(self):
|
||||
a = TextDeltaEvent(content="hi")
|
||||
b = ReasoningDeltaEvent(content="hi")
|
||||
assert a != b
|
||||
|
||||
def test_hashable(self):
|
||||
e = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
s = {e} # should be hashable since frozen
|
||||
assert e in s
|
||||
|
||||
def test_equal_events_same_hash(self):
|
||||
a = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
b = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
def test_events_with_dict_not_hashable(self):
|
||||
"""Events containing dict fields (e.g. tool_input) are not hashable."""
|
||||
e = ToolCallEvent(tool_use_id="x", tool_name="y", tool_input={"key": "val"})
|
||||
with pytest.raises(TypeError, match="unhashable type"):
|
||||
hash(e)
|
||||
@@ -362,7 +362,6 @@ class AgentRequest(BaseModel):
|
||||
raise ValueError('max_tokens too high')
|
||||
return v
|
||||
```
|
||||
|
||||
### Output Sanitization
|
||||
> **Note:** The following snippet is illustrative and shows a simplified example
|
||||
> of output sanitization logic. Actual implementations may differ.
|
||||
|
||||
@@ -1757,7 +1757,7 @@ tools/src/aden_tools/credentials/
|
||||
|
||||
### Manual Testing
|
||||
|
||||
- [ ] Create encrypted credential store
|
||||
- [ ] Create local encrypted store
|
||||
- [ ] Save and load multi-key credentials
|
||||
- [ ] Verify template resolution in tool headers
|
||||
- [ ] Test OAuth2 token refresh
|
||||
|
||||
+1
-3
@@ -14,7 +14,6 @@
|
||||
|
||||
[](https://github.com/adenhq/hive/blob/main/LICENSE)
|
||||
[](https://www.ycombinator.com/companies/aden)
|
||||
[](https://hub.docker.com/u/adenhq)
|
||||
[](https://discord.com/invite/MXE49hrKDk)
|
||||
[](https://x.com/aden_hq)
|
||||
[](https://www.linkedin.com/company/teamaden/)
|
||||
@@ -42,7 +41,7 @@ Visita [adenhq.com](https://adenhq.com) para documentación completa, ejemplos y
|
||||
## ¿Qué es Aden?
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden es una plataforma para construir, desplegar, operar y adaptar agentes de IA:
|
||||
@@ -66,7 +65,6 @@ Aden es una plataforma para construir, desplegar, operar y adaptar agentes de IA
|
||||
### Prerrequisitos
|
||||
|
||||
- [Python 3.11+](https://www.python.org/downloads/) - Para desarrollo de agentes
|
||||
- [Docker](https://docs.docker.com/get-docker/) (v20.10+) - Opcional, para herramientas en contenedores
|
||||
|
||||
### Instalación
|
||||
|
||||
|
||||
+1
-1
@@ -44,7 +44,7 @@
|
||||
# Aden क्या है?
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden एक ऐसा प्लेटफ़ॉर्म है जो AI एजेंट्स को बनाने, डिप्लॉय करने, ऑपरेट करने और अनुकूलित करने के लिए उपयोग होता है:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Adenとは
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Adenは、AIエージェントの構築、デプロイ、運用、適応のためのプラットフォームです:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Aden이란 무엇인가
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden은 AI 에이전트를 구축, 배포, 운영, 적응시키기 위한 플랫폼입니다:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@ Visite [adenhq.com](https://adenhq.com) para documentação completa, exemplos e
|
||||
## O que é Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden é uma plataforma para construir, implantar, operar e adaptar agentes de IA:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Что такое Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden — это платформа для создания, развёртывания, эксплуатации и адаптации ИИ-агентов:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## 什么是 Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden 是一个用于构建、部署、运营和适应 AI 智能体的平台:
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# Examples
|
||||
|
||||
This directory contains two types of examples to help you build agents with the Hive framework.
|
||||
|
||||
## Recipes vs Templates
|
||||
|
||||
### [recipes/](recipes/) — "How to make it"
|
||||
|
||||
A recipe is a **prompt-only** description of an agent. It tells you the goal, the nodes, the prompts, the edge routing logic, and what tools to wire in — but it's not runnable code. You read the recipe, then build the agent yourself.
|
||||
|
||||
Use recipes when you want to:
|
||||
- Understand a pattern before committing to an implementation
|
||||
- Adapt an idea to your own codebase or tooling
|
||||
- Learn how to think about agent design (goals, nodes, edges, prompts)
|
||||
|
||||
### [templates/](templates/) — "Ready to eat"
|
||||
|
||||
A template is a **working agent scaffold** that follows the standard Hive export structure. Copy the folder, rename it, swap in your own prompts and tools, and run it.
|
||||
|
||||
Use templates when you want to:
|
||||
- Get a new agent running quickly
|
||||
- Start from a known-good structure instead of from scratch
|
||||
- See how all the pieces (goal, nodes, edges, config, CLI) fit together in real code
|
||||
|
||||
## How to use a template
|
||||
|
||||
```bash
|
||||
# 1. Copy the template
|
||||
cp -r examples/templates/marketing_agent exports/my_agent
|
||||
|
||||
# 2. Edit the goal, nodes, and edges in agent.py and nodes/__init__.py
|
||||
|
||||
# 3. Run it
|
||||
PYTHONPATH=core python -m exports.my_agent --help
|
||||
```
|
||||
|
||||
## How to use a recipe
|
||||
|
||||
1. Read the recipe markdown file
|
||||
2. Use the patterns described to build your own agent — either manually or with the builder agent (`/agent-workflow`)
|
||||
3. Refer to the [core README](../core/README.md) for framework API details
|
||||
@@ -0,0 +1,27 @@
|
||||
# Recipes
|
||||
|
||||
A recipe describes an agent's design — the goal, nodes, prompts, edge logic, and tools — without providing runnable code. Think of it as a blueprint: it tells you *how* to build the agent, but you do the building.
|
||||
|
||||
## What's in a recipe
|
||||
|
||||
Each recipe is a markdown file (or folder with a markdown file) containing:
|
||||
|
||||
- **Goal**: What the agent accomplishes, including success criteria and constraints
|
||||
- **Nodes**: Each step in the workflow, with the system prompt, node type, and input/output keys
|
||||
- **Edges**: How nodes connect, including conditions and routing logic
|
||||
- **Tools**: What external tools or MCP servers the agent needs
|
||||
- **Usage notes**: Tips, gotchas, and suggested variations
|
||||
|
||||
## How to use a recipe
|
||||
|
||||
1. Read through the recipe to understand the design
|
||||
2. Create a new agent using the standard export structure (see [templates/](../templates/) for a scaffold)
|
||||
3. Translate the recipe's goal, nodes, and edges into code
|
||||
4. Wire in the tools described
|
||||
5. Test and iterate
|
||||
|
||||
## Available recipes
|
||||
|
||||
| Recipe | Description |
|
||||
|--------|-------------|
|
||||
| [marketing_agent](marketing_agent/) | Multi-channel marketing content generator with audience analysis and A/B copy variants |
|
||||
@@ -0,0 +1,156 @@
|
||||
# Recipe: Marketing Content Agent
|
||||
|
||||
A multi-channel marketing content generator. Given a product description and target audience, this agent analyzes the audience, generates tailored copy for multiple channels, and produces A/B variants.
|
||||
|
||||
## Goal
|
||||
|
||||
```
|
||||
Name: Marketing Content Generator
|
||||
Description: Generate targeted marketing content across multiple channels
|
||||
for a given product and audience.
|
||||
|
||||
Success criteria:
|
||||
- Audience analysis is produced with demographics and pain points
|
||||
- At least 2 channel-specific content pieces are generated
|
||||
- A/B variants are provided for each piece
|
||||
- All content aligns with the specified brand voice
|
||||
|
||||
Constraints:
|
||||
- (hard) No competitor brand names in generated content
|
||||
- (soft) Content should be under 280 characters for social media channels
|
||||
```
|
||||
|
||||
## Input / Output
|
||||
|
||||
**Input:**
|
||||
- `product_description` (str) — What the product is and does
|
||||
- `target_audience` (str) — Who the content is for
|
||||
- `brand_voice` (str) — Tone and style guidelines (e.g., "professional but approachable")
|
||||
- `channels` (list[str]) — Target channels, e.g. `["email", "twitter", "linkedin"]`
|
||||
|
||||
**Output:**
|
||||
- `audience_analysis` (dict) — Demographics, pain points, motivations
|
||||
- `content` (list[dict]) — Per-channel content with A/B variants
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
[analyze_audience] → [generate_content] → [review_and_refine]
|
||||
|
|
||||
(conditional)
|
||||
|
|
||||
needs_revision == True → [generate_content]
|
||||
needs_revision == False → (done)
|
||||
```
|
||||
|
||||
## Nodes
|
||||
|
||||
### 1. analyze_audience
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `product_description`, `target_audience` |
|
||||
| Output keys | `audience_analysis` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a marketing strategist. Analyze the target audience for a product.
|
||||
|
||||
Product: {product_description}
|
||||
Target audience: {target_audience}
|
||||
|
||||
Produce a structured analysis in JSON:
|
||||
{{
|
||||
"audience_analysis": {{
|
||||
"demographics": "...",
|
||||
"pain_points": ["..."],
|
||||
"motivations": ["..."],
|
||||
"preferred_channels": ["..."],
|
||||
"messaging_angle": "..."
|
||||
}}
|
||||
}}
|
||||
```
|
||||
|
||||
### 2. generate_content
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `product_description`, `audience_analysis`, `brand_voice`, `channels` |
|
||||
| Output keys | `content` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a marketing copywriter. Generate content for each channel.
|
||||
|
||||
Product: {product_description}
|
||||
Audience analysis: {audience_analysis}
|
||||
Brand voice: {brand_voice}
|
||||
Channels: {channels}
|
||||
|
||||
For each channel, produce two variants (A and B).
|
||||
|
||||
Output as JSON:
|
||||
{{
|
||||
"content": [
|
||||
{{
|
||||
"channel": "twitter",
|
||||
"variant_a": "...",
|
||||
"variant_b": "..."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
### 3. review_and_refine
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `content`, `brand_voice` |
|
||||
| Output keys | `content`, `needs_revision` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a senior marketing editor. Review the following content for brand
|
||||
voice alignment, clarity, and channel appropriateness.
|
||||
|
||||
Content: {content}
|
||||
Brand voice: {brand_voice}
|
||||
|
||||
If any piece needs revision, fix it and set needs_revision to true.
|
||||
If everything looks good, return the content unchanged with needs_revision false.
|
||||
|
||||
Output as JSON:
|
||||
{{
|
||||
"content": [...],
|
||||
"needs_revision": false
|
||||
}}
|
||||
```
|
||||
|
||||
## Edges
|
||||
|
||||
| Source | Target | Condition | Priority |
|
||||
|--------|--------|-----------|----------|
|
||||
| analyze_audience | generate_content | `on_success` | 0 |
|
||||
| generate_content | review_and_refine | `on_success` | 0 |
|
||||
| review_and_refine | generate_content | `conditional: needs_revision == True` | 10 |
|
||||
|
||||
The `review_and_refine → generate_content` loop has higher priority so it's checked first. If `needs_revision` is false, execution ends at `review_and_refine` (terminal node).
|
||||
|
||||
## Tools
|
||||
|
||||
This recipe uses no external tools — all nodes are `llm_generate`. To extend it, consider adding:
|
||||
- A web search tool for competitive analysis in the `analyze_audience` node
|
||||
- A URL shortener tool for social media content
|
||||
- An image generation tool for visual content variants
|
||||
|
||||
## Variations
|
||||
|
||||
- **Single-channel mode**: Remove the `channels` input and hardcode to one channel for simpler output
|
||||
- **With approval gate**: Add a `human_input` node between `review_and_refine` and the terminal to require human sign-off
|
||||
- **With analytics**: Add a `function` node that logs generated content to a tracking system
|
||||
@@ -0,0 +1,38 @@
|
||||
# Templates
|
||||
|
||||
A template is a working agent scaffold that follows the standard Hive export structure. Copy it, rename it, customize the goal/nodes/edges, and run it.
|
||||
|
||||
## What's in a template
|
||||
|
||||
Each template is a complete agent package:
|
||||
|
||||
```
|
||||
template_name/
|
||||
├── __init__.py # Package exports
|
||||
├── __main__.py # CLI entry point
|
||||
├── agent.py # Goal, edges, graph spec, agent class
|
||||
├── config.py # Runtime configuration
|
||||
├── nodes/
|
||||
│ └── __init__.py # Node definitions (NodeSpec instances)
|
||||
└── README.md # What this template demonstrates
|
||||
```
|
||||
|
||||
## How to use a template
|
||||
|
||||
```bash
|
||||
# 1. Copy to your exports directory
|
||||
cp -r examples/templates/marketing_agent exports/my_marketing_agent
|
||||
|
||||
# 2. Update the module references in __main__.py and __init__.py
|
||||
|
||||
# 3. Customize goal, nodes, edges, and prompts
|
||||
|
||||
# 4. Run it
|
||||
PYTHONPATH=core python -m exports.my_marketing_agent --input '{"product_description": "..."}'
|
||||
```
|
||||
|
||||
## Available templates
|
||||
|
||||
| Template | Description |
|
||||
|----------|-------------|
|
||||
| [marketing_agent](marketing_agent/) | Multi-channel marketing content generator with audience analysis, content generation, and editorial review nodes |
|
||||
@@ -0,0 +1,57 @@
|
||||
# Template: Marketing Content Agent
|
||||
|
||||
A multi-channel marketing content generator. Given a product and audience, this agent analyzes the audience, generates tailored copy for multiple channels with A/B variants, and reviews the output for quality.
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
[analyze-audience] → [generate-content] → [review-and-refine]
|
||||
|
|
||||
(conditional)
|
||||
|
|
||||
needs_revision == True → [generate-content]
|
||||
needs_revision == False → (done)
|
||||
```
|
||||
|
||||
## Nodes
|
||||
|
||||
| Node | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `analyze-audience` | `llm_generate` | Produces structured audience analysis |
|
||||
| `generate-content` | `llm_generate` | Creates per-channel copy with A/B variants |
|
||||
| `review-and-refine` | `llm_generate` | Reviews and optionally revises content |
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# From the repo root
|
||||
PYTHONPATH=core python -m examples.templates.marketing_agent
|
||||
|
||||
# With custom input
|
||||
PYTHONPATH=core python -m examples.templates.marketing_agent --input '{
|
||||
"product_description": "A fitness tracking app",
|
||||
"target_audience": "Health-conscious millennials",
|
||||
"brand_voice": "Energetic and motivational",
|
||||
"channels": ["instagram", "email"]
|
||||
}'
|
||||
```
|
||||
|
||||
## Customization ideas
|
||||
|
||||
- Add a `function` node to call an analytics API and inform audience analysis with real data
|
||||
- Add a `human_input` pause node before final output for editorial approval
|
||||
- Swap `llm_generate` nodes to `llm_tool_use` and add web search tools for competitive research
|
||||
- Add an image generation tool to produce visual assets alongside copy
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
marketing_agent/
|
||||
├── __init__.py # Package exports
|
||||
├── __main__.py # CLI entry point
|
||||
├── agent.py # Goal, edges, graph spec, MarketingAgent class
|
||||
├── config.py # RuntimeConfig and AgentMetadata
|
||||
├── nodes/
|
||||
│ └── __init__.py # NodeSpec definitions
|
||||
└── README.md # This file
|
||||
```
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Marketing Content Agent — template example."""
|
||||
|
||||
from .agent import MarketingAgent, goal, edges, nodes
|
||||
from .config import default_config
|
||||
|
||||
__all__ = ["MarketingAgent", "goal", "edges", "nodes", "default_config"]
|
||||
@@ -0,0 +1,31 @@
|
||||
"""CLI entry point for Marketing Content Agent."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
from .agent import MarketingAgent
|
||||
from .config import default_config
|
||||
|
||||
# Simple CLI — replace with Click for production use
|
||||
input_data = {
|
||||
"product_description": "An AI-powered project management tool for remote teams",
|
||||
"target_audience": "Engineering managers at mid-size tech companies",
|
||||
"brand_voice": "Professional but approachable, concise, data-driven",
|
||||
"channels": ["email", "twitter", "linkedin"],
|
||||
}
|
||||
|
||||
# Accept JSON input from command line
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--input":
|
||||
input_data = json.loads(sys.argv[2])
|
||||
|
||||
agent = MarketingAgent(config=default_config)
|
||||
result = asyncio.run(agent.run(input_data))
|
||||
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,161 @@
|
||||
"""Marketing Content Agent — goal, edges, graph spec, and agent class."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import EdgeCondition, EdgeSpec, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
|
||||
from .config import default_config, RuntimeConfig
|
||||
from .nodes import all_nodes
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Goal
|
||||
# ---------------------------------------------------------------------------
|
||||
goal = Goal(
|
||||
id="marketing-content",
|
||||
name="Marketing Content Generator",
|
||||
description=(
|
||||
"Generate targeted marketing content across multiple channels "
|
||||
"for a given product and audience."
|
||||
),
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="audience-analyzed",
|
||||
description="Audience analysis is produced with demographics and pain points",
|
||||
metric="output_contains",
|
||||
target="audience_analysis",
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="content-generated",
|
||||
description="At least 2 channel-specific content pieces are generated",
|
||||
metric="custom",
|
||||
target="len(content) >= 2",
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="variants-provided",
|
||||
description="A/B variants are provided for each content piece",
|
||||
metric="custom",
|
||||
target="all variants present",
|
||||
),
|
||||
],
|
||||
constraints=[
|
||||
Constraint(
|
||||
id="no-competitor-names",
|
||||
description="No competitor brand names in generated content",
|
||||
constraint_type="hard",
|
||||
category="safety",
|
||||
),
|
||||
Constraint(
|
||||
id="social-length",
|
||||
description="Social media content should be under 280 characters",
|
||||
constraint_type="soft",
|
||||
category="quality",
|
||||
),
|
||||
],
|
||||
input_schema={
|
||||
"product_description": {"type": "string"},
|
||||
"target_audience": {"type": "string"},
|
||||
"brand_voice": {"type": "string"},
|
||||
"channels": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
output_schema={
|
||||
"audience_analysis": {"type": "object"},
|
||||
"content": {"type": "array"},
|
||||
},
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edges
|
||||
# ---------------------------------------------------------------------------
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="analyze-to-generate",
|
||||
source="analyze-audience",
|
||||
target="generate-content",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
description="After audience analysis, generate content",
|
||||
),
|
||||
EdgeSpec(
|
||||
id="generate-to-review",
|
||||
source="generate-content",
|
||||
target="review-and-refine",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
description="After content generation, review and refine",
|
||||
),
|
||||
EdgeSpec(
|
||||
id="review-to-regenerate",
|
||||
source="review-and-refine",
|
||||
target="generate-content",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="needs_revision == True",
|
||||
priority=10,
|
||||
description="If revision needed, loop back to content generation",
|
||||
),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph structure
|
||||
# ---------------------------------------------------------------------------
|
||||
entry_node = "analyze-audience"
|
||||
entry_points = {"start": "analyze-audience"}
|
||||
terminal_nodes = ["review-and-refine"]
|
||||
pause_nodes = []
|
||||
nodes = all_nodes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent class
|
||||
# ---------------------------------------------------------------------------
|
||||
class MarketingAgent:
|
||||
"""Multi-channel marketing content generator agent."""
|
||||
|
||||
def __init__(self, config: RuntimeConfig | None = None):
|
||||
self.config = config or default_config
|
||||
self.goal = goal
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
self.entry_node = entry_node
|
||||
self.terminal_nodes = terminal_nodes
|
||||
self.executor = None
|
||||
|
||||
def _build_graph(self) -> GraphSpec:
|
||||
return GraphSpec(
|
||||
id="marketing-content-graph",
|
||||
goal_id=self.goal.id,
|
||||
entry_node=self.entry_node,
|
||||
entry_points=entry_points,
|
||||
terminal_nodes=self.terminal_nodes,
|
||||
pause_nodes=pause_nodes,
|
||||
nodes=self.nodes,
|
||||
edges=self.edges,
|
||||
default_model=self.config.model,
|
||||
max_tokens=self.config.max_tokens,
|
||||
description="Marketing content generation workflow",
|
||||
)
|
||||
|
||||
def _create_executor(self):
|
||||
runtime = Runtime(storage_path=Path(self.config.storage_path).expanduser())
|
||||
llm = AnthropicProvider(model=self.config.model)
|
||||
self.executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
return self.executor
|
||||
|
||||
async def run(self, context: dict, mock_mode: bool = False) -> dict:
|
||||
graph = self._build_graph()
|
||||
executor = self._create_executor()
|
||||
result = await executor.execute(
|
||||
graph=graph,
|
||||
goal=self.goal,
|
||||
input_data=context,
|
||||
)
|
||||
return {
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"steps": result.steps_executed,
|
||||
"path": result.path,
|
||||
}
|
||||
|
||||
|
||||
default_agent = MarketingAgent()
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Runtime configuration for Marketing Content Agent."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfig:
|
||||
model: str = "claude-haiku-4-5-20251001"
|
||||
max_tokens: int = 2048
|
||||
storage_path: str = "~/.hive/storage"
|
||||
mock_mode: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
name: str = "marketing_agent"
|
||||
version: str = "0.1.0"
|
||||
description: str = "Multi-channel marketing content generator"
|
||||
author: str = ""
|
||||
tags: list[str] = field(
|
||||
default_factory=lambda: ["marketing", "content", "template"]
|
||||
)
|
||||
|
||||
|
||||
default_config = RuntimeConfig()
|
||||
metadata = AgentMetadata()
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Node definitions for Marketing Content Agent."""
|
||||
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 1: Analyze the target audience
|
||||
# ---------------------------------------------------------------------------
|
||||
analyze_audience_node = NodeSpec(
|
||||
id="analyze-audience",
|
||||
name="Analyze Audience",
|
||||
description="Produce a structured audience analysis from the product and target audience description.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["product_description", "target_audience"],
|
||||
output_keys=["audience_analysis"],
|
||||
system_prompt="""\
|
||||
You are a marketing strategist. Analyze the target audience for a product.
|
||||
|
||||
Product: {product_description}
|
||||
Target audience: {target_audience}
|
||||
|
||||
Produce a structured analysis as raw JSON (no markdown):
|
||||
{{
|
||||
"audience_analysis": {{
|
||||
"demographics": "...",
|
||||
"pain_points": ["..."],
|
||||
"motivations": ["..."],
|
||||
"preferred_channels": ["..."],
|
||||
"messaging_angle": "..."
|
||||
}}
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 2: Generate channel-specific content with A/B variants
|
||||
# ---------------------------------------------------------------------------
|
||||
generate_content_node = NodeSpec(
|
||||
id="generate-content",
|
||||
name="Generate Content",
|
||||
description="Create marketing copy for each requested channel with two variants per channel.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["product_description", "audience_analysis", "brand_voice", "channels"],
|
||||
output_keys=["content"],
|
||||
system_prompt="""\
|
||||
You are a marketing copywriter. Generate content for each channel.
|
||||
|
||||
Product: {product_description}
|
||||
Audience analysis: {audience_analysis}
|
||||
Brand voice: {brand_voice}
|
||||
Channels: {channels}
|
||||
|
||||
For each channel, produce two variants (A and B).
|
||||
|
||||
Output as raw JSON (no markdown):
|
||||
{{
|
||||
"content": [
|
||||
{{
|
||||
"channel": "twitter",
|
||||
"variant_a": "...",
|
||||
"variant_b": "..."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 3: Review and refine content
|
||||
# ---------------------------------------------------------------------------
|
||||
review_and_refine_node = NodeSpec(
|
||||
id="review-and-refine",
|
||||
name="Review and Refine",
|
||||
description="Review generated content for brand voice alignment and channel fit. Revise if needed.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["content", "brand_voice"],
|
||||
output_keys=["content", "needs_revision"],
|
||||
system_prompt="""\
|
||||
You are a senior marketing editor. Review the following content for brand
|
||||
voice alignment, clarity, and channel appropriateness.
|
||||
|
||||
Content: {content}
|
||||
Brand voice: {brand_voice}
|
||||
|
||||
If any piece needs revision, fix it and set needs_revision to true.
|
||||
If everything looks good, return the content unchanged with needs_revision false.
|
||||
|
||||
Output as raw JSON (no markdown):
|
||||
{{
|
||||
"content": [...],
|
||||
"needs_revision": false
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# All nodes for easy import
|
||||
all_nodes = [
|
||||
analyze_audience_node,
|
||||
generate_content_node,
|
||||
review_and_refine_node,
|
||||
]
|
||||
@@ -0,0 +1,2 @@
|
||||
[tool.uv.workspace]
|
||||
members = ["core", "tools"]
|
||||
+149
-153
@@ -11,6 +11,14 @@
|
||||
|
||||
set -e
|
||||
|
||||
# Detect Bash version for compatibility
|
||||
BASH_MAJOR_VERSION="${BASH_VERSINFO[0]}"
|
||||
USE_ASSOC_ARRAYS=false
|
||||
if [ "$BASH_MAJOR_VERSION" -ge 4 ]; then
|
||||
USE_ASSOC_ARRAYS=true
|
||||
fi
|
||||
echo "[debug] Bash version: ${BASH_VERSION}"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
@@ -52,7 +60,7 @@ prompt_choice() {
|
||||
echo -e "${BOLD}$prompt${NC}"
|
||||
for opt in "${options[@]}"; do
|
||||
echo -e " ${CYAN}$i)${NC} $opt"
|
||||
((i++))
|
||||
i=$((i + 1))
|
||||
done
|
||||
echo ""
|
||||
|
||||
@@ -60,7 +68,8 @@ prompt_choice() {
|
||||
while true; do
|
||||
read -r -p "Enter choice (1-${#options[@]}): " choice
|
||||
if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#options[@]}" ]; then
|
||||
return $((choice - 1))
|
||||
PROMPT_CHOICE=$((choice - 1))
|
||||
return 0
|
||||
fi
|
||||
echo -e "${RED}Invalid choice. Please enter 1-${#options[@]}${NC}"
|
||||
done
|
||||
@@ -174,63 +183,26 @@ echo ""
|
||||
echo -e "${DIM}This may take a minute...${NC}"
|
||||
echo ""
|
||||
|
||||
# Upgrade pip, setuptools, and wheel
|
||||
echo -n " Upgrading pip... "
|
||||
$PYTHON_CMD -m pip install --upgrade pip setuptools wheel > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install framework package from core/
|
||||
echo -n " Installing framework... "
|
||||
cd "$SCRIPT_DIR/core"
|
||||
# Install all workspace packages (core + tools) from workspace root
|
||||
echo -n " Installing workspace packages... "
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
if [ -f "pyproject.toml" ]; then
|
||||
uv sync > /dev/null 2>&1
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${GREEN} ✓ framework package installed${NC}"
|
||||
if uv sync > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ workspace packages installed${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ framework installation had issues (may be OK)${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}failed (no pyproject.toml)${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install aden_tools package from tools/
|
||||
echo -n " Installing tools... "
|
||||
cd "$SCRIPT_DIR/tools"
|
||||
|
||||
if [ -f "pyproject.toml" ]; then
|
||||
uv sync > /dev/null 2>&1
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${GREEN} ✓ aden_tools package installed${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ aden_tools installation failed${NC}"
|
||||
echo -e "${RED} ✗ workspace installation failed${NC}"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
echo -e "${RED}failed (no root pyproject.toml)${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install MCP dependencies
|
||||
echo -n " Installing MCP... "
|
||||
$PYTHON_CMD -m pip install mcp fastmcp > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Fix openai version compatibility
|
||||
echo -n " Checking openai... "
|
||||
$PYTHON_CMD -m pip install "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install click for CLI
|
||||
echo -n " Installing CLI tools... "
|
||||
$PYTHON_CMD -m pip install click > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install Playwright browser
|
||||
echo -n " Installing Playwright browser... "
|
||||
if $PYTHON_CMD -c "import playwright" > /dev/null 2>&1; then
|
||||
if $PYTHON_CMD -m playwright install chromium > /dev/null 2>&1; then
|
||||
if uv run python -c "import playwright" > /dev/null 2>&1; then
|
||||
if uv run python -m playwright install chromium > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}⏭${NC}"
|
||||
@@ -249,33 +221,6 @@ echo ""
|
||||
# ============================================================
|
||||
|
||||
echo -e "${YELLOW}⬢${NC} ${BLUE}${BOLD}Step 3: Configuring LLM provider...${NC}"
|
||||
# Install MCP dependencies (in tools venv)
|
||||
echo " Installing MCP dependencies..."
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
uv pip install --python "$TOOLS_PYTHON" mcp fastmcp > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ MCP dependencies installed${NC}"
|
||||
|
||||
# Fix openai version compatibility (in tools venv)
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
OPENAI_VERSION=$($TOOLS_PYTHON -c "import openai; print(openai.__version__)" 2>/dev/null || echo "not_installed")
|
||||
if [ "$OPENAI_VERSION" = "not_installed" ]; then
|
||||
echo " Installing openai package..."
|
||||
uv pip install --python "$TOOLS_PYTHON" "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ openai installed${NC}"
|
||||
elif [[ "$OPENAI_VERSION" =~ ^0\. ]]; then
|
||||
echo " Upgrading openai to 1.x+ for litellm compatibility..."
|
||||
uv pip install --python "$TOOLS_PYTHON" --upgrade "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ openai upgraded${NC}"
|
||||
else
|
||||
echo -e "${GREEN} ✓ openai $OPENAI_VERSION is compatible${NC}"
|
||||
fi
|
||||
|
||||
# Install click for CLI (in tools venv)
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
uv pip install --python "$TOOLS_PYTHON" click > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ click installed${NC}"
|
||||
|
||||
cd "$SCRIPT_DIR"
|
||||
echo ""
|
||||
|
||||
# ============================================================
|
||||
@@ -287,42 +232,28 @@ echo ""
|
||||
|
||||
IMPORT_ERRORS=0
|
||||
|
||||
# Test imports using their respective venvs
|
||||
CORE_PYTHON="$SCRIPT_DIR/core/.venv/bin/python"
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
|
||||
# Test framework import (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "import framework" > /dev/null 2>&1; then
|
||||
# Test imports using workspace venv via uv run
|
||||
if uv run python -c "import framework" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ framework imports OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ framework import failed${NC}"
|
||||
IMPORT_ERRORS=$((IMPORT_ERRORS + 1))
|
||||
fi
|
||||
|
||||
# Test aden_tools import (from tools venv)
|
||||
if [ -f "$TOOLS_PYTHON" ] && $TOOLS_PYTHON -c "import aden_tools" > /dev/null 2>&1; then
|
||||
if uv run python -c "import aden_tools" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ aden_tools imports OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ aden_tools import failed${NC}"
|
||||
IMPORT_ERRORS=$((IMPORT_ERRORS + 1))
|
||||
fi
|
||||
|
||||
# Test litellm import (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK (core)${NC}"
|
||||
if uv run python -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ litellm import issues in core (may be OK)${NC}"
|
||||
echo -e "${YELLOW} ⚠ litellm import issues (may be OK)${NC}"
|
||||
fi
|
||||
|
||||
# Test litellm import (from tools venv)
|
||||
if [ -f "$TOOLS_PYTHON" ] && $TOOLS_PYTHON -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK (tools)${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ litellm import issues in tools (may be OK)${NC}"
|
||||
fi
|
||||
|
||||
# Test MCP server module (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "from framework.mcp import agent_builder_server" > /dev/null 2>&1; then
|
||||
if uv run python -c "from framework.mcp import agent_builder_server" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ MCP server module OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ MCP server module failed${NC}"
|
||||
@@ -344,53 +275,105 @@ echo ""
|
||||
echo -e "${BLUE}Step 4: Verifying Claude Code skills...${NC}"
|
||||
echo ""
|
||||
|
||||
# Provider data as parallel indexed arrays (Bash 3.2 compatible — no declare -A)
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_DISPLAY_NAMES=("Anthropic (Claude)" "OpenAI (GPT)" "Google Gemini" "Google AI" "Groq" "Cerebras" "Mistral" "Together AI" "DeepSeek")
|
||||
PROVIDER_ID_LIST=(anthropic openai gemini google groq cerebras mistral together deepseek)
|
||||
# Provider configuration - use associative arrays (Bash 4+) or indexed arrays (Bash 3.2)
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
# Bash 4+ - use associative arrays (cleaner and more efficient)
|
||||
declare -A PROVIDER_NAMES=(
|
||||
["ANTHROPIC_API_KEY"]="Anthropic (Claude)"
|
||||
["OPENAI_API_KEY"]="OpenAI (GPT)"
|
||||
["GEMINI_API_KEY"]="Google Gemini"
|
||||
["GOOGLE_API_KEY"]="Google AI"
|
||||
["GROQ_API_KEY"]="Groq"
|
||||
["CEREBRAS_API_KEY"]="Cerebras"
|
||||
["MISTRAL_API_KEY"]="Mistral"
|
||||
["TOGETHER_API_KEY"]="Together AI"
|
||||
["DEEPSEEK_API_KEY"]="DeepSeek"
|
||||
)
|
||||
|
||||
# Default models by provider id (parallel arrays)
|
||||
MODEL_PROVIDER_IDS=(anthropic openai gemini groq cerebras mistral together_ai deepseek)
|
||||
MODEL_DEFAULTS=("claude-sonnet-4-5-20250929" "gpt-4o" "gemini-3.0-flash-preview" "moonshotai/kimi-k2-instruct-0905" "zai-glm-4.7" "mistral-large-latest" "meta-llama/Llama-3.3-70B-Instruct-Turbo" "deepseek-chat")
|
||||
declare -A PROVIDER_IDS=(
|
||||
["ANTHROPIC_API_KEY"]="anthropic"
|
||||
["OPENAI_API_KEY"]="openai"
|
||||
["GEMINI_API_KEY"]="gemini"
|
||||
["GOOGLE_API_KEY"]="google"
|
||||
["GROQ_API_KEY"]="groq"
|
||||
["CEREBRAS_API_KEY"]="cerebras"
|
||||
["MISTRAL_API_KEY"]="mistral"
|
||||
["TOGETHER_API_KEY"]="together"
|
||||
["DEEPSEEK_API_KEY"]="deepseek"
|
||||
)
|
||||
|
||||
# Helper: get provider display name for an env var
|
||||
get_provider_name() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_DISPLAY_NAMES[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
declare -A DEFAULT_MODELS=(
|
||||
["anthropic"]="claude-sonnet-4-5-20250929"
|
||||
["openai"]="gpt-4o"
|
||||
["gemini"]="gemini-3.0-flash-preview"
|
||||
["groq"]="moonshotai/kimi-k2-instruct-0905"
|
||||
["cerebras"]="zai-glm-4.7"
|
||||
["mistral"]="mistral-large-latest"
|
||||
["together_ai"]="meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||
["deepseek"]="deepseek-chat"
|
||||
)
|
||||
|
||||
# Helper: get provider id for an env var
|
||||
get_provider_id() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_ID_LIST[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
# Helper functions for Bash 4+
|
||||
get_provider_name() {
|
||||
echo "${PROVIDER_NAMES[$1]}"
|
||||
}
|
||||
|
||||
# Helper: get default model for a provider id
|
||||
get_default_model() {
|
||||
local provider_id="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#MODEL_PROVIDER_IDS[@]} ]; do
|
||||
if [ "${MODEL_PROVIDER_IDS[$i]}" = "$provider_id" ]; then
|
||||
echo "${MODEL_DEFAULTS[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
get_provider_id() {
|
||||
echo "${PROVIDER_IDS[$1]}"
|
||||
}
|
||||
|
||||
get_default_model() {
|
||||
echo "${DEFAULT_MODELS[$1]}"
|
||||
}
|
||||
else
|
||||
# Bash 3.2 - use parallel indexed arrays
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_DISPLAY_NAMES=("Anthropic (Claude)" "OpenAI (GPT)" "Google Gemini" "Google AI" "Groq" "Cerebras" "Mistral" "Together AI" "DeepSeek")
|
||||
PROVIDER_ID_LIST=(anthropic openai gemini google groq cerebras mistral together deepseek)
|
||||
|
||||
# Default models by provider id (parallel arrays)
|
||||
MODEL_PROVIDER_IDS=(anthropic openai gemini groq cerebras mistral together_ai deepseek)
|
||||
MODEL_DEFAULTS=("claude-sonnet-4-5-20250929" "gpt-4o" "gemini-3.0-flash-preview" "moonshotai/kimi-k2-instruct-0905" "zai-glm-4.7" "mistral-large-latest" "meta-llama/Llama-3.3-70B-Instruct-Turbo" "deepseek-chat")
|
||||
|
||||
# Helper: get provider display name for an env var
|
||||
get_provider_name() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_DISPLAY_NAMES[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
|
||||
# Helper: get provider id for an env var
|
||||
get_provider_id() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_ID_LIST[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
|
||||
# Helper: get default model for a provider id
|
||||
get_default_model() {
|
||||
local provider_id="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#MODEL_PROVIDER_IDS[@]} ]; do
|
||||
if [ "${MODEL_PROVIDER_IDS[$i]}" = "$provider_id" ]; then
|
||||
echo "${MODEL_DEFAULTS[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
fi
|
||||
|
||||
# Configuration directory
|
||||
HIVE_CONFIG_DIR="$HOME/.hive"
|
||||
@@ -413,7 +396,7 @@ config = {
|
||||
'model': '$model',
|
||||
'api_key_env_var': '$env_var'
|
||||
},
|
||||
'created_at': '$(date -Iseconds)'
|
||||
'created_at': '$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")'
|
||||
}
|
||||
with open('$HIVE_CONFIG_FILE', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
@@ -442,13 +425,25 @@ FOUND_ENV_VARS=() # Corresponding env var names
|
||||
SELECTED_PROVIDER_ID="" # Will hold the chosen provider ID
|
||||
SELECTED_ENV_VAR="" # Will hold the chosen env var
|
||||
|
||||
for env_var in "${PROVIDER_ENV_VARS[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
# Bash 4+ - iterate over associative array keys
|
||||
for env_var in "${!PROVIDER_NAMES[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
else
|
||||
# Bash 3.2 - iterate over indexed array
|
||||
for env_var in "${PROVIDER_ENV_VARS[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${#FOUND_PROVIDERS[@]} -gt 0 ]; then
|
||||
echo "Found API keys:"
|
||||
@@ -476,7 +471,7 @@ if [ ${#FOUND_PROVIDERS[@]} -gt 0 ]; then
|
||||
i=1
|
||||
for provider in "${FOUND_PROVIDERS[@]}"; do
|
||||
echo -e " ${CYAN}$i)${NC} $provider"
|
||||
((i++))
|
||||
i=$((i + 1))
|
||||
done
|
||||
echo ""
|
||||
|
||||
@@ -507,7 +502,7 @@ if [ -z "$SELECTED_PROVIDER_ID" ]; then
|
||||
"Groq - Fast, free tier" \
|
||||
"Cerebras - Fast, free tier" \
|
||||
"Skip for now"
|
||||
choice=$?
|
||||
choice=$PROMPT_CHOICE
|
||||
|
||||
case $choice in
|
||||
0)
|
||||
@@ -542,7 +537,8 @@ if [ -z "$SELECTED_PROVIDER_ID" ]; then
|
||||
;;
|
||||
5)
|
||||
echo ""
|
||||
echo -e "${YELLOW}Skipped.${NC} Add your API key later:"
|
||||
echo -e "${YELLOW}Skipped.${NC} An LLM API key is required to test and use worker agents."
|
||||
echo -e "Add your API key later by running:"
|
||||
echo ""
|
||||
echo -e " ${CYAN}echo 'ANTHROPIC_API_KEY=your-key' >> .env${NC}"
|
||||
echo ""
|
||||
@@ -595,7 +591,7 @@ ERRORS=0
|
||||
|
||||
# Test imports
|
||||
echo -n " ⬡ framework... "
|
||||
if $PYTHON_CMD -c "import framework" > /dev/null 2>&1; then
|
||||
if uv run python -c "import framework" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
@@ -603,7 +599,7 @@ else
|
||||
fi
|
||||
|
||||
echo -n " ⬡ aden_tools... "
|
||||
if $PYTHON_CMD -c "import aden_tools" > /dev/null 2>&1; then
|
||||
if uv run python -c "import aden_tools" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
@@ -611,7 +607,7 @@ else
|
||||
fi
|
||||
|
||||
echo -n " ⬡ litellm... "
|
||||
if $PYTHON_CMD -c "import litellm" > /dev/null 2>&1; then
|
||||
if uv run python -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}--${NC}"
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
<#
|
||||
|
||||
setup-python.ps1 - Python Environment Setup for Aden Agent Framework
|
||||
|
||||
This script sets up the Python environment with all required packages
|
||||
for building and running goal-driven agents.
|
||||
#>
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# Colors for output
|
||||
$RED = "Red"
|
||||
$GREEN = "Green"
|
||||
$YELLOW = "Yellow"
|
||||
$BLUE = "Cyan"
|
||||
|
||||
# Get the directory where this script is located
|
||||
$SCRIPT_DIR = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||
$PROJECT_ROOT = Split-Path -Parent $SCRIPT_DIR
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "=================================================="
|
||||
Write-Host " Aden Agent Framework - Python Setup"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
# Check for Python
|
||||
$pythonCmd = $null
|
||||
if (Get-Command python -ErrorAction SilentlyContinue) {
|
||||
$pythonCmd = "python"
|
||||
}
|
||||
|
||||
if (-not $pythonCmd) {
|
||||
Write-Host "Error: Python is not installed." -ForegroundColor $RED
|
||||
Write-Host "Please install Python 3.11+ from https://python.org"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check Python version
|
||||
$versionInfo = & $pythonCmd -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')"
|
||||
$major = & $pythonCmd -c "import sys; print(sys.version_info.major)"
|
||||
$minor = & $pythonCmd -c "import sys; print(sys.version_info.minor)"
|
||||
|
||||
Write-Host "Detected Python: $versionInfo" -ForegroundColor $BLUE
|
||||
|
||||
if ($major -lt 3 -or ($major -eq 3 -and $minor -lt 11)) {
|
||||
Write-Host "Error: Python 3.11+ is required (found $versionInfo)" -ForegroundColor $RED
|
||||
Write-Host "Please upgrade your Python installation"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if ($minor -lt 11) {
|
||||
Write-Host "Warning: Python 3.11+ is recommended for best compatibility" -ForegroundColor $YELLOW
|
||||
Write-Host "You have Python $versionInfo which may work but is not officially supported" -ForegroundColor $YELLOW
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
Write-Host "[OK] Python version check passed" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Create and activate virtual environment
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Setting up Python Virtual Environment"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
$VENV_PATH = Join-Path $PROJECT_ROOT ".venv"
|
||||
$VENV_PYTHON = Join-Path $VENV_PATH "Scripts\python.exe"
|
||||
$VENV_ACTIVATE = Join-Path $VENV_PATH "Scripts\Activate.ps1"
|
||||
|
||||
if (-not (Test-Path $VENV_PYTHON)) {
|
||||
Write-Host "Creating virtual environment at .venv..."
|
||||
& $pythonCmd -m venv $VENV_PATH
|
||||
Write-Host "[OK] Virtual environment created" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[OK] Virtual environment already exists" -ForegroundColor $GREEN
|
||||
}
|
||||
|
||||
# Activate venv
|
||||
Write-Host "Activating virtual environment..."
|
||||
& $VENV_ACTIVATE
|
||||
Write-Host "[OK] Virtual environment activated" -ForegroundColor $GREEN
|
||||
|
||||
# From here on, always use venv python
|
||||
$pythonCmd = $VENV_PYTHON
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Check for pip
|
||||
try {
|
||||
& $pythonCmd -m pip --version | Out-Null
|
||||
}
|
||||
catch {
|
||||
Write-Host "Error: pip is not installed" -ForegroundColor $RED
|
||||
Write-Host "Please install pip for Python $versionInfo"
|
||||
exit 1
|
||||
}
|
||||
|
||||
Write-Host "[OK] pip detected" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Upgrade pip, setuptools, and wheel
|
||||
Write-Host "Upgrading pip, setuptools, and wheel..."
|
||||
& $pythonCmd -m pip install --upgrade pip setuptools wheel
|
||||
Write-Host "[OK] Core packages upgraded" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Install core framework package
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Installing Core Framework Package"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location "$PROJECT_ROOT\core"
|
||||
|
||||
if (Test-Path "pyproject.toml") {
|
||||
Write-Host "Installing framework from core/ (editable mode)..."
|
||||
& $pythonCmd -m pip install -e . | Out-Null
|
||||
Write-Host "[OK] Framework package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[WARN] No pyproject.toml found in core/, skipping framework installation" -ForegroundColor $YELLOW
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Install tools package
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Installing Tools Package (aden_tools)"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location "$PROJECT_ROOT\tools"
|
||||
|
||||
if (Test-Path "pyproject.toml") {
|
||||
Write-Host "Installing aden_tools from tools/ (editable mode)..."
|
||||
& $pythonCmd -m pip install -e . | Out-Null
|
||||
Write-Host "[OK] Tools package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "Error: No pyproject.toml found in tools/" -ForegroundColor $RED
|
||||
exit 1
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Fix openai version compatibility with litellm
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Fixing Package Compatibility"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
try {
|
||||
$openaiVersion = & $pythonCmd -c "import openai; print(openai.__version__)"
|
||||
}
|
||||
catch {
|
||||
$openaiVersion = "not_installed"
|
||||
}
|
||||
|
||||
if ($openaiVersion -eq "not_installed") {
|
||||
Write-Host "Installing openai package..."
|
||||
& $pythonCmd -m pip install "openai>=1.0.0" | Out-Null
|
||||
Write-Host "[OK] openai package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
elseif ($openaiVersion.StartsWith("0.")) {
|
||||
Write-Host "Found old openai version: $openaiVersion" -ForegroundColor $YELLOW
|
||||
Write-Host "Upgrading to openai 1.x+ for litellm compatibility..."
|
||||
& $pythonCmd -m pip install --upgrade "openai>=1.0.0" | Out-Null
|
||||
$openaiVersion = & $pythonCmd -c "import openai; print(openai.__version__)"
|
||||
Write-Host "[OK] openai upgraded to $openaiVersion" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[OK] openai $openaiVersion is compatible" -ForegroundColor $GREEN
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Verify installations
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Verifying Installation"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location $PROJECT_ROOT
|
||||
|
||||
# Test framework import
|
||||
& $pythonCmd -c "import framework" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] framework package imports successfully" -ForegroundColor Green
|
||||
}
|
||||
else {
|
||||
Write-Host "[FAIL] framework package import failed" -ForegroundColor Red
|
||||
}
|
||||
|
||||
# Test aden_tools import
|
||||
& $pythonCmd -c "import aden_tools" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] aden_tools package imports successfully" -ForegroundColor Green
|
||||
}
|
||||
else {
|
||||
Write-Host "[FAIL] aden_tools package import failed" -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Test litellm
|
||||
& $pythonCmd -c "import litellm" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] litellm package imports successfully" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[WARN] litellm import had issues (may be OK if not using LLM features)" -ForegroundColor $YELLOW
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Print agent commands
|
||||
Write-Host "=================================================="
|
||||
Write-Host " Setup Complete!"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
Write-Host "Python packages installed:"
|
||||
Write-Host " - framework (core agent runtime)"
|
||||
Write-Host " - aden_tools (tools and MCP servers)"
|
||||
Write-Host " - All dependencies and compatibility fixes applied"
|
||||
Write-Host ""
|
||||
Write-Host "To run agents on Windows (PowerShell):"
|
||||
Write-Host ""
|
||||
Write-Host "1. From the project root, set PYTHONPATH:"
|
||||
Write-Host " `$env:PYTHONPATH=`"core;exports`""
|
||||
Write-Host ""
|
||||
Write-Host "2. Run an agent command:"
|
||||
Write-Host " python -m agent_name validate"
|
||||
Write-Host " python -m agent_name info"
|
||||
Write-Host " python -m agent_name run --input '{...}'"
|
||||
Write-Host ""
|
||||
Write-Host "Example (support_ticket_agent):"
|
||||
Write-Host " python -m support_ticket_agent validate"
|
||||
Write-Host " python -m support_ticket_agent info"
|
||||
Write-Host " python -m support_ticket_agent run --input '{""ticket_content"":""..."",""customer_id"":""..."",""ticket_id"":""...""}'"
|
||||
Write-Host ""
|
||||
Write-Host "Notes:"
|
||||
Write-Host " - Ensure the virtual environment is activated (.venv)"
|
||||
Write-Host " - PYTHONPATH must be set in each new PowerShell session"
|
||||
Write-Host ""
|
||||
Write-Host "Documentation:"
|
||||
Write-Host " $PROJECT_ROOT\README.md"
|
||||
Write-Host ""
|
||||
Write-Host "Agent Examples:"
|
||||
Write-Host " $PROJECT_ROOT\exports\"
|
||||
Write-Host ""
|
||||
+1
-12
@@ -68,18 +68,7 @@ from starlette.responses import PlainTextResponse # noqa: E402
|
||||
from aden_tools.credentials import CredentialError, CredentialStoreAdapter # noqa: E402
|
||||
from aden_tools.tools import register_all_tools # noqa: E402
|
||||
|
||||
# Create credential store with access to both env vars AND encrypted store
|
||||
# This allows using Aden-synced credentials from ~/.hive/credentials
|
||||
try:
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
store = CredentialStore.with_encrypted_storage() # ~/.hive/credentials
|
||||
credentials = CredentialStoreAdapter(store)
|
||||
logger.info("Using CredentialStoreAdapter with encrypted storage")
|
||||
except Exception as e:
|
||||
# Fall back to env-only adapter if encrypted storage fails
|
||||
credentials = CredentialStoreAdapter.with_env_storage()
|
||||
logger.warning(f"Falling back to env-only CredentialStoreAdapter: {e}")
|
||||
credentials = CredentialStoreAdapter.default()
|
||||
|
||||
# Tier 1: Validate startup-required credentials (if any)
|
||||
try:
|
||||
|
||||
@@ -30,6 +30,7 @@ dependencies = [
|
||||
"playwright-stealth>=1.0.5",
|
||||
"litellm>=1.81.0",
|
||||
"resend>=2.0.0",
|
||||
"framework",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -54,6 +55,9 @@ all = [
|
||||
"duckdb>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
framework = { workspace = true }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# MCP Server
|
||||
fastmcp
|
||||
|
||||
# Tool dependencies
|
||||
diff-match-patch
|
||||
pypdf
|
||||
beautifulsoup4
|
||||
lxml
|
||||
playwright
|
||||
playwright-stealth
|
||||
requests
|
||||
|
||||
# Note: After installing, run `playwright install` to download browser binaries
|
||||
@@ -10,7 +10,7 @@ Usage:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
mcp = FastMCP("my-server")
|
||||
credentials = CredentialStoreAdapter.with_env_storage()
|
||||
credentials = CredentialStoreAdapter.default()
|
||||
register_all_tools(mcp, credentials=credentials)
|
||||
"""
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ Philosophy: Google Strictness + Apple UX
|
||||
|
||||
Usage:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
from core.framework.credentials import CredentialStore
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
# With encrypted storage (production)
|
||||
store = CredentialStore.with_encrypted_storage() # defaults to ~/.hive/credentials
|
||||
credentials = CredentialStoreAdapter(store)
|
||||
|
||||
# With env vars only (simple setup)
|
||||
credentials = CredentialStoreAdapter.with_env_storage()
|
||||
# With composite storage (encrypted primary + env fallback)
|
||||
credentials = CredentialStoreAdapter.default()
|
||||
|
||||
# In agent runner (validate at agent load time)
|
||||
credentials.validate_for_tools(["web_search", "file_read"])
|
||||
|
||||
@@ -70,6 +70,9 @@ class CredentialSpec:
|
||||
credential_key: str = "access_token"
|
||||
"""Key name within the credential (e.g., 'access_token', 'api_key')"""
|
||||
|
||||
credential_group: str = ""
|
||||
"""Group name for credentials that must be configured together (e.g., 'google_custom_search')"""
|
||||
|
||||
|
||||
class CredentialError(Exception):
|
||||
"""Raised when required credentials are missing."""
|
||||
|
||||
@@ -15,5 +15,22 @@ EMAIL_CREDENTIALS = {
|
||||
startup_required=False,
|
||||
help_url="https://resend.com/api-keys",
|
||||
description="API key for Resend email service",
|
||||
# Auth method support
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Resend API key:
|
||||
1. Go to https://resend.com and create an account (or sign in)
|
||||
2. Navigate to API Keys in the dashboard
|
||||
3. Click "Create API Key"
|
||||
4. Give it a name (e.g., "Hive Agent") and choose permissions:
|
||||
- "Sending access" is sufficient for most use cases
|
||||
- "Full access" if you also need to manage domains
|
||||
5. Copy the API key (starts with re_)
|
||||
6. Store it securely - you won't be able to see it again!
|
||||
7. Note: You'll also need to verify a domain to send emails from custom addresses""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="https://api.resend.com/domains",
|
||||
# Credential store mapping
|
||||
credential_id="resend",
|
||||
credential_key="api_key",
|
||||
),
|
||||
}
|
||||
|
||||
@@ -231,10 +231,213 @@ class GoogleSearchHealthChecker:
|
||||
)
|
||||
|
||||
|
||||
class AnthropicHealthChecker:
|
||||
"""Health checker for Anthropic API credentials."""
|
||||
|
||||
ENDPOINT = "https://api.anthropic.com/v1/messages"
|
||||
TIMEOUT = 10.0
|
||||
|
||||
def check(self, api_key: str) -> HealthCheckResult:
|
||||
"""
|
||||
Validate Anthropic API key without consuming tokens.
|
||||
|
||||
Sends a deliberately invalid request (empty messages) to the messages endpoint.
|
||||
A 401 means invalid key; 400 (bad request) means the key authenticated
|
||||
but the payload was rejected — confirming the key is valid without
|
||||
generating any tokens. 429 (rate limited) also indicates a valid key.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.TIMEOUT) as client:
|
||||
response = client.post(
|
||||
self.ENDPOINT,
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
# Empty messages triggers 400 (not 200), so no tokens are consumed.
|
||||
json={
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"max_tokens": 1,
|
||||
"messages": [],
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message="Anthropic API key valid",
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Anthropic API key is invalid",
|
||||
details={"status_code": 401},
|
||||
)
|
||||
elif response.status_code == 429:
|
||||
# Rate limited but key is valid
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message="Anthropic API key valid (rate limited)",
|
||||
details={"status_code": 429, "rate_limited": True},
|
||||
)
|
||||
elif response.status_code == 400:
|
||||
# Bad request but key authenticated - key is valid
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message="Anthropic API key valid",
|
||||
details={"status_code": 400},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Anthropic API returned status {response.status_code}",
|
||||
details={"status_code": response.status_code},
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Anthropic API request timed out",
|
||||
details={"error": "timeout"},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Failed to connect to Anthropic API: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
class GitHubHealthChecker:
|
||||
"""Health checker for GitHub Personal Access Token."""
|
||||
|
||||
ENDPOINT = "https://api.github.com/user"
|
||||
TIMEOUT = 10.0
|
||||
|
||||
def check(self, access_token: str) -> HealthCheckResult:
|
||||
"""
|
||||
Validate GitHub token by fetching the authenticated user.
|
||||
|
||||
Returns the authenticated username on success.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.TIMEOUT) as client:
|
||||
response = client.get(
|
||||
self.ENDPOINT,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
username = data.get("login", "unknown")
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message=f"GitHub token valid (authenticated as {username})",
|
||||
details={"username": username},
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="GitHub token is invalid or expired",
|
||||
details={"status_code": 401},
|
||||
)
|
||||
elif response.status_code == 403:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="GitHub token lacks required permissions",
|
||||
details={"status_code": 403},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"GitHub API returned status {response.status_code}",
|
||||
details={"status_code": response.status_code},
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="GitHub API request timed out",
|
||||
details={"error": "timeout"},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Failed to connect to GitHub API: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
class ResendHealthChecker:
|
||||
"""Health checker for Resend API credentials."""
|
||||
|
||||
ENDPOINT = "https://api.resend.com/domains"
|
||||
TIMEOUT = 10.0
|
||||
|
||||
def check(self, api_key: str) -> HealthCheckResult:
|
||||
"""
|
||||
Validate Resend API key by listing domains.
|
||||
|
||||
A successful response confirms the key is valid.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.TIMEOUT) as client:
|
||||
response = client.get(
|
||||
self.ENDPOINT,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return HealthCheckResult(
|
||||
valid=True,
|
||||
message="Resend API key valid",
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Resend API key is invalid",
|
||||
details={"status_code": 401},
|
||||
)
|
||||
elif response.status_code == 403:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Resend API key lacks required permissions",
|
||||
details={"status_code": 403},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Resend API returned status {response.status_code}",
|
||||
details={"status_code": response.status_code},
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message="Resend API request timed out",
|
||||
details={"error": "timeout"},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
return HealthCheckResult(
|
||||
valid=False,
|
||||
message=f"Failed to connect to Resend API: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
# Registry of health checkers
|
||||
HEALTH_CHECKERS: dict[str, CredentialHealthChecker] = {
|
||||
"hubspot": HubSpotHealthChecker(),
|
||||
"brave_search": BraveSearchHealthChecker(),
|
||||
"google_search": GoogleSearchHealthChecker(),
|
||||
"anthropic": AnthropicHealthChecker(),
|
||||
"github": GitHubHealthChecker(),
|
||||
"resend": ResendHealthChecker(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user