Compare commits
108 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 119280da1a | |||
| 4d49f74d5a | |||
| 6a42b9c66b | |||
| fc4a39480a | |||
| b98afb01c8 | |||
| ccd6bb7656 | |||
| ea30e5c631 | |||
| d16a3c3b22 | |||
| a03bd78c2e | |||
| 3cca41aab1 | |||
| d19aaed946 | |||
| 9a7db8cf94 | |||
| f50630c551 | |||
| 0ef2e64733 | |||
| 3a8e121d43 | |||
| 51d341b88c | |||
| 7dd70b8e31 | |||
| 84b332d989 | |||
| bcc6848275 | |||
| 75dd053a40 | |||
| 20f2aa09f2 | |||
| fb8c810b3d | |||
| ad21cf4243 | |||
| 1e45cfff67 | |||
| 0280600a47 | |||
| 571ad518dc | |||
| fe37a25cf1 | |||
| e06138628c | |||
| 1ed0edd158 | |||
| 49dbc46082 | |||
| a16a4adc09 | |||
| b4ab1cbd56 | |||
| 6faa63f0d0 | |||
| 2b44af427f | |||
| 11f7401bc2 | |||
| db7b5180dd | |||
| 5b4e56252c | |||
| e3c71f77de | |||
| b09824faec | |||
| c69bc24598 | |||
| 0cf17e1c63 | |||
| feac803491 | |||
| 4aacec30d8 | |||
| b459a2f7a9 | |||
| ca8ede65f0 | |||
| 9a177c46e1 | |||
| d49e858d32 | |||
| 20bea9cd7f | |||
| d7afa5dcf2 | |||
| 22e816bf86 | |||
| a7709d489c | |||
| 3240616808 | |||
| 18dfc997b8 | |||
| 92d0b6addf | |||
| b9f83d4d61 | |||
| 9c16826ad3 | |||
| f305745295 | |||
| df4d0ad3fd | |||
| 9034d1dc71 | |||
| 537172d8ce | |||
| 20b2e4b3dd | |||
| fc22586752 | |||
| 646440eba3 | |||
| 53e5579326 | |||
| 29a1630d0f | |||
| 171f4ab2ae | |||
| a86043a2ec | |||
| 3947da2cf1 | |||
| 17caab6563 | |||
| a5ae071a03 | |||
| 9c33da7b8d | |||
| 94d31743b0 | |||
| 70db618c6e | |||
| 363a650dfa | |||
| b6e2634537 | |||
| 23146c8dae | |||
| 9f424f2fc0 | |||
| 0715fc5498 | |||
| f9fddd6663 | |||
| 58b60b84fd | |||
| 86aef3319f | |||
| 63d017fc21 | |||
| 0015b3d43d | |||
| 9c4d44c057 | |||
| 800c7fbe11 | |||
| 291ba24229 | |||
| c52ce6bb49 | |||
| ffa4096390 | |||
| bcddd4ce77 | |||
| 017872f71b | |||
| f2b6fc6948 | |||
| acff8a0ece | |||
| bfb660275e | |||
| 472cfe1437 | |||
| 8b7efe27c1 | |||
| 71249f4f88 | |||
| d6ae48bc58 | |||
| a963d49306 | |||
| 7e670ce0a8 | |||
| ec3be40ddd | |||
| 94197cbcb9 | |||
| 83f77af2ab | |||
| 3ee6d98905 | |||
| 967cbf814b | |||
| a96cd546c8 | |||
| eb33d4f1c2 | |||
| 4253956326 | |||
| d6b05bf337 |
@@ -1,40 +0,0 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(npm install:*)",
|
||||
"Bash(npm test:*)",
|
||||
"Skill(building-agents-construction)",
|
||||
"Skill(building-agents-construction:*)",
|
||||
"Bash(PYTHONPATH=core:exports pytest:*)",
|
||||
"mcp__agent-builder__create_session",
|
||||
"mcp__agent-builder__get_session_status",
|
||||
"mcp__agent-builder__set_goal",
|
||||
"mcp__agent-builder__list_mcp_servers",
|
||||
"mcp__agent-builder__test_node",
|
||||
"mcp__agent-builder__add_node",
|
||||
"mcp__agent-builder__add_edge",
|
||||
"mcp__agent-builder__validate_graph",
|
||||
"Bash(ruff check:*)",
|
||||
"Bash(PYTHONPATH=core:exports python:*)",
|
||||
"mcp__agent-builder__list_tests",
|
||||
"mcp__agent-builder__generate_constraint_tests",
|
||||
"Bash(python -m agent:*)",
|
||||
"Bash(python agent.py:*)",
|
||||
"Bash(python -c:*)",
|
||||
"Bash(done)",
|
||||
"Bash(xargs cat:*)",
|
||||
"mcp__agent-builder__list_mcp_tools",
|
||||
"mcp__agent-builder__add_mcp_server",
|
||||
"mcp__agent-builder__check_missing_credentials",
|
||||
"mcp__agent-builder__store_credential",
|
||||
"mcp__agent-builder__list_stored_credentials",
|
||||
"mcp__agent-builder__delete_stored_credential",
|
||||
"mcp__agent-builder__verify_credentials",
|
||||
"Bash(PYTHONPATH=/home/timothy/oss/hive/core:/home/timothy/oss/hive/exports python:*)",
|
||||
"Bash(PYTHONPATH=core:exports:tools/src python -m hubspot_input:*)",
|
||||
"mcp__agent-builder__export_graph"
|
||||
]
|
||||
},
|
||||
"enabledMcpjsonServers": ["agent-builder", "tools"],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
+2
-12
@@ -218,19 +218,9 @@ class OnlineResearchAgent:
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
# Load MCP servers (always load, needed for tool validation)
|
||||
agent_dir = Path(__file__).parent
|
||||
mcp_config_path = agent_dir / "mcp_servers.json"
|
||||
|
||||
mcp_config_path = Path(__file__).parent / "mcp_servers.json"
|
||||
if mcp_config_path.exists():
|
||||
with open(mcp_config_path) as f:
|
||||
mcp_servers = json.load(f)
|
||||
|
||||
for server_config in mcp_servers.get("servers", []):
|
||||
# Resolve relative cwd paths
|
||||
cwd = server_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
server_config["cwd"] = str(agent_dir / cwd)
|
||||
tool_registry.register_mcp_server(server_config)
|
||||
tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
llm = None
|
||||
if not mock_mode:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
---
|
||||
name: setup-credentials
|
||||
description: Set up and install credentials for an agent. Detects missing credentials from agent config, collects them from the user, and stores them securely in the encrypted credential store at ~/.hive/credentials.
|
||||
description: Set up and install credentials for an agent. Detects missing credentials from agent config, collects them from the user, and stores them securely in the local encrypted store at ~/.hive/credentials.
|
||||
license: Apache-2.0
|
||||
metadata:
|
||||
author: hive
|
||||
version: "2.1"
|
||||
version: "2.2"
|
||||
type: utility
|
||||
---
|
||||
|
||||
@@ -31,48 +31,96 @@ Determine which agent needs credentials. The user will either:
|
||||
|
||||
Locate the agent's directory under `exports/{agent_name}/`.
|
||||
|
||||
### Step 2: Detect Required Credentials
|
||||
### Step 2: Detect Required Credentials (Bash-First)
|
||||
|
||||
Read the agent's configuration to determine which tools and node types it uses:
|
||||
Use bash commands to determine what the agent needs and what's already configured. This avoids Python import issues and works even when `HIVE_CREDENTIAL_KEY` is not set.
|
||||
|
||||
```python
|
||||
from core.framework.runner import AgentRunner
|
||||
#### Step 2a: Read Agent Requirements
|
||||
|
||||
runner = AgentRunner.load("exports/{agent_name}")
|
||||
validation = runner.validate()
|
||||
Extract `required_tools` and node types from the agent config:
|
||||
|
||||
# validation.missing_credentials contains env var names
|
||||
# validation.warnings contains detailed messages with help URLs
|
||||
```bash
|
||||
# Get required tools
|
||||
jq -r '.required_tools[]?' exports/{agent_name}/agent.json 2>/dev/null
|
||||
|
||||
# Get node types from graph nodes
|
||||
jq -r '.graph.nodes[]?.node_type' exports/{agent_name}/agent.json 2>/dev/null | sort -u
|
||||
```
|
||||
|
||||
Alternatively, check the credential store directly:
|
||||
Map the extracted tools and node types to credentials by reading the spec files directly:
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore
|
||||
|
||||
# Use encrypted storage (default: ~/.hive/credentials)
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
|
||||
# Check what's available
|
||||
available = store.list_credentials()
|
||||
print(f"Available credentials: {available}")
|
||||
|
||||
# Check if specific credential exists
|
||||
if store.is_available("hubspot"):
|
||||
print("HubSpot credential found")
|
||||
else:
|
||||
print("HubSpot credential missing")
|
||||
```bash
|
||||
# Read all credential specs — each file defines tools, node_types, env_var, and credential_id
|
||||
cat tools/src/aden_tools/credentials/llm.py tools/src/aden_tools/credentials/search.py tools/src/aden_tools/credentials/email.py tools/src/aden_tools/credentials/integrations.py
|
||||
```
|
||||
|
||||
To see all known credential specs (for help URLs and setup instructions):
|
||||
For each `CredentialSpec`, match its `tools` and `node_types` lists against the agent's required tools and node types. Extract the `env_var`, `credential_id`, and `credential_group` for every match. This is the list of needed credentials.
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
#### Step 2b: Check Existing Credential Sources
|
||||
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
print(f"{name}: env_var={spec.env_var}, aden={spec.aden_supported}")
|
||||
For each needed credential, check three sources. A credential is "found" if it exists in ANY of them:
|
||||
|
||||
**1. Encrypted store metadata index** (unencrypted JSON — no decryption key needed):
|
||||
|
||||
```bash
|
||||
cat ~/.hive/credentials/metadata/index.json 2>/dev/null | jq -r '.credentials | keys[]'
|
||||
```
|
||||
|
||||
If a credential ID appears in this list, it is stored in the encrypted store.
|
||||
|
||||
**2. Environment variables:**
|
||||
|
||||
```bash
|
||||
# Check each needed env var, e.g.:
|
||||
printenv ANTHROPIC_API_KEY > /dev/null 2>&1 && echo "ANTHROPIC_API_KEY: set" || echo "ANTHROPIC_API_KEY: not set"
|
||||
printenv BRAVE_SEARCH_API_KEY > /dev/null 2>&1 && echo "BRAVE_SEARCH_API_KEY: set" || echo "BRAVE_SEARCH_API_KEY: not set"
|
||||
```
|
||||
|
||||
**3. Project `.env` file:**
|
||||
|
||||
```bash
|
||||
# Check each needed env var, e.g.:
|
||||
grep -q '^ANTHROPIC_API_KEY=' .env 2>/dev/null && echo "ANTHROPIC_API_KEY: in .env" || echo "ANTHROPIC_API_KEY: not in .env"
|
||||
grep -q '^BRAVE_SEARCH_API_KEY=' .env 2>/dev/null && echo "BRAVE_SEARCH_API_KEY: in .env" || echo "BRAVE_SEARCH_API_KEY: not in .env"
|
||||
```
|
||||
|
||||
#### Step 2c: HIVE_CREDENTIAL_KEY Check
|
||||
|
||||
If any credentials were found in the encrypted store metadata index, verify the encryption key is available. The key is typically persisted to shell config by a previous setup-credentials run.
|
||||
|
||||
Check both the current session AND shell config files:
|
||||
|
||||
```bash
|
||||
# Check 1: Current session
|
||||
printenv HIVE_CREDENTIAL_KEY > /dev/null 2>&1 && echo "session: set" || echo "session: not set"
|
||||
|
||||
# Check 2: Shell config files (where setup-credentials persists it)
|
||||
# Note: check each file individually to avoid non-zero exit when one doesn't exist
|
||||
for f in ~/.zshrc ~/.bashrc ~/.profile; do [ -f "$f" ] && grep -q 'HIVE_CREDENTIAL_KEY' "$f" && echo "$f"; done
|
||||
```
|
||||
|
||||
Decision logic:
|
||||
- **In current session** — no action needed, credentials in the store are usable
|
||||
- **In shell config but NOT in current session** — the key is persisted but this shell hasn't sourced it. Run `source ~/.zshrc` (or `~/.bashrc`), then re-check. Credentials in the store are usable after sourcing.
|
||||
- **Not in session AND not in shell config** — the key was never persisted. Warn the user that credentials in the store cannot be decrypted. Help fix the key situation (recover/re-persist), do NOT re-collect credential values that are already stored.
|
||||
|
||||
#### Step 2d: Compute Missing & Group
|
||||
|
||||
Diff the "needed" credentials against the "found" credentials to get the truly missing list.
|
||||
|
||||
Group related credentials by their `credential_group` field from the spec files. Credentials that share the same non-empty `credential_group` value should be presented as a single setup step rather than asking for each one individually.
|
||||
|
||||
**If nothing is missing and there's no HIVE_CREDENTIAL_KEY issue:** Report all credentials as configured and skip Steps 3-5. Example:
|
||||
|
||||
```
|
||||
All required credentials are already configured:
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — found in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — found in environment
|
||||
Your agent is ready to run!
|
||||
```
|
||||
|
||||
**If credentials are missing:** Continue to Step 3 with only the missing ones.
|
||||
|
||||
### Step 3: Present Auth Options for Each Missing Credential
|
||||
|
||||
For each missing credential, check what authentication methods are available:
|
||||
@@ -104,7 +152,7 @@ Present the available options using AskUserQuestion:
|
||||
```
|
||||
Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
|
||||
1) Aden Authorization Server (Recommended)
|
||||
1) Aden Platform (OAuth) (Recommended)
|
||||
Secure OAuth2 flow via integration.adenhq.com
|
||||
- Quick setup with automatic token refresh
|
||||
- No need to manage API keys manually
|
||||
@@ -114,7 +162,7 @@ Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
- Requires creating a HubSpot Private App
|
||||
- Full control over scopes and permissions
|
||||
|
||||
3) Custom Credential Store (Advanced)
|
||||
3) Local Credential Setup (Advanced)
|
||||
Programmatic configuration for CI/CD
|
||||
- For automated deployments
|
||||
- Requires manual API calls
|
||||
@@ -122,7 +170,7 @@ Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
|
||||
### Step 4: Execute Auth Flow Based on User Choice
|
||||
|
||||
#### Option 1: Aden Authorization Server
|
||||
#### Option 1: Aden Platform (OAuth)
|
||||
|
||||
This is the recommended flow for supported integrations (HubSpot, etc.).
|
||||
|
||||
@@ -174,7 +222,7 @@ shell_type = detect_shell() # 'bash', 'zsh', or 'unknown'
|
||||
success, config_path = add_env_var_to_shell_config(
|
||||
"ADEN_API_KEY",
|
||||
user_provided_key,
|
||||
comment="Aden authorization server API key"
|
||||
comment="Aden Platform (OAuth) API key"
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -313,7 +361,7 @@ if not result.valid:
|
||||
# 2. Continue anyway (not recommended)
|
||||
```
|
||||
|
||||
**4.2d. Store in Encrypted Credential Store**
|
||||
**4.2d. Store in Local Encrypted Store**
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore, CredentialObject, CredentialKey
|
||||
@@ -340,7 +388,7 @@ store.save_credential(cred)
|
||||
export HUBSPOT_ACCESS_TOKEN="the-value"
|
||||
```
|
||||
|
||||
#### Option 3: Custom Credential Store (Advanced)
|
||||
#### Option 3: Local Credential Setup (Advanced)
|
||||
|
||||
For programmatic/CI/CD setups.
|
||||
|
||||
@@ -408,10 +456,14 @@ Report the result to the user.
|
||||
|
||||
Health checks validate credentials by making lightweight API calls:
|
||||
|
||||
| Credential | Endpoint | What It Checks |
|
||||
| -------------- | --------------------------------------- | --------------------------------- |
|
||||
| `hubspot` | `GET /crm/v3/objects/contacts?limit=1` | Bearer token validity, CRM scopes |
|
||||
| `brave_search` | `GET /res/v1/web/search?q=test&count=1` | API key validity |
|
||||
| Credential | Endpoint | What It Checks |
|
||||
| --------------- | --------------------------------------- | ---------------------------------- |
|
||||
| `anthropic` | `POST /v1/messages` | API key validity |
|
||||
| `brave_search` | `GET /res/v1/web/search?q=test&count=1` | API key validity |
|
||||
| `google_search` | `GET /customsearch/v1?q=test&num=1` | API key + CSE ID validity |
|
||||
| `github` | `GET /user` | Token validity, user identity |
|
||||
| `hubspot` | `GET /crm/v3/objects/contacts?limit=1` | Bearer token validity, CRM scopes |
|
||||
| `resend` | `GET /domains` | API key validity |
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import check_credential_health, HealthCheckResult
|
||||
@@ -424,7 +476,7 @@ result: HealthCheckResult = check_credential_health("hubspot", token_value)
|
||||
|
||||
## Encryption Key (HIVE_CREDENTIAL_KEY)
|
||||
|
||||
The encrypted credential store requires `HIVE_CREDENTIAL_KEY` to encrypt/decrypt credentials.
|
||||
The local encrypted store requires `HIVE_CREDENTIAL_KEY` to encrypt/decrypt credentials.
|
||||
|
||||
- If the user doesn't have one, `EncryptedFileStorage` will auto-generate one and log it
|
||||
- The user MUST persist this key (e.g., in `~/.bashrc` or a secrets manager)
|
||||
@@ -443,7 +495,7 @@ If `HIVE_CREDENTIAL_KEY` is not set:
|
||||
- **NEVER** store credentials in plaintext files, git-tracked files, or agent configs
|
||||
- **NEVER** hardcode credentials in source code
|
||||
- **ALWAYS** use `SecretStr` from Pydantic when handling credential values in Python
|
||||
- **ALWAYS** use the encrypted credential store (`~/.hive/credentials`) for persistence
|
||||
- **ALWAYS** use the local encrypted store (`~/.hive/credentials`) for persistence
|
||||
- **ALWAYS** run health checks before storing credentials (when possible)
|
||||
- **ALWAYS** verify credentials were stored by re-running validation, not by reading them back
|
||||
- When modifying `~/.bashrc` or `~/.zshrc`, confirm with the user first
|
||||
@@ -456,7 +508,8 @@ All credential specs are defined in `tools/src/aden_tools/credentials/`:
|
||||
| ----------------- | ------------- | --------------------------------------------- | -------------- |
|
||||
| `llm.py` | LLM Providers | `anthropic` | No |
|
||||
| `search.py` | Search Tools | `brave_search`, `google_search`, `google_cse` | No |
|
||||
| `integrations.py` | Integrations | `hubspot` | Yes |
|
||||
| `email.py` | Email | `resend` | No |
|
||||
| `integrations.py` | Integrations | `github`, `hubspot` | No / Yes |
|
||||
|
||||
**Note:** Additional LLM providers (Cerebras, Groq, OpenAI) are handled by LiteLLM via environment
|
||||
variables (`CEREBRAS_API_KEY`, `GROQ_API_KEY`, `OPENAI_API_KEY`) but are not yet in CREDENTIAL_SPECS.
|
||||
@@ -507,66 +560,90 @@ token = store.get_key("hubspot", "access_token")
|
||||
## Example Session
|
||||
|
||||
```
|
||||
User: /setup-credentials for my hubspot-agent
|
||||
User: /setup-credentials for my research-agent
|
||||
|
||||
Agent: Let me check what credentials your hubspot-agent needs.
|
||||
Agent: Let me check what credentials your research-agent needs.
|
||||
|
||||
[Runs validation, finds ANTHROPIC_API_KEY and HUBSPOT_ACCESS_TOKEN missing]
|
||||
[Reads agent config]
|
||||
$ jq -r '.required_tools[]?' exports/research-agent/agent.json
|
||||
web_search
|
||||
google_search
|
||||
|
||||
Agent: Your hubspot-agent requires 2 credentials:
|
||||
$ jq -r '.graph.nodes[]?.node_type' exports/research-agent/agent.json | sort -u
|
||||
llm_tool_use
|
||||
|
||||
1. ANTHROPIC_API_KEY - Anthropic API key for LLM calls
|
||||
2. HUBSPOT_ACCESS_TOKEN - HubSpot access token
|
||||
[Maps tools/nodes to credentials using lookup table]
|
||||
Needed: anthropic, brave_search, google_search, google_cse
|
||||
|
||||
Let's set these up one at a time.
|
||||
[Checks encrypted store metadata index]
|
||||
$ cat ~/.hive/credentials/metadata/index.json | jq -r '.credentials | keys[]'
|
||||
anthropic
|
||||
brave_search
|
||||
|
||||
--- Setting up ANTHROPIC_API_KEY ---
|
||||
[Checks environment variables]
|
||||
$ printenv ANTHROPIC_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv BRAVE_SEARCH_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv GOOGLE_API_KEY > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
$ printenv GOOGLE_CSE_ID > /dev/null 2>&1 && echo "set" || echo "not set"
|
||||
not set
|
||||
|
||||
This credential only supports direct API key entry.
|
||||
[Checks HIVE_CREDENTIAL_KEY since credentials found in store]
|
||||
$ printenv HIVE_CREDENTIAL_KEY > /dev/null 2>&1 && echo "session: set" || echo "session: not set"
|
||||
session: not set
|
||||
$ for f in ~/.zshrc ~/.bashrc ~/.profile; do [ -f "$f" ] && grep -q 'HIVE_CREDENTIAL_KEY' "$f" && echo "$f"; done
|
||||
/Users/user/.zshrc
|
||||
|
||||
[Shows instructions for getting Anthropic API key]
|
||||
[AskUserQuestion: "Please provide your Anthropic API key:"]
|
||||
[User provides key]
|
||||
[Stores in credential store, exports to env]
|
||||
[Key is in shell config but not current session — sourcing it]
|
||||
$ source ~/.zshrc
|
||||
|
||||
--- Setting up HUBSPOT_ACCESS_TOKEN ---
|
||||
[Computes missing credentials]
|
||||
Found:
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — in encrypted store
|
||||
Missing:
|
||||
✗ google_search (GOOGLE_API_KEY)
|
||||
✗ google_cse (GOOGLE_CSE_ID)
|
||||
|
||||
[AskUserQuestion with 3 options]
|
||||
Agent: 2 of 4 required credentials are already configured. Only Google Custom
|
||||
Search needs setup (2 values as a single group).
|
||||
|
||||
Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
1) Aden Authorization Server (Recommended)
|
||||
2) Direct API Key
|
||||
3) Custom Credential Store (Advanced)
|
||||
--- Setting up Google Custom Search (google_search + google_cse) ---
|
||||
|
||||
User: 1 (Aden)
|
||||
This requires two values that work together.
|
||||
|
||||
Agent: Great! Let me check if you have an Aden API key configured...
|
||||
First, the Google API Key:
|
||||
1. Go to https://console.cloud.google.com/apis/credentials
|
||||
2. Create a new project (or select an existing one)
|
||||
3. Enable the "Custom Search API" from the API Library
|
||||
4. Go to Credentials > Create Credentials > API Key
|
||||
5. Copy the generated API key
|
||||
|
||||
[Checks for ADEN_API_KEY, not found]
|
||||
[Opens browser to integration.adenhq.com/setup]
|
||||
|
||||
Agent: I've opened Aden's setup page. Please:
|
||||
1. Sign in or create an account
|
||||
2. Connect your HubSpot account (OAuth happens on Aden's site)
|
||||
3. Copy the API key shown after setup
|
||||
|
||||
[AskUserQuestion: "Please provide your Aden API key:"]
|
||||
[AskUserQuestion: "Please provide your Google API key:"]
|
||||
[User provides key]
|
||||
|
||||
Agent: Would you like me to save this to your shell config (~/.zshrc) for future sessions?
|
||||
Now, the Custom Search Engine ID:
|
||||
1. Go to https://programmablesearchengine.google.com/controlpanel/all
|
||||
2. Click "Add" to create a new search engine
|
||||
3. Under "What to search", select "Search the entire web"
|
||||
4. Give your search engine a name
|
||||
5. Click "Create"
|
||||
6. Copy the Search Engine ID (cx value)
|
||||
|
||||
[User approves]
|
||||
[Saves to ~/.zshrc and ~/.hive/configuration.json]
|
||||
[AskUserQuestion: "Please provide your Google CSE ID:"]
|
||||
[User provides ID]
|
||||
|
||||
Agent: Let me sync your credentials from Aden...
|
||||
[Runs health check with both values - GET /customsearch/v1?q=test&num=1 → 200 OK]
|
||||
[Stores both in local encrypted store, exports to env]
|
||||
|
||||
[Syncs credentials from Aden server - OAuth already done on Aden's side]
|
||||
[Runs health check]
|
||||
|
||||
Agent: HubSpot credentials validated successfully!
|
||||
✓ Google Custom Search credentials valid
|
||||
|
||||
All credentials are now configured:
|
||||
- ANTHROPIC_API_KEY: Stored in encrypted credential store
|
||||
- HUBSPOT_ACCESS_TOKEN: Synced from Aden (OAuth completed on Aden)
|
||||
- Validation passed - your agent is ready to run!
|
||||
✓ anthropic (ANTHROPIC_API_KEY) — already in encrypted store
|
||||
✓ brave_search (BRAVE_SEARCH_API_KEY) — already in encrypted store
|
||||
✓ google_search (GOOGLE_API_KEY) — stored in encrypted store
|
||||
✓ google_cse (GOOGLE_CSE_ID) — stored in encrypted store
|
||||
Your agent is ready to run!
|
||||
```
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Report a bug to help us improve
|
||||
title: '[Bug]: '
|
||||
labels: bug
|
||||
title: "[Bug]: "
|
||||
labels: bug, enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Describe the Bug
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
---
|
||||
name: Feature Request
|
||||
about: Suggest a new feature or enhancement
|
||||
title: '[Feature]: '
|
||||
title: "[Feature]: "
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Problem Statement
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
---
|
||||
name: Integration Request
|
||||
about: Suggest a new integration
|
||||
title: "[Integration]:"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Service
|
||||
|
||||
Name and brief description of the service and what it enables agents to do.
|
||||
|
||||
**Description:** [e.g., "API key for Slack Bot" — short one-liner for the credential spec]
|
||||
|
||||
## Credential Identity
|
||||
|
||||
- **credential_id:** [e.g., `slack`]
|
||||
- **env_var:** [e.g., `SLACK_BOT_TOKEN`]
|
||||
- **credential_key:** [e.g., `access_token`, `api_key`, `bot_token`]
|
||||
|
||||
## Tools
|
||||
|
||||
Tool function names that require this credential:
|
||||
|
||||
- [e.g., `slack_send_message`]
|
||||
- [e.g., `slack_list_channels`]
|
||||
|
||||
## Auth Methods
|
||||
|
||||
- **Direct API key supported:** Yes / No
|
||||
- **Aden OAuth supported:** Yes / No
|
||||
|
||||
If Aden OAuth is supported, describe the OAuth scopes/permissions required.
|
||||
|
||||
## How to Get the Credential
|
||||
|
||||
Link where users obtain the key/token:
|
||||
|
||||
[e.g., https://api.slack.com/apps]
|
||||
|
||||
Step-by-step instructions:
|
||||
|
||||
1. Go to ...
|
||||
2. Create a ...
|
||||
3. Select scopes/permissions: ...
|
||||
4. Copy the key/token
|
||||
|
||||
## Health Check
|
||||
|
||||
A lightweight API call to validate the credential (no writes, no charges).
|
||||
|
||||
- **Endpoint:** [e.g., `https://slack.com/api/auth.test`]
|
||||
- **Method:** [e.g., `GET` or `POST`]
|
||||
- **Auth header:** [e.g., `Authorization: Bearer {token}` or `X-Api-Key: {key}`]
|
||||
- **Parameters (if any):** [e.g., `?limit=1`]
|
||||
- **200 means:** [e.g., key is valid]
|
||||
- **401 means:** [e.g., invalid or expired]
|
||||
- **429 means:** [e.g., rate limited but key is valid]
|
||||
|
||||
## Credential Group
|
||||
|
||||
Does this require multiple credentials configured together? (e.g., Google Custom Search needs
|
||||
both an API key and a CSE ID)
|
||||
|
||||
- [ ] No, single credential
|
||||
- [ ] Yes — list the other credential IDs in the group:
|
||||
|
||||
## Additional Context
|
||||
|
||||
Links to API docs, rate limits, free tier availability, or anything else relevant.
|
||||
+40
-19
@@ -21,23 +21,22 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
run: uv sync --project core --group dev
|
||||
|
||||
- name: Ruff lint
|
||||
run: |
|
||||
ruff check core/
|
||||
ruff check tools/
|
||||
uv run --project core ruff check core/
|
||||
uv run --project core ruff check tools/
|
||||
|
||||
- name: Ruff format
|
||||
run: |
|
||||
ruff format --check core/
|
||||
ruff format --check tools/
|
||||
uv run --project core ruff format --check core/
|
||||
uv run --project core ruff format --check tools/
|
||||
|
||||
test:
|
||||
name: Test Python Framework
|
||||
@@ -52,23 +51,23 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd core
|
||||
pytest tests/ -v
|
||||
uv run pytest tests/ -v
|
||||
|
||||
validate:
|
||||
name: Validate Agent Exports
|
||||
test-tools:
|
||||
name: Test Tools
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -76,13 +75,35 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies and run tests
|
||||
run: |
|
||||
cd tools
|
||||
uv sync --extra dev
|
||||
uv run pytest tests/ -v
|
||||
|
||||
validate:
|
||||
name: Validate Agent Exports
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test, test-tools]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Validate exported agents
|
||||
run: |
|
||||
|
||||
@@ -80,7 +80,13 @@ jobs:
|
||||
- help wanted: Extra attention is needed (if issue needs community input)
|
||||
- backlog: Tracked for the future, but not currently planned or prioritized
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug" and "help wanted").
|
||||
### 6. Estimate size (if NOT a duplicate, spam, or invalid)
|
||||
Apply exactly ONE size label to help contributors match their capacity to the task:
|
||||
- "size: small": Docs, typos, single-file fixes, config changes
|
||||
- "size: medium": Bug fixes with tests, adding a single tool, changes within one package
|
||||
- "size: large": Cross-package changes (core + tools), new modules, complex logic, architectural refactors
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug", "size: small", and "good first issue").
|
||||
|
||||
## Tools Available:
|
||||
- mcp__github__get_issue: Get issue details
|
||||
|
||||
@@ -21,18 +21,19 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd core
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd core
|
||||
pytest tests/ -v
|
||||
uv run pytest tests/ -v
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
|
||||
+6
-1
@@ -69,4 +69,9 @@ exports/*
|
||||
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.venv
|
||||
.claude/settings.local.json
|
||||
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": ".venv/bin/python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core",
|
||||
"env": {
|
||||
"PYTHONPATH": "../tools/src"
|
||||
}
|
||||
"command": "uv",
|
||||
"args": ["run", "-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core"
|
||||
},
|
||||
"tools": {
|
||||
"command": ".venv/bin/python",
|
||||
"args": ["mcp_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"env": {
|
||||
"PYTHONPATH": "src:../core"
|
||||
}
|
||||
"command": "uv",
|
||||
"args": ["run", "mcp_server.py", "--stdio"],
|
||||
"cwd": "tools"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+20
-22
@@ -44,7 +44,7 @@ Aden Agent Framework is a Python-based system for building goal-driven, self-imp
|
||||
Ensure you have installed:
|
||||
|
||||
- **Python 3.11+** - [Download](https://www.python.org/downloads/) (3.12 or 3.13 recommended)
|
||||
- **pip** - Package installer for Python (comes with Python)
|
||||
- **uv** - Python package manager ([Install](https://docs.astral.sh/uv/getting-started/installation/))
|
||||
- **git** - Version control
|
||||
- **Claude Code** - [Install](https://docs.anthropic.com/claude/docs/claude-code) (optional, for using building skills)
|
||||
|
||||
@@ -52,7 +52,7 @@ Verify installation:
|
||||
|
||||
```bash
|
||||
python --version # Should be 3.11+
|
||||
pip --version # Should be latest
|
||||
uv --version # Should be latest
|
||||
git --version # Any recent version
|
||||
```
|
||||
|
||||
@@ -128,8 +128,12 @@ hive/ # Repository root
|
||||
│
|
||||
├── .github/ # GitHub configuration
|
||||
│ ├── workflows/
|
||||
│ │ ├── ci.yml # Runs on every PR
|
||||
│ │ └── release.yml # Runs on tags
|
||||
│ │ ├── ci.yml # Lint, test, validate on every PR
|
||||
│ │ ├── release.yml # Runs on tags
|
||||
│ │ ├── pr-requirements.yml # PR requirement checks
|
||||
│ │ ├── pr-check-command.yml # PR check commands
|
||||
│ │ ├── claude-issue-triage.yml # Automated issue triage
|
||||
│ │ └── auto-close-duplicates.yml # Close duplicate issues
|
||||
│ ├── ISSUE_TEMPLATE/ # Bug report & feature request templates
|
||||
│ ├── PULL_REQUEST_TEMPLATE.md # PR description template
|
||||
│ └── CODEOWNERS # Auto-assign reviewers
|
||||
@@ -166,7 +170,6 @@ hive/ # Repository root
|
||||
│ │ ├── testing/ # Testing utilities
|
||||
│ │ └── __init__.py
|
||||
│ ├── pyproject.toml # Package metadata and dependencies
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ ├── README.md # Framework documentation
|
||||
│ ├── MCP_INTEGRATION_GUIDE.md # MCP server integration guide
|
||||
│ └── docs/ # Protocol documentation
|
||||
@@ -182,7 +185,6 @@ hive/ # Repository root
|
||||
│ │ ├── mcp_server.py # HTTP MCP server
|
||||
│ │ └── __init__.py
|
||||
│ ├── pyproject.toml # Package metadata
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ └── README.md # Tools documentation
|
||||
│
|
||||
├── exports/ # AGENT PACKAGES (user-created, gitignored)
|
||||
@@ -191,14 +193,16 @@ hive/ # Repository root
|
||||
├── docs/ # Documentation
|
||||
│ ├── getting-started.md # Quick start guide
|
||||
│ ├── configuration.md # Configuration reference
|
||||
│ ├── architecture.md # System architecture
|
||||
│ └── articles/ # Technical articles
|
||||
│ ├── architecture/ # System architecture
|
||||
│ ├── articles/ # Technical articles
|
||||
│ ├── quizzes/ # Developer quizzes
|
||||
│ └── i18n/ # Translations
|
||||
│
|
||||
├── scripts/ # Build & utility scripts
|
||||
│ ├── setup-python.sh # Python environment setup
|
||||
│ └── setup.sh # Legacy setup script
|
||||
│
|
||||
├── quickstart.sh # Install Claude Code skills
|
||||
├── quickstart.sh # Interactive setup wizard
|
||||
├── ENVIRONMENT_SETUP.md # Complete Python setup guide
|
||||
├── README.md # Project overview
|
||||
├── DEVELOPER.md # This file
|
||||
@@ -375,7 +379,7 @@ def test_ticket_categorization():
|
||||
- **PEP 8** - Follow Python style guide
|
||||
- **Type hints** - Use for function signatures and class attributes
|
||||
- **Docstrings** - Document classes and public functions
|
||||
- **Black** - Code formatter (run with `black .`)
|
||||
- **Ruff** - Linter and formatter (run with `make check`)
|
||||
|
||||
```python
|
||||
# Good
|
||||
@@ -509,8 +513,8 @@ chore(deps): update React to 18.2.0
|
||||
|
||||
1. Create a feature branch from `main`
|
||||
2. Make your changes with clear commits
|
||||
3. Run tests locally: `PYTHONPATH=core:exports python -m pytest`
|
||||
4. Run linting: `black --check .`
|
||||
3. Run tests locally: `make test`
|
||||
4. Run linting: `make check`
|
||||
5. Push and create a PR
|
||||
6. Fill out the PR template
|
||||
7. Request review from CODEOWNERS
|
||||
@@ -528,16 +532,11 @@ chore(deps): update React to 18.2.0
|
||||
```bash
|
||||
# Add to core framework
|
||||
cd core
|
||||
pip install <package>
|
||||
# Then add to requirements.txt or pyproject.toml
|
||||
uv add <package>
|
||||
|
||||
# Add to tools package
|
||||
cd tools
|
||||
pip install <package>
|
||||
# Then add to requirements.txt or pyproject.toml
|
||||
|
||||
# Reinstall in editable mode
|
||||
pip install -e .
|
||||
uv add <package>
|
||||
```
|
||||
|
||||
### Creating a New Agent
|
||||
@@ -670,9 +669,8 @@ cat .env
|
||||
# Or check shell environment
|
||||
echo $ANTHROPIC_API_KEY
|
||||
|
||||
# Copy from .env.example if needed
|
||||
cp .env.example .env
|
||||
# Then edit .env with your API keys
|
||||
# Create .env if needed
|
||||
# Then add your API keys
|
||||
```
|
||||
|
||||
|
||||
|
||||
+86
-3
@@ -21,6 +21,43 @@ This will:
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
## Quick Setup (Windows – PowerShell)
|
||||
|
||||
Windows users can use the native PowerShell setup script.
|
||||
|
||||
Before running the script, allow script execution for the current session:
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
```
|
||||
|
||||
Run setup from the project root:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
This will:
|
||||
|
||||
- Check Python version (requires 3.11+)
|
||||
- Create a local `.venv` virtual environment
|
||||
- Install the core framework package (`framework`)
|
||||
- Install the tools package (`aden_tools`)
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
After setup, activate the virtual environment:
|
||||
|
||||
```powershell
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Set `PYTHONPATH` (required in every new PowerShell session):
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
```
|
||||
|
||||
## Alpine Linux Setup
|
||||
|
||||
If you are using Alpine Linux (e.g., inside a Docker container), you must install system dependencies and use a virtual environment before running the setup script:
|
||||
@@ -100,6 +137,12 @@ For running agents with real LLMs:
|
||||
export ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
$env:ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
## Running Agents
|
||||
|
||||
All agent commands must be run from the project root with `PYTHONPATH` set:
|
||||
@@ -109,9 +152,14 @@ All agent commands must be run from the project root with `PYTHONPATH` set:
|
||||
PYTHONPATH=core:exports python -m agent_name COMMAND
|
||||
```
|
||||
|
||||
### Example Commands
|
||||
Windows (PowerShell):
|
||||
|
||||
After building an agent via `/building-agents-construction`, use these commands:
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
python -m agent_name COMMAND
|
||||
```
|
||||
|
||||
### Example: Support Ticket Agent
|
||||
|
||||
```bash
|
||||
# Validate agent structure
|
||||
@@ -248,6 +296,14 @@ source .venv/bin/activate
|
||||
PYTHONPATH=core:exports python -m your_agent_name demo
|
||||
```
|
||||
|
||||
### PowerShell: “running scripts is disabled on this system”
|
||||
|
||||
Run once per session:
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'framework'"
|
||||
|
||||
**Solution:** Install the core package:
|
||||
@@ -270,6 +326,12 @@ Or run the setup script:
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'openai.\_models'"
|
||||
|
||||
**Cause:** Outdated `openai` package (0.27.x) incompatible with `litellm`
|
||||
@@ -284,12 +346,21 @@ pip install --upgrade "openai>=1.0.0"
|
||||
|
||||
**Cause:** Not running from project root, missing PYTHONPATH, or agent not yet created
|
||||
|
||||
**Solution:** Ensure you're in the project root directory, have built an agent, and use:
|
||||
**Solution:** Ensure you're in `/hive/` and use:
|
||||
|
||||
Linux/macOS:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m your_agent_name validate
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
python -m support_ticket_agent validate
|
||||
```
|
||||
|
||||
### Agent imports fail with "broken installation"
|
||||
|
||||
**Symptom:** `pip list` shows packages pointing to non-existent directories
|
||||
@@ -304,6 +375,12 @@ pip uninstall -y framework tools
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
## Package Structure
|
||||
|
||||
The Hive framework consists of three Python packages:
|
||||
@@ -402,6 +479,12 @@ This design allows agents in `exports/` to be:
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
### 2. Build Agent (Claude Code)
|
||||
|
||||
```
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
[](https://github.com/adenhq/hive/blob/main/LICENSE)
|
||||
[](https://www.ycombinator.com/companies/aden)
|
||||
[](https://hub.docker.com/u/adenhq)
|
||||
[](https://discord.com/invite/MXE49hrKDk)
|
||||
[](https://x.com/aden_hq)
|
||||
[](https://www.linkedin.com/company/teamaden/)
|
||||
@@ -40,6 +39,31 @@ Build reliable, self-improving AI agents without hardcoding workflows. Define yo
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
Hive is designed for developers and teams who want to build **production-grade AI agents** without manually wiring complex workflows.
|
||||
|
||||
Hive is a good fit if you:
|
||||
|
||||
- Want AI agents that **execute real business processes**, not demos
|
||||
- Prefer **goal-driven development** over hardcoded workflows
|
||||
- Need **self-healing and adaptive agents** that improve over time
|
||||
- Require **human-in-the-loop control**, observability, and cost limits
|
||||
- Plan to run agents in **production environments**
|
||||
|
||||
Hive may not be the best fit if you’re only experimenting with simple agent chains or one-off scripts.
|
||||
|
||||
## When Should You Use Hive?
|
||||
|
||||
Use Hive when you need:
|
||||
|
||||
- Long-running, autonomous agents
|
||||
- Multi-agent coordination
|
||||
- Continuous improvement based on failures
|
||||
- Strong monitoring, safety, and budget controls
|
||||
- A framework that evolves with your goals
|
||||
|
||||
|
||||
## What is Aden
|
||||
|
||||
<p align="center">
|
||||
@@ -64,11 +88,13 @@ Aden is a platform for building, deploying, operating, and adapting AI agents:
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
## Prerequisites
|
||||
|
||||
- [Python 3.11+](https://www.python.org/downloads/) for agent development
|
||||
- Python 3.11+ for agent development
|
||||
- Claude Code or Cursor for utilizing agent skills
|
||||
|
||||
> **Note for Windows Users:** It is strongly recommended to use **WSL (Windows Subsystem for Linux)** or **Git Bash** to run this framework. Some core automation scripts may not execute correctly in standard Command Prompt or PowerShell.
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
@@ -112,7 +138,7 @@ Skills are also available in Cursor. To enable:
|
||||
## Features
|
||||
|
||||
- **Goal-Driven Development** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
|
||||
- **Adaptiveness** - Framework captures failures, calibrates accourding to the objectives, and evolves the agent graph
|
||||
- **Adaptiveness** - Framework captures failures, calibrates according to the objectives, and evolves the agent graph
|
||||
- **Dynamic Node Connections** - No predefined edges; connection code is generated by any capable LLM based on your goals
|
||||
- **SDK-Wrapped Nodes** - Every node gets shared memory, local RLM memory, monitoring, tools, and LLM access out of the box
|
||||
- **Human-in-the-Loop** - Intervention nodes that pause execution for human input with configurable timeouts and escalation
|
||||
@@ -122,7 +148,7 @@ Skills are also available in Cursor. To enable:
|
||||
|
||||
## Why Aden
|
||||
|
||||
Traditional agent frameworks require you to manually design workflows, define agent interactions, and handle failures reactively. Aden flips this paradigm—**you describe outcomes, and the system builds itself**.
|
||||
Hive focuses on generating agents that run real business processes rather than generic agents. Instead of requiring you to manually design workflows, define agent interactions, and handle failures reactively, Hive flips the paradigm: **you describe outcomes, and the system builds itself**—delivering an outcome-driven, adaptive experience with an easy-to-use set of tools and integrations.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
@@ -162,7 +188,7 @@ flowchart LR
|
||||
| -------------------------- | -------------------------------------- |
|
||||
| Hardcode agent workflows | Describe goals in natural language |
|
||||
| Manual graph definition | Auto-generated agent graphs |
|
||||
| Reactive error handling | Proactive self-evolution |
|
||||
| Reactive error handling | Outcome-evaluation and adaptiveness |
|
||||
| Static tool configurations | Dynamic SDK-wrapped nodes |
|
||||
| Separate monitoring setup | Built-in real-time observability |
|
||||
| DIY budget management | Integrated cost controls & degradation |
|
||||
@@ -175,9 +201,14 @@ flowchart LR
|
||||
4. **Control Plane Monitors** → Real-time metrics, budget enforcement, policy management
|
||||
5. **Adaptiveness** → On failure, the system evolves the graph and redeploys automatically
|
||||
|
||||
## Development
|
||||
## Run pre-built Agents (Coming Soon)
|
||||
|
||||
### Run a sample agent
|
||||
Aden Hive provides a list of featured agents that you can use and build on top of.
|
||||
|
||||
### Run an agent shared by others
|
||||
Put the agent in `exports/` and run `PYTHONPATH=core:exports python -m your_agent_name run --input '{...}'`
|
||||
|
||||
### Python Agent Development
|
||||
|
||||
For building and running goal-driven agents with the framework:
|
||||
|
||||
@@ -211,7 +242,7 @@ See [ENVIRONMENT_SETUP.md](ENVIRONMENT_SETUP.md) for complete setup instructions
|
||||
|
||||
## Roadmap
|
||||
|
||||
Aden Agent Framework aims to help developers build outcome-oriented, self-adaptive agents. See [ROADMAP.md](ROADMAP.md) for details.
|
||||
Aden Hive Agent Framework aims to help developers build outcome-oriented, self-adaptive agents. See [ROADMAP.md](ROADMAP.md) for details.
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
@@ -338,47 +369,47 @@ This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENS
|
||||
|
||||
## Frequently Asked Questions (FAQ)
|
||||
|
||||
**Q: Does Aden depend on LangChain or other agent frameworks?**
|
||||
**Q: Does Hive depend on LangChain or other agent frameworks?**
|
||||
|
||||
No. Aden is built from the ground up with no dependencies on LangChain, CrewAI, or other agent frameworks. The framework is designed to be lean and flexible, generating agent graphs dynamically rather than relying on predefined components.
|
||||
No. Hive is built from the ground up with no dependencies on LangChain, CrewAI, or other agent frameworks. The framework is designed to be lean and flexible, generating agent graphs dynamically rather than relying on predefined components.
|
||||
|
||||
**Q: What LLM providers does Aden support?**
|
||||
**Q: What LLM providers does Hive support?**
|
||||
|
||||
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
Hive supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
|
||||
**Q: Can I use Aden with local AI models like Ollama?**
|
||||
**Q: Can I use Hive with local AI models like Ollama?**
|
||||
|
||||
Yes! Aden supports local models through LiteLLM. Simply use the model name format `ollama/model-name` (e.g., `ollama/llama3`, `ollama/mistral`) and ensure Ollama is running locally.
|
||||
Yes! Hive supports local models through LiteLLM. Simply use the model name format `ollama/model-name` (e.g., `ollama/llama3`, `ollama/mistral`) and ensure Ollama is running locally.
|
||||
|
||||
**Q: What makes Aden different from other agent frameworks?**
|
||||
**Q: What makes Hive different from other agent frameworks?**
|
||||
|
||||
Aden generates your entire agent system from natural language goals using a coding agent—you don't hardcode workflows or manually define graphs. When agents fail, the framework automatically captures failure data, evolves the agent graph, and redeploys. This self-improving loop is unique to Aden.
|
||||
Hive generates your entire agent system from natural language goals using a coding agent—you don't hardcode workflows or manually define graphs. When agents fail, the framework automatically captures failure data, evolves the agent graph, and redeploys. This self-improving loop is unique to Aden.
|
||||
|
||||
**Q: Is Aden open-source?**
|
||||
**Q: Is Hive open-source?**
|
||||
|
||||
Yes, Aden is fully open-source under the Apache License 2.0. We actively encourage community contributions and collaboration.
|
||||
Yes, Hive is fully open-source under the Apache License 2.0. We actively encourage community contributions and collaboration.
|
||||
|
||||
**Q: Does Aden collect data from users?**
|
||||
**Q: Does Hive collect data from users?**
|
||||
|
||||
Aden collects telemetry data for monitoring and observability purposes, including token usage, latency metrics, and cost tracking. Content capture (prompts and responses) is configurable and stored with team-scoped data isolation. All data stays within your infrastructure when self-hosted.
|
||||
Hive collects telemetry data for monitoring and observability purposes, including token usage, latency metrics, and cost tracking. Content capture (prompts and responses) is configurable and stored with team-scoped data isolation. All data stays within your infrastructure when self-hosted.
|
||||
|
||||
**Q: What deployment options does Aden support?**
|
||||
**Q: What deployment options does Hive support?**
|
||||
|
||||
Aden supports self-hosted deployments via Python packages. See the [Environment Setup Guide](ENVIRONMENT_SETUP.md) for installation instructions. Cloud deployment options and Kubernetes-ready configurations are on the roadmap.
|
||||
Hive supports self-hosted deployments via Python packages. See the [Environment Setup Guide](ENVIRONMENT_SETUP.md) for installation instructions. Cloud deployment options and Kubernetes-ready configurations are on the roadmap.
|
||||
|
||||
**Q: Can Aden handle complex, production-scale use cases?**
|
||||
**Q: Can Hive handle complex, production-scale use cases?**
|
||||
|
||||
Yes. Aden is explicitly designed for production environments with features like automatic failure recovery, real-time observability, cost controls, and horizontal scaling support. The framework handles both simple automations and complex multi-agent workflows.
|
||||
Yes. Hive is explicitly designed for production environments with features like automatic failure recovery, real-time observability, cost controls, and horizontal scaling support. The framework handles both simple automations and complex multi-agent workflows.
|
||||
|
||||
**Q: Does Aden support human-in-the-loop workflows?**
|
||||
**Q: Does Hive support human-in-the-loop workflows?**
|
||||
|
||||
Yes, Aden fully supports human-in-the-loop workflows through intervention nodes that pause execution for human input. These include configurable timeouts and escalation policies, allowing seamless collaboration between human experts and AI agents.
|
||||
Yes, Hive fully supports human-in-the-loop workflows through intervention nodes that pause execution for human input. These include configurable timeouts and escalation policies, allowing seamless collaboration between human experts and AI agents.
|
||||
|
||||
**Q: What monitoring and debugging tools does Aden provide?**
|
||||
**Q: What monitoring and debugging tools does Hive provide?**
|
||||
|
||||
Aden includes comprehensive observability features: real-time WebSocket streaming for live agent execution monitoring, TimescaleDB-powered analytics for cost and performance metrics, health check endpoints for Kubernetes integration, and MCP tools for agent execution, including file operations, web search, data processing, and more.
|
||||
Hive includes comprehensive observability features: real-time WebSocket streaming for live agent execution monitoring, TimescaleDB-powered analytics for cost and performance metrics, health check endpoints for Kubernetes integration, and MCP tools for agent execution, including file operations, web search, data processing, and more.
|
||||
|
||||
**Q: What programming languages does Aden support?**
|
||||
**Q: What programming languages does Hive support?**
|
||||
|
||||
The Hive framework is built in Python. A JavaScript/TypeScript SDK is on the roadmap.
|
||||
|
||||
@@ -386,9 +417,9 @@ The Hive framework is built in Python. A JavaScript/TypeScript SDK is on the roa
|
||||
|
||||
Yes. Aden's SDK-wrapped nodes provide built-in tool access, and the framework supports flexible tool ecosystems. Agents can integrate with external APIs, databases, and services through the node architecture.
|
||||
|
||||
**Q: How does cost control work in Aden?**
|
||||
**Q: How does cost control work in Hive?**
|
||||
|
||||
Aden provides granular budget controls including spending limits, throttles, and automatic model degradation policies. You can set budgets at the team, agent, or workflow level, with real-time cost tracking and alerts.
|
||||
Hive provides granular budget controls including spending limits, throttles, and automatic model degradation policies. You can set budgets at the team, agent, or workflow level, with real-time cost tracking and alerts.
|
||||
|
||||
**Q: Where can I find examples and documentation?**
|
||||
|
||||
@@ -398,6 +429,14 @@ Visit [docs.adenhq.com](https://docs.adenhq.com/) for complete guides, API refer
|
||||
|
||||
Contributions are welcome! Fork the repository, create your feature branch, implement your changes, and submit a pull request. See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines.
|
||||
|
||||
**Q: When will my team start seeing results from Aden's adaptive agents?**
|
||||
|
||||
Aden's adaptation loop begins working from the first execution. When an agent fails, the framework captures the failure data, helping developers evolve the agent graph through the coding agent. How quickly this translates to measurable results depends on the complexity of your use case, the quality of your goal definitions, and the volume of executions generating feedback.
|
||||
|
||||
**Q: How does Hive compare to other agent frameworks?**
|
||||
|
||||
Hive focuses on generating agents that run real business processes, rather than generic agents. This vision emphasizes outcome-driven design, adaptability, and an easy-to-use set of tools and integrations.
|
||||
|
||||
**Q: Does Aden offer enterprise support?**
|
||||
|
||||
For enterprise inquiries, contact the Aden team through [adenhq.com](https://adenhq.com) or join our [Discord community](https://discord.com/invite/MXE49hrKDk) for support and discussions.
|
||||
|
||||
+1
-1
@@ -268,7 +268,7 @@ classDef done fill:#9e9e9e,color:#fff,stroke:#757575
|
||||
- [ ] Wake-up Tool (resume agent tasks)
|
||||
|
||||
### Deployment (Self-Hosted)
|
||||
- [ ] Docker container standardization
|
||||
- [ ] Workder agent docker container standardization
|
||||
- [ ] Headless backend execution
|
||||
- [ ] Exposed API for frontend attachment
|
||||
- [ ] Local monitoring & observability
|
||||
|
||||
@@ -0,0 +1,740 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
EventLoopNode WebSocket Demo
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams EventLoopNode execution to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/event_loop_wss_demo.py
|
||||
|
||||
Then open http://localhost:8765 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
import os # noqa: E402
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state (shared across WebSocket connections)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_demo_"))
|
||||
STORE = FileConversationStore(STORE_DIR / "conversation")
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — real tools via ToolRegistry (same pattern as GraphExecutor)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
# Credential store: Aden sync (OAuth2 tokens) + encrypted files + env var fallback
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_local_storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
|
||||
if os.environ.get("ADEN_API_KEY"):
|
||||
try:
|
||||
from framework.credentials.aden import ( # noqa: E402
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
_client = AdenCredentialClient(AdenClientConfig(base_url="https://api.adenhq.com"))
|
||||
_provider = AdenSyncProvider(client=_client)
|
||||
_storage = AdenCachedStorage(
|
||||
local_storage=_local_storage,
|
||||
aden_provider=_provider,
|
||||
)
|
||||
_cred_store = CredentialStore(storage=_storage, providers=[_provider], auto_refresh=True)
|
||||
_synced = _provider.sync_all(_cred_store)
|
||||
logger.info("Synced %d credentials from Aden", _synced)
|
||||
except Exception as e:
|
||||
logger.warning("Aden sync unavailable: %s", e)
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
else:
|
||||
logger.info("ADEN_API_KEY not set, using local credential storage")
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
|
||||
CREDENTIALS = CredentialStoreAdapter(_cred_store)
|
||||
|
||||
# Debug: log which credentials resolved
|
||||
for _name in ["brave_search", "hubspot", "anthropic"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# --- web_search (Brave Search API) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={"X-Subscription-Token": api_key, "Accept": "application/json"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
# --- web_scrape (httpx + BeautifulSoup, no playwright for sync compat) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(url, timeout=30.0, follow_redirects=True, headers=_SCRAPE_HEADERS)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {"url": url, "title": title, "content": text, "length": len(text)}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
# --- HubSpot CRM tools (optional, requires HUBSPOT_ACCESS_TOKEN) ---
|
||||
|
||||
_HUBSPOT_API = "https://api.hubapi.com"
|
||||
|
||||
|
||||
def _hubspot_headers() -> dict | None:
|
||||
token = CREDENTIALS.get("hubspot")
|
||||
if token:
|
||||
logger.debug("HubSpot token: %s...%s (len=%d)", token[:8], token[-4:], len(token))
|
||||
else:
|
||||
logger.debug("HubSpot token: not found")
|
||||
if not token:
|
||||
return None
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _exec_hubspot_search(inputs: dict) -> dict:
|
||||
headers = _hubspot_headers()
|
||||
if not headers:
|
||||
return {"error": "HUBSPOT_ACCESS_TOKEN not set"}
|
||||
object_type = inputs.get("object_type", "contacts")
|
||||
query = inputs.get("query", "")
|
||||
limit = min(inputs.get("limit", 10), 100)
|
||||
body: dict = {"limit": limit}
|
||||
if query:
|
||||
body["query"] = query
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_HUBSPOT_API}/crm/v3/objects/{object_type}/search",
|
||||
headers=headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HubSpot API HTTP {resp.status_code}: {resp.text[:200]}"}
|
||||
return resp.json()
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"HubSpot error: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="hubspot_search",
|
||||
tool=Tool(
|
||||
name="hubspot_search",
|
||||
description=(
|
||||
"Search HubSpot CRM objects (contacts, companies, or deals). "
|
||||
"Returns matching records with their properties."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"object_type": {
|
||||
"type": "string",
|
||||
"description": "CRM object type: 'contacts', 'companies', or 'deals'",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (name, email, domain, etc.)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (1-100, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["object_type"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_hubspot_search(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page (embedded)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>EventLoopNode Live Demo</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117; color: #c9d1d9;
|
||||
height: 100vh; display: flex; flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22; padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex; align-items: center; gap: 16px;
|
||||
}
|
||||
header h1 { font-size: 16px; color: #58a6ff; font-weight: 600; }
|
||||
.status {
|
||||
font-size: 12px; padding: 3px 10px; border-radius: 12px;
|
||||
background: #21262d; color: #8b949e;
|
||||
}
|
||||
.status.running { background: #1a4b2e; color: #3fb950; }
|
||||
.status.done { background: #1a3a5c; color: #58a6ff; }
|
||||
.status.error { background: #4b1a1a; color: #f85149; }
|
||||
.chat { flex: 1; overflow-y: auto; padding: 16px; }
|
||||
.msg {
|
||||
margin: 8px 0; padding: 10px 14px; border-radius: 8px;
|
||||
line-height: 1.6; white-space: pre-wrap; word-wrap: break-word;
|
||||
}
|
||||
.msg.user { background: #1a3a5c; color: #58a6ff; }
|
||||
.msg.assistant { background: #161b22; color: #c9d1d9; }
|
||||
.msg.event {
|
||||
background: transparent; color: #8b949e; font-size: 11px;
|
||||
padding: 4px 14px; border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop { border-left-color: #58a6ff; }
|
||||
.msg.event.tool { border-left-color: #d29922; }
|
||||
.msg.event.stall { border-left-color: #f85149; }
|
||||
.input-bar {
|
||||
padding: 12px 16px; background: #161b22;
|
||||
border-top: 1px solid #30363d; display: flex; gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1; background: #0d1117; border: 1px solid #30363d;
|
||||
color: #c9d1d9; padding: 8px 12px; border-radius: 6px;
|
||||
font-family: inherit; font-size: 14px; outline: none;
|
||||
}
|
||||
.input-bar input:focus { border-color: #58a6ff; }
|
||||
.input-bar button {
|
||||
background: #238636; color: #fff; border: none;
|
||||
padding: 8px 20px; border-radius: 6px; cursor: pointer;
|
||||
font-family: inherit; font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover { background: #2ea043; }
|
||||
.input-bar button:disabled {
|
||||
background: #21262d; color: #484f58; cursor: not-allowed;
|
||||
}
|
||||
.input-bar button.clear { background: #da3633; }
|
||||
.input-bar button.clear:hover { background: #f85149; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>EventLoopNode Live</h1>
|
||||
<span id="status" class="status">Idle</span>
|
||||
<span id="iter" class="status" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Ask anything..." autofocus />
|
||||
<button id="go" onclick="run()">Send</button>
|
||||
<button class="clear"
|
||||
onclick="clearConversation()">Clear</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
const chat = document.getElementById('chat');
|
||||
const status = document.getElementById('status');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setStatus(text, cls) {
|
||||
status.textContent = text;
|
||||
status.className = 'status ' + cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setStatus('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setStatus('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'ready') {
|
||||
setStatus('Ready', 'done');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg('RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls);
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
}
|
||||
else if (evt.type === 'result') {
|
||||
setStatus('Session ended', evt.success ? 'done' : 'error');
|
||||
if (evt.error) addMsg('ERROR ' + evt.error, 'event stall');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
else if (evt.type === 'cleared') {
|
||||
chat.innerHTML = '';
|
||||
iterCount = 0;
|
||||
iterEl.textContent = 'Step 0';
|
||||
iterEl.style.display = 'none';
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
setStatus('Running', 'running');
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
function clearConversation() {
|
||||
if (ws && ws.readyState === 1) {
|
||||
ws.send(JSON.stringify({ command: 'clear' }));
|
||||
}
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Persistent WebSocket: long-lived EventLoopNode with client_facing blocking."""
|
||||
global STORE
|
||||
|
||||
# -- Event forwarding (WebSocket ← EventBus) ----------------------------
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
# -- Per-connection state -----------------------------------------------
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
node_spec = NodeSpec(
|
||||
id="assistant",
|
||||
name="Chat Assistant",
|
||||
description="A conversational assistant that remembers context across messages",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to tools. "
|
||||
"You can search the web, scrape webpages, and query HubSpot CRM. "
|
||||
"Use tools when the user asks for current information or external data. "
|
||||
"You have full conversation history, so you can reference previous messages."
|
||||
),
|
||||
)
|
||||
|
||||
# -- Ready callback: subscribe to CLIENT_INPUT_REQUESTED on the bus ---
|
||||
async def on_input_requested(event):
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "ready"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=on_input_requested,
|
||||
)
|
||||
|
||||
async def start_loop(first_message: str):
|
||||
"""Create an EventLoopNode and run it as a background task."""
|
||||
nonlocal node, loop_task
|
||||
|
||||
memory = SharedMemory()
|
||||
ctx = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="assistant",
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
config=LoopConfig(max_iterations=10_000, max_history_tokens=32_000),
|
||||
conversation_store=STORE,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
await node.inject_event(first_message)
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"error": result.error,
|
||||
"tokens": result.tokens_used,
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Loop ended: success={result.success}, tokens={result.tokens_used}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("Loop stopped: WebSocket closed")
|
||||
except Exception as e:
|
||||
logger.exception("Loop error")
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"output": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
loop_task = asyncio.create_task(_run())
|
||||
|
||||
async def stop_loop():
|
||||
"""Signal the node and wait for the loop task to finish."""
|
||||
nonlocal node, loop_task
|
||||
if loop_task and not loop_task.done():
|
||||
if node:
|
||||
node.signal_shutdown()
|
||||
try:
|
||||
await asyncio.wait_for(loop_task, timeout=5.0)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
loop_task.cancel()
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
# -- Message loop (runs for the lifetime of this WebSocket) -------------
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Clear command
|
||||
if msg.get("command") == "clear":
|
||||
import shutil
|
||||
|
||||
await stop_loop()
|
||||
await STORE.close()
|
||||
conv_dir = STORE_DIR / "conversation"
|
||||
if conv_dir.exists():
|
||||
shutil.rmtree(conv_dir)
|
||||
STORE = FileConversationStore(conv_dir)
|
||||
await websocket.send(json.dumps({"type": "cleared"}))
|
||||
logger.info("Conversation cleared")
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
if node is None:
|
||||
# First message — spin up the loop
|
||||
logger.info(f"Starting persistent loop: {topic}")
|
||||
await start_loop(topic)
|
||||
else:
|
||||
# Subsequent message — inject into the running loop
|
||||
logger.info(f"Injecting message: {topic}")
|
||||
await node.inject_event(topic)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
await stop_loop()
|
||||
logger.info("WebSocket closed, loop stopped")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler for serving the HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None # let websockets handle the upgrade
|
||||
# Serve the HTML page for any other path
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8765
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Demo running at http://localhost:{port}")
|
||||
logger.info("Open in your browser and enter a topic to research.")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,930 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Two-Node ContextHandoff Demo
|
||||
|
||||
Demonstrates ContextHandoff between two EventLoopNode instances:
|
||||
Node A (Researcher) → ContextHandoff → Node B (Analyst)
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams both nodes to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/handoff_demo.py
|
||||
|
||||
Then open http://localhost:8766 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.context_handoff import ContextHandoff # noqa: E402
|
||||
from framework.graph.conversation import NodeConversation # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("handoff_demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_handoff_"))
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Credentials
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Composite credential store: encrypted files (primary) + env vars (fallback)
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_composite = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
CREDENTIALS = CredentialStoreAdapter(CredentialStore(storage=_composite))
|
||||
|
||||
for _name in ["brave_search", "hubspot"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — web_search + web_scrape for Node A (Researcher)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(
|
||||
url,
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
headers=_SCRAPE_HEADERS,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": text,
|
||||
"length": len(text),
|
||||
}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Node Specs
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
RESEARCHER_SPEC = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches a topic using web search and scraping tools",
|
||||
node_type="event_loop",
|
||||
input_keys=["topic"],
|
||||
output_keys=["research_summary"],
|
||||
system_prompt=(
|
||||
"You are a thorough research assistant. Your job is to research "
|
||||
"the given topic using the web_search and web_scrape tools.\n\n"
|
||||
"1. Search for relevant information on the topic\n"
|
||||
"2. Scrape 1-2 of the most promising URLs for details\n"
|
||||
"3. Synthesize your findings into a comprehensive summary\n"
|
||||
"4. Use set_output with key='research_summary' to save your "
|
||||
"findings\n\n"
|
||||
"Be thorough but efficient. Aim for 2-4 search/scrape calls, "
|
||||
"then summarize and set_output."
|
||||
),
|
||||
)
|
||||
|
||||
ANALYST_SPEC = NodeSpec(
|
||||
id="analyst",
|
||||
name="Analyst",
|
||||
description="Analyzes research findings and provides insights",
|
||||
node_type="event_loop",
|
||||
input_keys=["context"],
|
||||
output_keys=["analysis"],
|
||||
system_prompt=(
|
||||
"You are a strategic analyst. You receive research findings from "
|
||||
"a previous researcher and must:\n\n"
|
||||
"1. Identify key themes and patterns\n"
|
||||
"2. Assess the reliability and significance of the findings\n"
|
||||
"3. Provide actionable insights and recommendations\n"
|
||||
"4. Use set_output with key='analysis' to save your analysis\n\n"
|
||||
"Be concise but insightful. Focus on what matters most."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>ContextHandoff Demo</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117;
|
||||
color: #c9d1d9;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22;
|
||||
padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
header h1 {
|
||||
font-size: 16px;
|
||||
color: #58a6ff;
|
||||
font-weight: 600;
|
||||
}
|
||||
.badge {
|
||||
font-size: 12px;
|
||||
padding: 3px 10px;
|
||||
border-radius: 12px;
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.researcher {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.badge.analyst {
|
||||
background: #1a4b2e;
|
||||
color: #3fb950;
|
||||
}
|
||||
.badge.handoff {
|
||||
background: #3d1f00;
|
||||
color: #d29922;
|
||||
}
|
||||
.badge.done {
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.error {
|
||||
background: #4b1a1a;
|
||||
color: #f85149;
|
||||
}
|
||||
.chat {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 16px;
|
||||
}
|
||||
.msg {
|
||||
margin: 8px 0;
|
||||
padding: 10px 14px;
|
||||
border-radius: 8px;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
.msg.user {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.msg.assistant {
|
||||
background: #161b22;
|
||||
color: #c9d1d9;
|
||||
}
|
||||
.msg.assistant.analyst-msg {
|
||||
border-left: 3px solid #3fb950;
|
||||
}
|
||||
.msg.event {
|
||||
background: transparent;
|
||||
color: #8b949e;
|
||||
font-size: 11px;
|
||||
padding: 4px 14px;
|
||||
border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop {
|
||||
border-left-color: #58a6ff;
|
||||
}
|
||||
.msg.event.tool {
|
||||
border-left-color: #d29922;
|
||||
}
|
||||
.msg.event.stall {
|
||||
border-left-color: #f85149;
|
||||
}
|
||||
.handoff-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #1c1200;
|
||||
border: 1px solid #d29922;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.handoff-banner h3 {
|
||||
color: #d29922;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.handoff-banner p, .result-banner p {
|
||||
color: #8b949e;
|
||||
font-size: 12px;
|
||||
line-height: 1.5;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap;
|
||||
text-align: left;
|
||||
}
|
||||
.result-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #0a2614;
|
||||
border: 1px solid #3fb950;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.result-banner h3 {
|
||||
color: #3fb950;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.result-banner .label {
|
||||
color: #58a6ff;
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
.result-banner .tokens {
|
||||
color: #484f58;
|
||||
font-size: 11px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.input-bar {
|
||||
padding: 12px 16px;
|
||||
background: #161b22;
|
||||
border-top: 1px solid #30363d;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1;
|
||||
background: #0d1117;
|
||||
border: 1px solid #30363d;
|
||||
color: #c9d1d9;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
font-family: inherit;
|
||||
font-size: 14px;
|
||||
outline: none;
|
||||
}
|
||||
.input-bar input:focus {
|
||||
border-color: #58a6ff;
|
||||
}
|
||||
.input-bar button {
|
||||
background: #238636;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 20px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover {
|
||||
background: #2ea043;
|
||||
}
|
||||
.input-bar button:disabled {
|
||||
background: #21262d;
|
||||
color: #484f58;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>ContextHandoff Demo</h1>
|
||||
<span id="phase" class="badge">Idle</span>
|
||||
<span id="iter" class="badge" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Enter a research topic..." autofocus />
|
||||
<button id="go" onclick="run()">Research</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
let currentPhase = 'idle';
|
||||
const chat = document.getElementById('chat');
|
||||
const phase = document.getElementById('phase');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setPhase(text, cls) {
|
||||
phase.textContent = text;
|
||||
phase.className = 'badge ' + cls;
|
||||
currentPhase = cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function addHandoffBanner(summary) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'handoff-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Context Handoff: Researcher -> Analyst';
|
||||
const p = document.createElement('p');
|
||||
p.textContent = summary || 'Passing research context...';
|
||||
banner.appendChild(h3);
|
||||
banner.appendChild(p);
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function addResultBanner(researcher, analyst, tokens) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'result-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Pipeline Complete';
|
||||
banner.appendChild(h3);
|
||||
|
||||
if (researcher && researcher.research_summary) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'RESEARCH SUMMARY';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = researcher.research_summary;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (analyst && analyst.analysis) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'ANALYSIS';
|
||||
lbl.style.color = '#3fb950';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = analyst.analysis;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (tokens) {
|
||||
const t = document.createElement('div');
|
||||
t.className = 'tokens';
|
||||
t.textContent = 'Total tokens: ' + tokens.toLocaleString();
|
||||
banner.appendChild(t);
|
||||
}
|
||||
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setPhase('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setPhase('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setPhase('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'phase') {
|
||||
if (evt.phase === 'researcher') {
|
||||
setPhase('Researcher', 'researcher');
|
||||
} else if (evt.phase === 'handoff') {
|
||||
setPhase('Handoff', 'handoff');
|
||||
} else if (evt.phase === 'analyst') {
|
||||
setPhase('Analyst', 'analyst');
|
||||
}
|
||||
iterCount = 0;
|
||||
iterEl.style.display = 'none';
|
||||
}
|
||||
else if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg(
|
||||
'RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls
|
||||
);
|
||||
var assistCls = currentPhase === 'analyst'
|
||||
? 'assistant analyst-msg' : 'assistant';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'handoff_context') {
|
||||
addHandoffBanner(evt.summary);
|
||||
var assistCls = 'assistant analyst-msg';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'node_result') {
|
||||
if (evt.node_id === 'researcher') {
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'done') {
|
||||
setPhase('Done', 'done');
|
||||
iterEl.style.display = 'none';
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
currentAssistantEl = null;
|
||||
addResultBanner(
|
||||
evt.researcher, evt.analyst, evt.total_tokens
|
||||
);
|
||||
goBtn.disabled = false;
|
||||
inputEl.placeholder = 'Enter another topic...';
|
||||
}
|
||||
else if (evt.type === 'error') {
|
||||
setPhase('Error', 'error');
|
||||
addMsg('ERROR ' + evt.message, 'event stall');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
chat.innerHTML = '';
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler — sequential Node A → Handoff → Node B
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Run the two-node handoff pipeline per user message."""
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
logger.info(f"Starting handoff pipeline for: {topic}")
|
||||
|
||||
try:
|
||||
await _run_pipeline(websocket, topic)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("WebSocket closed during pipeline")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Pipeline error")
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_pipeline(websocket, topic: str):
|
||||
"""Execute: Node A (research) → ContextHandoff → Node B (analysis)."""
|
||||
import shutil
|
||||
|
||||
# Fresh stores for each run
|
||||
run_dir = Path(tempfile.mkdtemp(prefix="hive_run_", dir=STORE_DIR))
|
||||
store_a = FileConversationStore(run_dir / "node_a")
|
||||
store_b = FileConversationStore(run_dir / "node_b")
|
||||
|
||||
# Shared event bus
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
# ---- Phase 1: Researcher ------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "researcher"}))
|
||||
|
||||
node_a = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge: accept when output_keys filled
|
||||
config=LoopConfig(
|
||||
max_iterations=20,
|
||||
max_tool_calls_per_turn=10,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_a,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
ctx_a = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="researcher",
|
||||
node_spec=RESEARCHER_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"topic": topic},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
result_a = await node_a.execute(ctx_a)
|
||||
logger.info(
|
||||
"Researcher done: success=%s, tokens=%s",
|
||||
result_a.success,
|
||||
result_a.tokens_used,
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_result",
|
||||
"node_id": "researcher",
|
||||
"success": result_a.success,
|
||||
"output": result_a.output,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if not result_a.success:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": f"Researcher failed: {result_a.error}",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# ---- Phase 2: Context Handoff -------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "handoff"}))
|
||||
|
||||
# Restore the researcher's conversation from store
|
||||
conversation_a = await NodeConversation.restore(store_a)
|
||||
if conversation_a is None:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Failed to restore researcher conversation",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
handoff_engine = ContextHandoff(llm=LLM)
|
||||
handoff_context = handoff_engine.summarize_conversation(
|
||||
conversation=conversation_a,
|
||||
node_id="researcher",
|
||||
output_keys=["research_summary"],
|
||||
)
|
||||
|
||||
formatted_handoff = ContextHandoff.format_as_input(handoff_context)
|
||||
logger.info(
|
||||
"Handoff: %d turns, ~%d tokens, keys=%s",
|
||||
handoff_context.turn_count,
|
||||
handoff_context.total_tokens_used,
|
||||
list(handoff_context.key_outputs.keys()),
|
||||
)
|
||||
|
||||
# Send handoff context to browser
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "handoff_context",
|
||||
"summary": handoff_context.summary[:500],
|
||||
"turn_count": handoff_context.turn_count,
|
||||
"tokens": handoff_context.total_tokens_used,
|
||||
"key_outputs": handoff_context.key_outputs,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# ---- Phase 3: Analyst ---------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "analyst"}))
|
||||
|
||||
node_b = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
max_tool_calls_per_turn=5,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_b,
|
||||
)
|
||||
|
||||
ctx_b = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="analyst",
|
||||
node_spec=ANALYST_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"context": formatted_handoff},
|
||||
llm=LLM,
|
||||
available_tools=[],
|
||||
)
|
||||
|
||||
result_b = await node_b.execute(ctx_b)
|
||||
logger.info(
|
||||
"Analyst done: success=%s, tokens=%s",
|
||||
result_b.success,
|
||||
result_b.tokens_used,
|
||||
)
|
||||
|
||||
# ---- Done ---------------------------------------------------------------
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "done",
|
||||
"researcher": result_a.output,
|
||||
"analyst": result_b.output,
|
||||
"total_tokens": ((result_a.tokens_used or 0) + (result_b.tokens_used or 0)),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Clean up temp stores
|
||||
try:
|
||||
shutil.rmtree(run_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8766
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Handoff demo at http://localhost:{port}")
|
||||
logger.info("Enter a research topic to start the pipeline.")
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,7 +15,7 @@ You cannot skip steps or bypass validation.
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -26,7 +26,7 @@ from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
|
||||
class BuildPhase(str, Enum):
|
||||
class BuildPhase(StrEnum):
|
||||
"""Current phase of the build process."""
|
||||
|
||||
INIT = "init" # Just started
|
||||
|
||||
@@ -64,6 +64,8 @@ class AdenCachedStorage(CredentialStorage):
|
||||
- **Reads**: Try local cache first, fallback to Aden if stale/missing
|
||||
- **Writes**: Always write to local cache
|
||||
- **Offline resilience**: Uses cached credentials when Aden is unreachable
|
||||
- **Provider-based lookup**: Match credentials by provider name (e.g., "hubspot")
|
||||
when direct ID lookup fails, since Aden uses hash-based IDs internally.
|
||||
|
||||
The cache TTL determines how long to trust local credentials before
|
||||
checking with the Aden server for updates. This balances:
|
||||
@@ -85,6 +87,7 @@ class AdenCachedStorage(CredentialStorage):
|
||||
|
||||
# First access fetches from Aden
|
||||
# Subsequent accesses use cache until TTL expires
|
||||
# Can look up by provider name OR credential ID
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
|
||||
@@ -111,21 +114,24 @@ class AdenCachedStorage(CredentialStorage):
|
||||
self._cache_ttl = timedelta(seconds=cache_ttl_seconds)
|
||||
self._prefer_local = prefer_local
|
||||
self._cache_timestamps: dict[str, datetime] = {}
|
||||
# Index: provider name (e.g., "hubspot") -> credential hash ID
|
||||
self._provider_index: dict[str, str] = {}
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save credential to local cache.
|
||||
Save credential to local cache and update provider index.
|
||||
|
||||
Args:
|
||||
credential: The credential to save.
|
||||
"""
|
||||
self._local.save(credential)
|
||||
self._cache_timestamps[credential.id] = datetime.now(UTC)
|
||||
self._index_provider(credential)
|
||||
logger.debug(f"Cached credential '{credential.id}'")
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load credential from cache, with Aden fallback.
|
||||
Load credential from cache, with Aden fallback and provider-based lookup.
|
||||
|
||||
The loading strategy depends on the `prefer_local` setting:
|
||||
|
||||
@@ -141,8 +147,37 @@ class AdenCachedStorage(CredentialStorage):
|
||||
2. Update local cache with response
|
||||
3. Fall back to local cache only if Aden fails
|
||||
|
||||
Provider-based lookup:
|
||||
When a provider index mapping exists for the credential_id (e.g.,
|
||||
"hubspot" → hash ID), the Aden-synced credential is loaded first.
|
||||
This ensures fresh OAuth tokens from Aden take priority over stale
|
||||
local credentials (env vars, old encrypted files).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
credential_id: The credential identifier or provider name.
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
"""
|
||||
# Check provider index first — Aden-synced credentials take priority
|
||||
resolved_id = self._provider_index.get(credential_id)
|
||||
if resolved_id and resolved_id != credential_id:
|
||||
result = self._load_by_id(resolved_id)
|
||||
if result is not None:
|
||||
logger.info(
|
||||
f"Loaded credential '{credential_id}' via provider index (id='{resolved_id}')"
|
||||
)
|
||||
return result
|
||||
|
||||
# Direct lookup (exact credential_id match)
|
||||
return self._load_by_id(credential_id)
|
||||
|
||||
def _load_by_id(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load credential by exact ID from cache, with Aden fallback.
|
||||
|
||||
Args:
|
||||
credential_id: The exact credential identifier.
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
@@ -200,15 +235,21 @@ class AdenCachedStorage(CredentialStorage):
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if credential exists in local cache.
|
||||
Check if credential exists in local cache (by ID or provider name).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
credential_id: The credential identifier or provider name.
|
||||
|
||||
Returns:
|
||||
True if credential exists locally.
|
||||
"""
|
||||
return self._local.exists(credential_id)
|
||||
if self._local.exists(credential_id):
|
||||
return True
|
||||
# Check provider index
|
||||
resolved_id = self._provider_index.get(credential_id)
|
||||
if resolved_id and resolved_id != credential_id:
|
||||
return self._local.exists(resolved_id)
|
||||
return False
|
||||
|
||||
def _is_cache_fresh(self, credential_id: str) -> bool:
|
||||
"""
|
||||
@@ -242,6 +283,47 @@ class AdenCachedStorage(CredentialStorage):
|
||||
self._cache_timestamps.clear()
|
||||
logger.debug("Invalidated all cache entries")
|
||||
|
||||
def _index_provider(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Index a credential by its provider/integration type.
|
||||
|
||||
Aden credentials carry an ``_integration_type`` key whose value is
|
||||
the provider name (e.g., ``hubspot``). This method maps that
|
||||
provider name to the credential's hash ID so that subsequent
|
||||
``load("hubspot")`` calls resolve to the correct credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to index.
|
||||
"""
|
||||
integration_type_key = credential.keys.get("_integration_type")
|
||||
if integration_type_key is None:
|
||||
return
|
||||
provider_name = integration_type_key.value.get_secret_value()
|
||||
if provider_name:
|
||||
self._provider_index[provider_name] = credential.id
|
||||
logger.debug(f"Indexed provider '{provider_name}' -> '{credential.id}'")
|
||||
|
||||
def rebuild_provider_index(self) -> int:
|
||||
"""
|
||||
Rebuild the provider index from all locally cached credentials.
|
||||
|
||||
Useful after loading from disk when the in-memory index is empty.
|
||||
|
||||
Returns:
|
||||
Number of provider mappings indexed.
|
||||
"""
|
||||
self._provider_index.clear()
|
||||
indexed = 0
|
||||
for cred_id in self._local.list_all():
|
||||
cred = self._local.load(cred_id)
|
||||
if cred:
|
||||
before = len(self._provider_index)
|
||||
self._index_provider(cred)
|
||||
if len(self._provider_index) > before:
|
||||
indexed += 1
|
||||
logger.debug(f"Rebuilt provider index with {indexed} mappings")
|
||||
return indexed
|
||||
|
||||
def sync_all_from_aden(self) -> int:
|
||||
"""
|
||||
Sync all credentials from Aden server to local cache.
|
||||
|
||||
@@ -589,6 +589,149 @@ class TestAdenCachedStorage:
|
||||
assert info["stale"]["is_fresh"] is False
|
||||
assert info["stale"]["ttl_remaining_seconds"] == 0
|
||||
|
||||
def test_save_indexes_provider(self, cached_storage):
|
||||
"""Test save builds the provider index from _integration_type key."""
|
||||
cred = CredentialObject(
|
||||
id="aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token-value"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert cached_storage._provider_index["hubspot"] == "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
|
||||
def test_load_by_provider_name(self, cached_storage):
|
||||
"""Test load resolves provider name to hash-based credential ID."""
|
||||
hash_id = "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("hubspot-token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Save builds the index
|
||||
cached_storage.save(cred)
|
||||
|
||||
# Load by provider name should resolve to the hash ID
|
||||
loaded = cached_storage.load("hubspot")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == hash_id
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "hubspot-token"
|
||||
|
||||
def test_load_by_direct_id_still_works(self, cached_storage):
|
||||
"""Test load by direct hash ID still works as before."""
|
||||
hash_id = "aHVic3BvdDp0ZXN0OjEzNjExOjExNTI1"
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("hubspot"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
# Direct ID lookup should still work
|
||||
loaded = cached_storage.load(hash_id)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == hash_id
|
||||
|
||||
def test_exists_by_provider_name(self, cached_storage):
|
||||
"""Test exists resolves provider name to hash-based credential ID."""
|
||||
hash_id = "c2xhY2s6dGVzdDo5OTk="
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("slack-token"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr("slack"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert cached_storage.exists("slack") is True
|
||||
assert cached_storage.exists(hash_id) is True
|
||||
assert cached_storage.exists("nonexistent") is False
|
||||
|
||||
def test_rebuild_provider_index(self, cached_storage, local_storage):
|
||||
"""Test rebuild_provider_index reconstructs from local storage."""
|
||||
# Manually save credentials to local storage (bypassing cached_storage.save)
|
||||
for provider_name, hash_id in [("hubspot", "hash_hub"), ("slack", "hash_slack")]:
|
||||
cred = CredentialObject(
|
||||
id=hash_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr(provider_name),
|
||||
),
|
||||
},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
|
||||
# Index should be empty (we bypassed save)
|
||||
assert len(cached_storage._provider_index) == 0
|
||||
|
||||
# Rebuild
|
||||
indexed = cached_storage.rebuild_provider_index()
|
||||
|
||||
assert indexed == 2
|
||||
assert cached_storage._provider_index["hubspot"] == "hash_hub"
|
||||
assert cached_storage._provider_index["slack"] == "hash_slack"
|
||||
|
||||
def test_save_without_integration_type_no_index(self, cached_storage):
|
||||
"""Test save does not index credentials without _integration_type key."""
|
||||
cred = CredentialObject(
|
||||
id="plain-cred",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={
|
||||
"api_key": CredentialKey(
|
||||
name="api_key",
|
||||
value=SecretStr("key-value"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert "plain-cred" not in cached_storage._provider_index
|
||||
assert len(cached_storage._provider_index) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
|
||||
@@ -8,7 +8,7 @@ containing one or more keys (e.g., api_key, access_token, refresh_token).
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@@ -19,7 +19,7 @@ def _utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class CredentialType(str, Enum):
|
||||
class CredentialType(StrEnum):
|
||||
"""Types of credentials the store can manage."""
|
||||
|
||||
API_KEY = "api_key"
|
||||
|
||||
@@ -11,11 +11,11 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TokenPlacement(str, Enum):
|
||||
class TokenPlacement(StrEnum):
|
||||
"""Where to place the access token in HTTP requests."""
|
||||
|
||||
HEADER_BEARER = "header_bearer"
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
|
||||
|
||||
from framework.graph.client_io import (
|
||||
ActiveNodeClientIO,
|
||||
ClientIOGateway,
|
||||
InertNodeClientIO,
|
||||
NodeClientIO,
|
||||
)
|
||||
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import ConversationStore, Message, NodeConversation
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
JudgeProtocol,
|
||||
JudgeVerdict,
|
||||
LoopConfig,
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
|
||||
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
|
||||
@@ -72,4 +87,22 @@ __all__ = [
|
||||
"CodeSandbox",
|
||||
"safe_exec",
|
||||
"safe_eval",
|
||||
# Conversation
|
||||
"NodeConversation",
|
||||
"ConversationStore",
|
||||
"Message",
|
||||
# Event Loop
|
||||
"EventLoopNode",
|
||||
"LoopConfig",
|
||||
"OutputAccumulator",
|
||||
"JudgeProtocol",
|
||||
"JudgeVerdict",
|
||||
# Context Handoff
|
||||
"ContextHandoff",
|
||||
"HandoffContext",
|
||||
# Client I/O
|
||||
"NodeClientIO",
|
||||
"ActiveNodeClientIO",
|
||||
"InertNodeClientIO",
|
||||
"ClientIOGateway",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Client I/O gateway for graph nodes.
|
||||
|
||||
Provides the bridge between node code and external clients:
|
||||
- ActiveNodeClientIO: for client_facing=True nodes (streams output, accepts input)
|
||||
- InertNodeClientIO: for client_facing=False nodes (logs internally, redirects input)
|
||||
- ClientIOGateway: factory that creates the right variant per node
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeClientIO(ABC):
|
||||
"""Abstract base for node client I/O."""
|
||||
|
||||
@abstractmethod
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
"""Emit output content. If is_final=True, signal end of stream."""
|
||||
|
||||
@abstractmethod
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
"""Request input. Behavior depends on whether the node is client-facing."""
|
||||
|
||||
|
||||
class ActiveNodeClientIO(NodeClientIO):
|
||||
"""
|
||||
Client I/O for client_facing=True nodes.
|
||||
|
||||
- emit_output() queues content and publishes CLIENT_OUTPUT_DELTA.
|
||||
- request_input() publishes CLIENT_INPUT_REQUESTED, then awaits provide_input().
|
||||
- output_stream() yields queued content until the final sentinel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
|
||||
self._output_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._output_snapshot = ""
|
||||
|
||||
self._input_event: asyncio.Event | None = None
|
||||
self._input_result: str | None = None
|
||||
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
self._output_snapshot += content
|
||||
await self._output_queue.put(content)
|
||||
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_client_output_delta(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
snapshot=self._output_snapshot,
|
||||
)
|
||||
|
||||
if is_final:
|
||||
await self._output_queue.put(None)
|
||||
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
if self._input_event is not None:
|
||||
raise RuntimeError("request_input already pending for this node")
|
||||
|
||||
self._input_event = asyncio.Event()
|
||||
self._input_result = None
|
||||
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_client_input_requested(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
try:
|
||||
if timeout is not None:
|
||||
await asyncio.wait_for(self._input_event.wait(), timeout=timeout)
|
||||
else:
|
||||
await self._input_event.wait()
|
||||
finally:
|
||||
self._input_event = None
|
||||
|
||||
if self._input_result is None:
|
||||
raise RuntimeError("input event was set but no input was provided")
|
||||
result = self._input_result
|
||||
self._input_result = None
|
||||
return result
|
||||
|
||||
async def provide_input(self, content: str) -> None:
|
||||
"""Called externally to fulfill a pending request_input()."""
|
||||
if self._input_event is None:
|
||||
raise RuntimeError("no pending request_input to fulfill")
|
||||
self._input_result = content
|
||||
self._input_event.set()
|
||||
|
||||
async def output_stream(self) -> AsyncIterator[str]:
|
||||
"""Async iterator that yields output chunks until the final sentinel."""
|
||||
while True:
|
||||
chunk = await self._output_queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
|
||||
class InertNodeClientIO(NodeClientIO):
|
||||
"""
|
||||
Client I/O for client_facing=False nodes.
|
||||
|
||||
- emit_output() publishes NODE_INTERNAL_OUTPUT (content is not discarded).
|
||||
- request_input() publishes NODE_INPUT_BLOCKED and returns a redirect string.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
|
||||
async def emit_output(self, content: str, is_final: bool = False) -> None:
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_node_internal_output(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
async def request_input(self, prompt: str = "", timeout: float | None = None) -> str:
|
||||
if self._event_bus is not None:
|
||||
await self._event_bus.emit_node_input_blocked(
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
return (
|
||||
"You are an internal processing node. There is no user to interact with."
|
||||
" Work with the data provided in your inputs to complete your task."
|
||||
)
|
||||
|
||||
|
||||
class ClientIOGateway:
|
||||
"""Factory that creates the appropriate NodeClientIO for a node."""
|
||||
|
||||
def __init__(self, event_bus: EventBus | None = None) -> None:
|
||||
self._event_bus = event_bus
|
||||
|
||||
def create_io(self, node_id: str, client_facing: bool) -> NodeClientIO:
|
||||
if client_facing:
|
||||
return ActiveNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
)
|
||||
return InertNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
)
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Context handoff: summarize a completed NodeConversation for the next graph node."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.conversation import _try_extract_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.provider import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRUNCATE_CHARS = 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandoffContext:
|
||||
"""Structured summary of a completed node conversation."""
|
||||
|
||||
source_node_id: str
|
||||
summary: str
|
||||
key_outputs: dict[str, Any]
|
||||
turn_count: int
|
||||
total_tokens_used: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContextHandoff
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ContextHandoff:
|
||||
"""Summarize a completed NodeConversation into a HandoffContext.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
llm : LLMProvider | None
|
||||
Optional LLM provider for abstractive summarization.
|
||||
When *None*, all summarization uses the extractive fallback.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLMProvider | None = None) -> None:
|
||||
self.llm = llm
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def summarize_conversation(
|
||||
self,
|
||||
conversation: NodeConversation,
|
||||
node_id: str,
|
||||
output_keys: list[str] | None = None,
|
||||
) -> HandoffContext:
|
||||
"""Produce a HandoffContext from *conversation*.
|
||||
|
||||
1. Extracts turn_count & total_tokens_used (sync properties).
|
||||
2. Extracts key_outputs by scanning assistant messages most-recent-first.
|
||||
3. Builds a summary via the LLM (if available) or extractive fallback.
|
||||
"""
|
||||
turn_count = conversation.turn_count
|
||||
total_tokens_used = conversation.estimate_tokens()
|
||||
messages = conversation.messages # defensive copy
|
||||
|
||||
# --- key outputs ---------------------------------------------------
|
||||
key_outputs: dict[str, Any] = {}
|
||||
if output_keys:
|
||||
remaining = set(output_keys)
|
||||
for msg in reversed(messages):
|
||||
if msg.role != "assistant" or not remaining:
|
||||
continue
|
||||
for key in list(remaining):
|
||||
value = _try_extract_key(msg.content, key)
|
||||
if value is not None:
|
||||
key_outputs[key] = value
|
||||
remaining.discard(key)
|
||||
|
||||
# --- summary -------------------------------------------------------
|
||||
if self.llm is not None:
|
||||
try:
|
||||
summary = self._llm_summary(messages, output_keys or [])
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"LLM summarization failed; falling back to extractive.",
|
||||
exc_info=True,
|
||||
)
|
||||
summary = self._extractive_summary(messages)
|
||||
else:
|
||||
summary = self._extractive_summary(messages)
|
||||
|
||||
return HandoffContext(
|
||||
source_node_id=node_id,
|
||||
summary=summary,
|
||||
key_outputs=key_outputs,
|
||||
turn_count=turn_count,
|
||||
total_tokens_used=total_tokens_used,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_as_input(handoff: HandoffContext) -> str:
|
||||
"""Render *handoff* as structured plain text for the next node's input."""
|
||||
header = (
|
||||
f"--- CONTEXT FROM: {handoff.source_node_id} "
|
||||
f"({handoff.turn_count} turns, ~{handoff.total_tokens_used} tokens) ---"
|
||||
)
|
||||
|
||||
sections: list[str] = [header, ""]
|
||||
|
||||
if handoff.key_outputs:
|
||||
sections.append("KEY OUTPUTS:")
|
||||
for k, v in handoff.key_outputs.items():
|
||||
sections.append(f"- {k}: {v}")
|
||||
sections.append("")
|
||||
|
||||
summary_text = handoff.summary or "No summary available."
|
||||
sections.append("SUMMARY:")
|
||||
sections.append(summary_text)
|
||||
sections.append("")
|
||||
sections.append("--- END CONTEXT ---")
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extractive_summary(messages: list) -> str:
|
||||
"""Build a summary from key assistant messages without an LLM.
|
||||
|
||||
Strategy:
|
||||
- Include the first assistant message (initial assessment).
|
||||
- Include the last assistant message (final conclusion).
|
||||
- Truncate each to ~500 chars.
|
||||
"""
|
||||
if not messages:
|
||||
return "Empty conversation."
|
||||
|
||||
assistant_msgs = [m for m in messages if m.role == "assistant"]
|
||||
if not assistant_msgs:
|
||||
return "No assistant responses."
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
first = assistant_msgs[0].content
|
||||
parts.append(first[:_TRUNCATE_CHARS])
|
||||
|
||||
if len(assistant_msgs) > 1:
|
||||
last = assistant_msgs[-1].content
|
||||
parts.append(last[:_TRUNCATE_CHARS])
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _llm_summary(self, messages: list, output_keys: list[str]) -> str:
|
||||
"""Produce a summary by calling the LLM provider."""
|
||||
if self.llm is None:
|
||||
raise ValueError("_llm_summary called without an LLM provider")
|
||||
|
||||
conversation_text = "\n".join(f"[{m.role}]: {m.content}" for m in messages)
|
||||
|
||||
key_hint = ""
|
||||
if output_keys:
|
||||
key_hint = (
|
||||
"\nThe following output keys are especially important: "
|
||||
+ ", ".join(output_keys)
|
||||
+ ".\n"
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a concise summarizer. Given the conversation below, "
|
||||
"produce a brief summary (at most ~500 tokens) that captures the "
|
||||
"key decisions, findings, and outcomes. Focus on what was concluded "
|
||||
"rather than the back-and-forth process." + key_hint
|
||||
)
|
||||
|
||||
response = self.llm.complete(
|
||||
messages=[{"role": "user", "content": conversation_text}],
|
||||
system=system_prompt,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
return response.content.strip()
|
||||
@@ -0,0 +1,600 @@
|
||||
"""NodeConversation: Message history management for graph nodes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A single message in a conversation.
|
||||
|
||||
Attributes:
|
||||
seq: Monotonic sequence number.
|
||||
role: One of "user", "assistant", or "tool".
|
||||
content: Message text.
|
||||
tool_use_id: Internal tool-use identifier (output as ``tool_call_id`` in LLM dicts).
|
||||
tool_calls: OpenAI-format tool call list for assistant messages.
|
||||
is_error: When True and role is "tool", ``to_llm_dict`` prepends "ERROR: " to content.
|
||||
"""
|
||||
|
||||
seq: int
|
||||
role: Literal["user", "assistant", "tool"]
|
||||
content: str
|
||||
tool_use_id: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
is_error: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
if self.role == "user":
|
||||
return {"role": "user", "content": self.content}
|
||||
|
||||
if self.role == "assistant":
|
||||
d: dict[str, Any] = {"role": "assistant", "content": self.content}
|
||||
if self.tool_calls:
|
||||
d["tool_calls"] = self.tool_calls
|
||||
return d
|
||||
|
||||
# role == "tool"
|
||||
content = f"ERROR: {self.content}" if self.is_error else self.content
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_use_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def to_storage_dict(self) -> dict[str, Any]:
|
||||
"""Serialize all fields for persistence. Omits None/default-False fields."""
|
||||
d: dict[str, Any] = {
|
||||
"seq": self.seq,
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
}
|
||||
if self.tool_use_id is not None:
|
||||
d["tool_use_id"] = self.tool_use_id
|
||||
if self.tool_calls is not None:
|
||||
d["tool_calls"] = self.tool_calls
|
||||
if self.is_error:
|
||||
d["is_error"] = self.is_error
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_storage_dict(cls, data: dict[str, Any]) -> Message:
|
||||
"""Deserialize from a storage dict."""
|
||||
return cls(
|
||||
seq=data["seq"],
|
||||
role=data["role"],
|
||||
content=data["content"],
|
||||
tool_use_id=data.get("tool_use_id"),
|
||||
tool_calls=data.get("tool_calls"),
|
||||
is_error=data.get("is_error", False),
|
||||
)
|
||||
|
||||
|
||||
def _extract_spillover_filename(content: str) -> str | None:
|
||||
"""Extract spillover filename from a truncated tool result.
|
||||
|
||||
Matches the pattern produced by EventLoopNode._truncate_tool_result():
|
||||
"saved to 'tool_github_list_stargazers_abc123.txt'"
|
||||
"""
|
||||
match = re.search(r"saved to '([^']+)'", content)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationStore protocol (Phase 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConversationStore(Protocol):
|
||||
"""Protocol for conversation persistence backends."""
|
||||
|
||||
async def write_part(self, seq: int, data: dict[str, Any]) -> None: ...
|
||||
|
||||
async def read_parts(self) -> list[dict[str, Any]]: ...
|
||||
|
||||
async def write_meta(self, data: dict[str, Any]) -> None: ...
|
||||
|
||||
async def read_meta(self) -> dict[str, Any] | None: ...
|
||||
|
||||
async def write_cursor(self, data: dict[str, Any]) -> None: ...
|
||||
|
||||
async def read_cursor(self) -> dict[str, Any] | None: ...
|
||||
|
||||
async def delete_parts_before(self, seq: int) -> None: ...
|
||||
|
||||
async def close(self) -> None: ...
|
||||
|
||||
async def destroy(self) -> None: ...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NodeConversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_extract_key(content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a *key*'s value from message content.
|
||||
|
||||
Strategies (in order):
|
||||
1. Whole message is JSON — ``json.loads``, check for key.
|
||||
2. Embedded JSON via ``find_json_object`` helper.
|
||||
3. Colon format: ``key: value``.
|
||||
4. Equals format: ``key = value``.
|
||||
"""
|
||||
from framework.graph.node import find_json_object
|
||||
|
||||
# 1. Whole message is JSON
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 2. Embedded JSON via find_json_object
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Colon format: key: value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 4. Equals format: key = value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NodeConversation:
|
||||
"""Message history for a graph node with optional write-through persistence.
|
||||
|
||||
When *store* is ``None`` the conversation works purely in-memory.
|
||||
When a :class:`ConversationStore` is supplied every mutation is
|
||||
persisted via write-through (meta is lazily written on the first
|
||||
``_persist`` call).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str = "",
|
||||
max_history_tokens: int = 32000,
|
||||
compaction_threshold: float = 0.8,
|
||||
output_keys: list[str] | None = None,
|
||||
store: ConversationStore | None = None,
|
||||
) -> None:
|
||||
self._system_prompt = system_prompt
|
||||
self._max_history_tokens = max_history_tokens
|
||||
self._compaction_threshold = compaction_threshold
|
||||
self._output_keys = output_keys
|
||||
self._store = store
|
||||
self._messages: list[Message] = []
|
||||
self._next_seq: int = 0
|
||||
self._meta_persisted: bool = False
|
||||
self._last_api_input_tokens: int | None = None
|
||||
|
||||
# --- Properties --------------------------------------------------------
|
||||
|
||||
@property
|
||||
def system_prompt(self) -> str:
|
||||
return self._system_prompt
|
||||
|
||||
@property
|
||||
def messages(self) -> list[Message]:
|
||||
"""Return a defensive copy of the message list."""
|
||||
return list(self._messages)
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
"""Number of conversational turns (one turn = one user message)."""
|
||||
return sum(1 for m in self._messages if m.role == "user")
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""Total number of messages (all roles)."""
|
||||
return len(self._messages)
|
||||
|
||||
@property
|
||||
def next_seq(self) -> int:
|
||||
return self._next_seq
|
||||
|
||||
# --- Add messages ------------------------------------------------------
|
||||
|
||||
async def add_user_message(self, content: str) -> Message:
|
||||
msg = Message(seq=self._next_seq, role="user", content=content)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
await self._persist(msg)
|
||||
return msg
|
||||
|
||||
async def add_assistant_message(
|
||||
self,
|
||||
content: str,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
await self._persist(msg)
|
||||
return msg
|
||||
|
||||
async def add_tool_result(
|
||||
self,
|
||||
tool_use_id: str,
|
||||
content: str,
|
||||
is_error: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
role="tool",
|
||||
content=content,
|
||||
tool_use_id=tool_use_id,
|
||||
is_error=is_error,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
await self._persist(msg)
|
||||
return msg
|
||||
|
||||
# --- Query -------------------------------------------------------------
|
||||
|
||||
def to_llm_messages(self) -> list[dict[str, Any]]:
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded).
|
||||
|
||||
Automatically repairs orphaned tool_use blocks (assistant messages
|
||||
with tool_calls that lack corresponding tool-result messages). This
|
||||
can happen when a loop is cancelled mid-tool-execution.
|
||||
"""
|
||||
msgs = [m.to_llm_dict() for m in self._messages]
|
||||
return self._repair_orphaned_tool_calls(msgs)
|
||||
|
||||
@staticmethod
|
||||
def _repair_orphaned_tool_calls(
|
||||
msgs: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Ensure every tool_call has a matching tool-result message."""
|
||||
repaired: list[dict[str, Any]] = []
|
||||
for i, m in enumerate(msgs):
|
||||
repaired.append(m)
|
||||
tool_calls = m.get("tool_calls")
|
||||
if m.get("role") != "assistant" or not tool_calls:
|
||||
continue
|
||||
# Collect IDs of tool results that follow this assistant message
|
||||
answered: set[str] = set()
|
||||
for j in range(i + 1, len(msgs)):
|
||||
if msgs[j].get("role") == "tool":
|
||||
tid = msgs[j].get("tool_call_id")
|
||||
if tid:
|
||||
answered.add(tid)
|
||||
else:
|
||||
break # stop at first non-tool message
|
||||
# Patch any missing results
|
||||
for tc in tool_calls:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in answered:
|
||||
repaired.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": "ERROR: Tool execution was interrupted.",
|
||||
}
|
||||
)
|
||||
return repaired
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Best available token estimate.
|
||||
|
||||
Uses actual API input token count when available (set via
|
||||
:meth:`update_token_count`), otherwise falls back to the rough
|
||||
``total_chars / 4`` heuristic.
|
||||
"""
|
||||
if self._last_api_input_tokens is not None:
|
||||
return self._last_api_input_tokens
|
||||
total_chars = sum(len(m.content) for m in self._messages)
|
||||
return total_chars // 4
|
||||
|
||||
def update_token_count(self, actual_input_tokens: int) -> None:
|
||||
"""Store actual API input token count for more accurate compaction.
|
||||
|
||||
Called by EventLoopNode after each LLM call with the ``input_tokens``
|
||||
value from the API response. This value includes system prompt and
|
||||
tool definitions, so it may be higher than a message-only estimate.
|
||||
"""
|
||||
self._last_api_input_tokens = actual_input_tokens
|
||||
|
||||
def usage_ratio(self) -> float:
|
||||
"""Current token usage as a fraction of *max_history_tokens*.
|
||||
|
||||
Returns 0.0 when ``max_history_tokens`` is zero (unlimited).
|
||||
"""
|
||||
if self._max_history_tokens <= 0:
|
||||
return 0.0
|
||||
return self.estimate_tokens() / self._max_history_tokens
|
||||
|
||||
def needs_compaction(self) -> bool:
|
||||
return self.estimate_tokens() >= self._max_history_tokens * self._compaction_threshold
|
||||
|
||||
# --- Output-key extraction ---------------------------------------------
|
||||
|
||||
def _extract_protected_values(self, messages: list[Message]) -> dict[str, str]:
|
||||
"""Scan assistant messages for output_key values before compaction.
|
||||
|
||||
Iterates most-recent-first. Once a key is found, it's skipped for
|
||||
older messages (latest value wins).
|
||||
"""
|
||||
if not self._output_keys:
|
||||
return {}
|
||||
|
||||
found: dict[str, str] = {}
|
||||
remaining_keys = set(self._output_keys)
|
||||
|
||||
for msg in reversed(messages):
|
||||
if msg.role != "assistant" or not remaining_keys:
|
||||
continue
|
||||
|
||||
for key in list(remaining_keys):
|
||||
value = self._try_extract_key(msg.content, key)
|
||||
if value is not None:
|
||||
found[key] = value
|
||||
remaining_keys.discard(key)
|
||||
|
||||
return found
|
||||
|
||||
def _try_extract_key(self, content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a key's value from message content."""
|
||||
return _try_extract_key(content, key)
|
||||
|
||||
# --- Lifecycle ---------------------------------------------------------
|
||||
|
||||
async def prune_old_tool_results(
|
||||
self,
|
||||
protect_tokens: int = 5000,
|
||||
min_prune_tokens: int = 2000,
|
||||
) -> int:
|
||||
"""Replace old tool result content with compact placeholders.
|
||||
|
||||
Walks backward through messages. Recent tool results (within
|
||||
*protect_tokens*) are kept intact. Older tool results have their
|
||||
content replaced with a ~100-char placeholder that preserves the
|
||||
spillover filename reference (if any). Message structure (role,
|
||||
seq, tool_use_id) stays valid for the LLM API.
|
||||
|
||||
Error tool results are never pruned — they prevent re-calling
|
||||
failing tools.
|
||||
|
||||
Returns the number of messages pruned (0 if nothing was pruned).
|
||||
"""
|
||||
if not self._messages:
|
||||
return 0
|
||||
|
||||
# Phase 1: Walk backward, classify tool results as protected vs pruneable
|
||||
protected_tokens = 0
|
||||
pruneable: list[int] = [] # indices into self._messages
|
||||
pruneable_tokens = 0
|
||||
|
||||
for i in range(len(self._messages) - 1, -1, -1):
|
||||
msg = self._messages[i]
|
||||
if msg.role != "tool":
|
||||
continue
|
||||
if msg.is_error:
|
||||
continue # never prune errors
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
|
||||
est = len(msg.content) // 4
|
||||
if protected_tokens < protect_tokens:
|
||||
protected_tokens += est
|
||||
else:
|
||||
pruneable.append(i)
|
||||
pruneable_tokens += est
|
||||
|
||||
# Phase 2: Only prune if enough to be worthwhile
|
||||
if pruneable_tokens < min_prune_tokens:
|
||||
return 0
|
||||
|
||||
# Phase 3: Replace content with compact placeholder
|
||||
count = 0
|
||||
for i in pruneable:
|
||||
msg = self._messages[i]
|
||||
orig_len = len(msg.content)
|
||||
spillover = _extract_spillover_filename(msg.content)
|
||||
|
||||
if spillover:
|
||||
placeholder = (
|
||||
f"[Pruned tool result: {orig_len} chars. "
|
||||
f"Full data in '{spillover}'. "
|
||||
f"Use load_data('{spillover}') to retrieve.]"
|
||||
)
|
||||
else:
|
||||
placeholder = f"[Pruned tool result: {orig_len} chars cleared from context.]"
|
||||
|
||||
self._messages[i] = Message(
|
||||
seq=msg.seq,
|
||||
role=msg.role,
|
||||
content=placeholder,
|
||||
tool_use_id=msg.tool_use_id,
|
||||
tool_calls=msg.tool_calls,
|
||||
is_error=msg.is_error,
|
||||
)
|
||||
count += 1
|
||||
|
||||
if self._store:
|
||||
await self._store.write_part(msg.seq, self._messages[i].to_storage_dict())
|
||||
|
||||
# Reset token estimate — content lengths changed
|
||||
self._last_api_input_tokens = None
|
||||
return count
|
||||
|
||||
async def compact(self, summary: str, keep_recent: int = 2) -> None:
|
||||
"""Replace old messages with a summary, optionally keeping recent ones.
|
||||
|
||||
Args:
|
||||
summary: Caller-provided summary text.
|
||||
keep_recent: Number of recent messages to preserve (default 2).
|
||||
Clamped to [0, len(messages) - 1].
|
||||
"""
|
||||
if not self._messages:
|
||||
return
|
||||
|
||||
# Clamp: must discard at least 1 message
|
||||
keep_recent = max(0, min(keep_recent, len(self._messages) - 1))
|
||||
|
||||
total = len(self._messages)
|
||||
split = total - keep_recent if keep_recent > 0 else total
|
||||
|
||||
# Advance split past orphaned tool results at the boundary.
|
||||
# Tool-role messages reference a tool_use from the preceding
|
||||
# assistant message; if that assistant message falls into the
|
||||
# compacted (old) portion the tool_result becomes invalid.
|
||||
while split < total and self._messages[split].role == "tool":
|
||||
split += 1
|
||||
|
||||
old_messages = list(self._messages[:split])
|
||||
recent_messages = list(self._messages[split:])
|
||||
|
||||
# Extract protected values from messages being discarded
|
||||
if self._output_keys:
|
||||
protected = self._extract_protected_values(old_messages)
|
||||
if protected:
|
||||
lines = ["PRESERVED VALUES (do not lose these):"]
|
||||
for k, v in protected.items():
|
||||
lines.append(f"- {k}: {v}")
|
||||
lines.append("")
|
||||
lines.append("CONVERSATION SUMMARY:")
|
||||
lines.append(summary)
|
||||
summary = "\n".join(lines)
|
||||
|
||||
# Determine summary seq
|
||||
if recent_messages:
|
||||
summary_seq = recent_messages[0].seq - 1
|
||||
else:
|
||||
summary_seq = self._next_seq
|
||||
self._next_seq += 1
|
||||
|
||||
summary_msg = Message(seq=summary_seq, role="user", content=summary)
|
||||
|
||||
# Persist
|
||||
if self._store:
|
||||
delete_before = recent_messages[0].seq if recent_messages else self._next_seq
|
||||
await self._store.delete_parts_before(delete_before)
|
||||
await self._store.write_part(summary_msg.seq, summary_msg.to_storage_dict())
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
self._messages = [summary_msg] + recent_messages
|
||||
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
|
||||
if self._store:
|
||||
await self._store.delete_parts_before(self._next_seq)
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
self._messages.clear()
|
||||
self._last_api_input_tokens = None
|
||||
|
||||
def export_summary(self) -> str:
|
||||
"""Structured summary with [STATS], [CONFIG], [RECENT_MESSAGES] sections."""
|
||||
prompt_preview = (
|
||||
self._system_prompt[:80] + "..."
|
||||
if len(self._system_prompt) > 80
|
||||
else self._system_prompt
|
||||
)
|
||||
|
||||
lines = [
|
||||
"[STATS]",
|
||||
f"turns: {self.turn_count}",
|
||||
f"messages: {self.message_count}",
|
||||
f"estimated_tokens: {self.estimate_tokens()}",
|
||||
"",
|
||||
"[CONFIG]",
|
||||
f"system_prompt: {prompt_preview!r}",
|
||||
]
|
||||
|
||||
if self._output_keys:
|
||||
lines.append(f"output_keys: {', '.join(self._output_keys)}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("[RECENT_MESSAGES]")
|
||||
for m in self._messages[-5:]:
|
||||
preview = m.content[:60] + "..." if len(m.content) > 60 else m.content
|
||||
lines.append(f" [{m.role}] {preview}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# --- Persistence internals ---------------------------------------------
|
||||
|
||||
async def _persist(self, message: Message) -> None:
|
||||
"""Write-through a single message. No-op when store is None."""
|
||||
if self._store is None:
|
||||
return
|
||||
if not self._meta_persisted:
|
||||
await self._persist_meta()
|
||||
await self._store.write_part(message.seq, message.to_storage_dict())
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
async def _persist_meta(self) -> None:
|
||||
"""Lazily write conversation metadata to the store (called once)."""
|
||||
if self._store is None:
|
||||
return
|
||||
await self._store.write_meta(
|
||||
{
|
||||
"system_prompt": self._system_prompt,
|
||||
"max_history_tokens": self._max_history_tokens,
|
||||
"compaction_threshold": self._compaction_threshold,
|
||||
"output_keys": self._output_keys,
|
||||
}
|
||||
)
|
||||
self._meta_persisted = True
|
||||
|
||||
# --- Restore -----------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def restore(cls, store: ConversationStore) -> NodeConversation | None:
|
||||
"""Reconstruct a NodeConversation from a store.
|
||||
|
||||
Returns ``None`` if the store contains no metadata (i.e. the
|
||||
conversation was never persisted).
|
||||
"""
|
||||
meta = await store.read_meta()
|
||||
if meta is None:
|
||||
return None
|
||||
|
||||
conv = cls(
|
||||
system_prompt=meta.get("system_prompt", ""),
|
||||
max_history_tokens=meta.get("max_history_tokens", 32000),
|
||||
compaction_threshold=meta.get("compaction_threshold", 0.8),
|
||||
output_keys=meta.get("output_keys"),
|
||||
store=store,
|
||||
)
|
||||
conv._meta_persisted = True
|
||||
|
||||
parts = await store.read_parts()
|
||||
conv._messages = [Message.from_storage_dict(p) for p in parts]
|
||||
|
||||
cursor = await store.read_cursor()
|
||||
if cursor:
|
||||
conv._next_seq = cursor["next_seq"]
|
||||
elif conv._messages:
|
||||
conv._next_seq = conv._messages[-1].seq + 1
|
||||
|
||||
return conv
|
||||
@@ -11,7 +11,6 @@ our edges can be created dynamically by a Builder agent based on the goal.
|
||||
|
||||
Edge Types:
|
||||
- always: Always traverse after source completes
|
||||
- always: Always traverse after source completes
|
||||
- on_success: Traverse only if source succeeds
|
||||
- on_failure: Traverse only if source fails
|
||||
- conditional: Traverse based on expression evaluation (SAFE SUBSET ONLY)
|
||||
@@ -22,7 +21,7 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
|
||||
given the current goal, context, and execution state.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -30,7 +29,7 @@ from pydantic import BaseModel, Field
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
|
||||
class EdgeCondition(str, Enum):
|
||||
class EdgeCondition(StrEnum):
|
||||
"""When an edge should be traversed."""
|
||||
|
||||
ALWAYS = "always" # Always after source completes
|
||||
@@ -609,4 +608,40 @@ class GraphSpec(BaseModel):
|
||||
continue
|
||||
errors.append(f"Node '{node.id}' is unreachable from entry")
|
||||
|
||||
# Client-facing fan-out validation
|
||||
fan_outs = self.detect_fan_out_nodes()
|
||||
for source_id, targets in fan_outs.items():
|
||||
client_facing_targets = [
|
||||
t
|
||||
for t in targets
|
||||
if self.get_node(t) and getattr(self.get_node(t), "client_facing", False)
|
||||
]
|
||||
if len(client_facing_targets) > 1:
|
||||
errors.append(
|
||||
f"Fan-out from '{source_id}' has multiple client-facing nodes: "
|
||||
f"{client_facing_targets}. Only one branch may be client-facing."
|
||||
)
|
||||
|
||||
# Output key overlap on parallel event_loop nodes
|
||||
for source_id, targets in fan_outs.items():
|
||||
event_loop_targets = [
|
||||
t
|
||||
for t in targets
|
||||
if self.get_node(t) and getattr(self.get_node(t), "node_type", "") == "event_loop"
|
||||
]
|
||||
if len(event_loop_targets) > 1:
|
||||
seen_keys: dict[str, str] = {}
|
||||
for node_id in event_loop_targets:
|
||||
node = self.get_node(node_id)
|
||||
for key in getattr(node, "output_keys", []):
|
||||
if key in seen_keys:
|
||||
errors.append(
|
||||
f"Fan-out from '{source_id}': event_loop nodes "
|
||||
f"'{seen_keys[key]}' and '{node_id}' both write to "
|
||||
f"output_key '{key}'. Parallel event_loop nodes must "
|
||||
f"have disjoint output_keys to prevent last-wins data loss."
|
||||
)
|
||||
else:
|
||||
seen_keys[key] = node_id
|
||||
|
||||
return errors
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,11 +11,12 @@ The executor:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.edge import EdgeSpec, GraphSpec
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import (
|
||||
FunctionNode,
|
||||
@@ -54,6 +55,9 @@ class ExecutionResult:
|
||||
had_partial_failures: bool = False # True if any node failed but eventually succeeded
|
||||
execution_quality: str = "clean" # "clean", "degraded", or "failed"
|
||||
|
||||
# Visit tracking (for feedback/callback edges)
|
||||
node_visit_counts: dict[str, int] = field(default_factory=dict) # {node_id: visit_count}
|
||||
|
||||
@property
|
||||
def is_clean_success(self) -> bool:
|
||||
"""True only if execution succeeded with no retries or failures."""
|
||||
@@ -250,6 +254,8 @@ class GraphExecutor:
|
||||
total_tokens = 0
|
||||
total_latency = 0
|
||||
node_retry_counts: dict[str, int] = {} # Track retries per node
|
||||
node_visit_counts: dict[str, int] = {} # Track visits for feedback loops
|
||||
_is_retry = False # True when looping back for a retry (not a new visit)
|
||||
|
||||
# Determine entry point (may differ if resuming)
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
@@ -278,6 +284,34 @@ class GraphExecutor:
|
||||
if node_spec is None:
|
||||
raise RuntimeError(f"Node not found: {current_node_id}")
|
||||
|
||||
# Enforce max_node_visits (feedback/callback edge support)
|
||||
# Don't increment visit count on retries — retries are not new visits
|
||||
if not _is_retry:
|
||||
cnt = node_visit_counts.get(current_node_id, 0) + 1
|
||||
node_visit_counts[current_node_id] = cnt
|
||||
_is_retry = False
|
||||
max_visits = getattr(node_spec, "max_node_visits", 1)
|
||||
if max_visits > 0 and node_visit_counts[current_node_id] > max_visits:
|
||||
self.logger.warning(
|
||||
f" ⊘ Node '{node_spec.name}' visit limit reached "
|
||||
f"({node_visit_counts[current_node_id]}/{max_visits}), skipping"
|
||||
)
|
||||
# Skip execution — follow outgoing edges using current memory
|
||||
skip_result = NodeResult(success=True, output=memory.read_all())
|
||||
next_node = self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
current_node_spec=node_spec,
|
||||
result=skip_result,
|
||||
memory=memory,
|
||||
)
|
||||
if next_node is None:
|
||||
self.logger.info(" → No more edges after visit limit, ending")
|
||||
break
|
||||
current_node_id = next_node
|
||||
continue
|
||||
|
||||
path.append(current_node_id)
|
||||
|
||||
# Check if pause (HITL) before execution
|
||||
@@ -380,6 +414,15 @@ class GraphExecutor:
|
||||
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
|
||||
max_retries = getattr(node_spec, "max_retries", 3)
|
||||
|
||||
# Event loop nodes handle retry internally via judge —
|
||||
# executor retry is catastrophic (retry multiplication)
|
||||
if node_spec.node_type == "event_loop" and max_retries > 0:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has max_retries={max_retries}. "
|
||||
"Overriding to 0 — event loop nodes handle retry internally via judge."
|
||||
)
|
||||
max_retries = 0
|
||||
|
||||
if node_retry_counts[current_node_id] < max_retries:
|
||||
# Retry - don't increment steps for retries
|
||||
steps -= 1
|
||||
@@ -395,6 +438,7 @@ class GraphExecutor:
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
_is_retry = True
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - fail the execution
|
||||
@@ -437,6 +481,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
# Check if we just executed a pause node - if so, save state and return
|
||||
@@ -476,6 +521,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
# Check if this is a terminal node - if so, we're done
|
||||
@@ -596,6 +642,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -622,6 +669,7 @@ class GraphExecutor:
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
node_visit_counts=dict(node_visit_counts),
|
||||
)
|
||||
|
||||
def _build_context(
|
||||
@@ -658,7 +706,15 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
VALID_NODE_TYPES = {
|
||||
"llm_tool_use",
|
||||
"llm_generate",
|
||||
"router",
|
||||
"function",
|
||||
"human_input",
|
||||
"event_loop",
|
||||
}
|
||||
DEPRECATED_NODE_TYPES = {"llm_tool_use": "event_loop", "llm_generate": "event_loop"}
|
||||
|
||||
def _get_node_implementation(
|
||||
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
|
||||
@@ -676,6 +732,17 @@ class GraphExecutor:
|
||||
f"Use 'llm_tool_use' for nodes that call tools, 'llm_generate' for text generation."
|
||||
)
|
||||
|
||||
# Warn on deprecated node types
|
||||
if node_spec.node_type in self.DEPRECATED_NODE_TYPES:
|
||||
replacement = self.DEPRECATED_NODE_TYPES[node_spec.node_type]
|
||||
warnings.warn(
|
||||
f"Node type '{node_spec.node_type}' is deprecated. "
|
||||
f"Use '{replacement}' instead. "
|
||||
f"Node: '{node_spec.id}'",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Create based on type
|
||||
if node_spec.node_type == "llm_tool_use":
|
||||
if not node_spec.tools:
|
||||
@@ -713,6 +780,13 @@ class GraphExecutor:
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "event_loop":
|
||||
# Event loop nodes must be pre-registered (like function nodes)
|
||||
raise RuntimeError(
|
||||
f"EventLoopNode '{node_spec.id}' not found in registry. "
|
||||
"Register it with executor.register_node() before execution."
|
||||
)
|
||||
|
||||
# Should never reach here due to validation above
|
||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||
|
||||
@@ -823,6 +897,21 @@ class GraphExecutor:
|
||||
):
|
||||
traversable.append(edge)
|
||||
|
||||
# Priority filtering for CONDITIONAL edges:
|
||||
# When multiple CONDITIONAL edges match, keep only the highest-priority
|
||||
# group. This prevents mutually-exclusive conditional branches (e.g.
|
||||
# forward vs. feedback) from incorrectly triggering fan-out.
|
||||
# ON_SUCCESS / other edge types are unaffected.
|
||||
if len(traversable) > 1:
|
||||
conditionals = [e for e in traversable if e.condition == EdgeCondition.CONDITIONAL]
|
||||
if len(conditionals) > 1:
|
||||
max_prio = max(e.priority for e in conditionals)
|
||||
traversable = [
|
||||
e
|
||||
for e in traversable
|
||||
if e.condition != EdgeCondition.CONDITIONAL or e.priority == max_prio
|
||||
]
|
||||
|
||||
return traversable
|
||||
|
||||
def _find_convergence_node(
|
||||
@@ -909,6 +998,17 @@ class GraphExecutor:
|
||||
branch.status = "failed"
|
||||
branch.error = f"Node {branch.node_id} not found in graph"
|
||||
return branch, RuntimeError(branch.error)
|
||||
|
||||
effective_max_retries = node_spec.max_retries
|
||||
if node_spec.node_type == "event_loop":
|
||||
if effective_max_retries > 1:
|
||||
self.logger.warning(
|
||||
f"EventLoopNode '{node_spec.id}' has "
|
||||
f"max_retries={effective_max_retries}. Overriding "
|
||||
"to 1 — event loop nodes handle retry internally."
|
||||
)
|
||||
effective_max_retries = 1
|
||||
|
||||
branch.status = "running"
|
||||
|
||||
try:
|
||||
@@ -942,7 +1042,7 @@ class GraphExecutor:
|
||||
|
||||
# Execute with retries
|
||||
last_result = None
|
||||
for attempt in range(node_spec.max_retries):
|
||||
for attempt in range(effective_max_retries):
|
||||
branch.retry_count = attempt
|
||||
|
||||
# Build context for this branch
|
||||
@@ -970,7 +1070,7 @@ class GraphExecutor:
|
||||
|
||||
self.logger.warning(
|
||||
f" ↻ Branch {node_spec.name}: "
|
||||
f"retry {attempt + 1}/{node_spec.max_retries}"
|
||||
f"retry {attempt + 1}/{effective_max_retries}"
|
||||
)
|
||||
|
||||
# All retries exhausted
|
||||
@@ -979,7 +1079,7 @@ class GraphExecutor:
|
||||
branch.result = last_result
|
||||
self.logger.error(
|
||||
f" ✗ Branch {node_spec.name}: "
|
||||
f"failed after {node_spec.max_retries} attempts"
|
||||
f"failed after {effective_max_retries} attempts"
|
||||
)
|
||||
return branch, last_result
|
||||
|
||||
|
||||
@@ -12,13 +12,13 @@ Goals are:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GoalStatus(str, Enum):
|
||||
class GoalStatus(StrEnum):
|
||||
"""Lifecycle status of a goal."""
|
||||
|
||||
DRAFT = "draft" # Being defined
|
||||
|
||||
@@ -6,11 +6,11 @@ where agents need to gather input from humans.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class HITLInputType(str, Enum):
|
||||
class HITLInputType(StrEnum):
|
||||
"""Type of input expected from human."""
|
||||
|
||||
FREE_TEXT = "free_text" # Open-ended text response
|
||||
|
||||
@@ -16,10 +16,12 @@ Protocol:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -153,7 +155,10 @@ class NodeSpec(BaseModel):
|
||||
# Node behavior type
|
||||
node_type: str = Field(
|
||||
default="llm_tool_use",
|
||||
description="Type: 'llm_tool_use', 'llm_generate', 'function', 'router', 'human_input'",
|
||||
description=(
|
||||
"Type: 'event_loop', 'function', 'router', 'human_input'. "
|
||||
"Deprecated: 'llm_tool_use', 'llm_generate' (use 'event_loop' instead)."
|
||||
),
|
||||
)
|
||||
|
||||
# Data flow
|
||||
@@ -205,6 +210,15 @@ class NodeSpec(BaseModel):
|
||||
max_retries: int = Field(default=3)
|
||||
retry_on: list[str] = Field(default_factory=list, description="Error types to retry on")
|
||||
|
||||
# Visit limits (for feedback/callback edges)
|
||||
max_node_visits: int = Field(
|
||||
default=1,
|
||||
description=(
|
||||
"Max times this node executes in one graph run. "
|
||||
"Set >1 for feedback loops. 0 = unlimited (max_steps guards)."
|
||||
),
|
||||
)
|
||||
|
||||
# Pydantic model for output validation
|
||||
output_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
@@ -218,6 +232,12 @@ class NodeSpec(BaseModel):
|
||||
description="Maximum retries when Pydantic validation fails (with feedback to LLM)",
|
||||
)
|
||||
|
||||
# Client-facing behavior
|
||||
client_facing: bool = Field(
|
||||
default=False,
|
||||
description="If True, this node streams output to the end user and can request input.",
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow", "arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
@@ -1348,7 +1368,9 @@ Expected output keys: {output_keys}
|
||||
LLM Response:
|
||||
{raw_response}
|
||||
|
||||
Output ONLY the JSON object, nothing else."""
|
||||
Output ONLY the JSON object, nothing else.
|
||||
If no valid JSON object exists in the response, output exactly: {{"error": "NO_JSON_FOUND"}}
|
||||
Do NOT fabricate data or return empty objects."""
|
||||
|
||||
try:
|
||||
result = cleaner_llm.complete(
|
||||
@@ -1395,6 +1417,14 @@ Output ONLY the JSON object, nothing else."""
|
||||
parsed = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
parsed = json.loads(_fix_unescaped_newlines_in_json(cleaned))
|
||||
|
||||
# Validate LLM didn't return empty or fabricated data
|
||||
if parsed.get("error") == "NO_JSON_FOUND":
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
if not parsed or parsed == {}:
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
if all(v is None for v in parsed.values()):
|
||||
raise ValueError("Cannot parse JSON from response")
|
||||
logger.info(" ✓ LLM cleaned JSON output")
|
||||
return parsed
|
||||
|
||||
@@ -1504,6 +1534,8 @@ Output ONLY the JSON object, nothing else."""
|
||||
|
||||
def _build_system_prompt(self, ctx: NodeContext) -> str:
|
||||
"""Build the system prompt."""
|
||||
from datetime import datetime
|
||||
|
||||
parts = []
|
||||
|
||||
if ctx.node_spec.system_prompt:
|
||||
@@ -1526,6 +1558,15 @@ Output ONLY the JSON object, nothing else."""
|
||||
|
||||
parts.append(prompt)
|
||||
|
||||
# Inject current datetime so LLM knows "now"
|
||||
utc_dt = datetime.now(UTC)
|
||||
local_dt = datetime.now().astimezone()
|
||||
local_tz_name = local_dt.tzname() or "Unknown"
|
||||
parts.append("\n## Runtime Context")
|
||||
parts.append(f"- Current Date/Time (UTC): {utc_dt.isoformat()}")
|
||||
parts.append(f"- Local Timezone: {local_tz_name}")
|
||||
parts.append(f"- Current Date/Time (Local): {local_dt.isoformat()}")
|
||||
|
||||
if ctx.goal_context:
|
||||
parts.append("\n# Goal Context")
|
||||
parts.append(ctx.goal_context)
|
||||
@@ -1727,8 +1768,19 @@ class FunctionNode(NodeProtocol):
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Call the function
|
||||
result = self.func(**ctx.input_data)
|
||||
# Filter input_data to only declared input_keys to prevent
|
||||
# leaking extra memory keys from upstream nodes.
|
||||
if ctx.node_spec.input_keys:
|
||||
filtered = {
|
||||
k: v for k, v in ctx.input_data.items() if k in ctx.node_spec.input_keys
|
||||
}
|
||||
else:
|
||||
filtered = ctx.input_data
|
||||
|
||||
# Call the function (supports both sync and async)
|
||||
result = self.func(**filtered)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
latency_ms = int((time.time() - start) * 1000)
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ The Plan is the contract between the external planner and the executor:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
class ActionType(StrEnum):
|
||||
"""Types of actions a PlanStep can perform."""
|
||||
|
||||
LLM_CALL = "llm_call" # Call LLM for generation
|
||||
@@ -27,7 +27,7 @@ class ActionType(str, Enum):
|
||||
CODE_EXECUTION = "code_execution" # Execute dynamic code (sandboxed)
|
||||
|
||||
|
||||
class StepStatus(str, Enum):
|
||||
class StepStatus(StrEnum):
|
||||
"""Status of a plan step."""
|
||||
|
||||
PENDING = "pending"
|
||||
@@ -56,7 +56,7 @@ class StepStatus(str, Enum):
|
||||
return self == StepStatus.COMPLETED
|
||||
|
||||
|
||||
class ApprovalDecision(str, Enum):
|
||||
class ApprovalDecision(StrEnum):
|
||||
"""Human decision on a step requiring approval."""
|
||||
|
||||
APPROVE = "approve" # Execute as planned
|
||||
@@ -91,7 +91,7 @@ class ApprovalResult(BaseModel):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class JudgmentAction(str, Enum):
|
||||
class JudgmentAction(StrEnum):
|
||||
"""Actions the judge can take after evaluating a step."""
|
||||
|
||||
ACCEPT = "accept" # Step completed successfully, continue
|
||||
@@ -423,7 +423,7 @@ class Plan(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
class ExecutionStatus(StrEnum):
|
||||
"""Status of plan execution."""
|
||||
|
||||
COMPLETED = "completed"
|
||||
|
||||
@@ -75,16 +75,6 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
def visit_Constant(self, node: ast.Constant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Number/String/Bytes/NameConstant (Python < 3.8 compat if needed) ---
|
||||
def visit_Num(self, node: ast.Num) -> Any:
|
||||
return node.n
|
||||
|
||||
def visit_Str(self, node: ast.Str) -> Any:
|
||||
return node.s
|
||||
|
||||
def visit_NameConstant(self, node: ast.NameConstant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Data Structures ---
|
||||
def visit_List(self, node: ast.List) -> list:
|
||||
return [self.visit(elt) for elt in node.elts]
|
||||
|
||||
@@ -126,14 +126,16 @@ class OutputValidator:
|
||||
|
||||
for key in expected_keys:
|
||||
if key not in output:
|
||||
errors.append(f"Missing required output key: '{key}'")
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Missing required output key: '{key}'")
|
||||
elif not allow_empty:
|
||||
value = output[key]
|
||||
if value is None:
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Output key '{key}' is None")
|
||||
elif isinstance(value, str) and len(value.strip()) == 0:
|
||||
errors.append(f"Output key '{key}' is empty string")
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Output key '{key}' is empty string")
|
||||
|
||||
return ValidationResult(success=len(errors) == 0, errors=errors)
|
||||
|
||||
|
||||
@@ -1,8 +1,31 @@
|
||||
"""LLM provider abstraction."""
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse"]
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"StreamEvent",
|
||||
"TextDeltaEvent",
|
||||
"TextEndEvent",
|
||||
"ToolCallEvent",
|
||||
"ToolResultEvent",
|
||||
"ReasoningStartEvent",
|
||||
"ReasoningDeltaEvent",
|
||||
"FinishEvent",
|
||||
"StreamErrorEvent",
|
||||
]
|
||||
|
||||
try:
|
||||
from framework.llm.anthropic import AnthropicProvider # noqa: F401
|
||||
|
||||
@@ -18,7 +18,7 @@ def _get_api_key_from_credential_store() -> str | None:
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except ImportError:
|
||||
|
||||
@@ -7,10 +7,11 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,6 +24,7 @@ except ImportError:
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -161,11 +163,24 @@ class LiteLLMProvider(LLMProvider):
|
||||
content = response.choices[0].message.content if response.choices else None
|
||||
has_tool_calls = bool(response.choices and response.choices[0].message.tool_calls)
|
||||
if not content and not has_tool_calls:
|
||||
# If the conversation ends with an assistant message,
|
||||
# an empty response is expected — don't retry.
|
||||
messages = kwargs.get("messages", [])
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
logger.debug(
|
||||
"[retry] Empty response after assistant message — "
|
||||
"expected, not retrying."
|
||||
)
|
||||
return response
|
||||
|
||||
finish_reason = (
|
||||
response.choices[0].finish_reason if response.choices else "unknown"
|
||||
)
|
||||
# Dump full request to file for debugging
|
||||
messages = kwargs.get("messages", [])
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
@@ -378,11 +393,18 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
# Execute tools and add results.
|
||||
for tool_call in message.tool_calls:
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
# Surface error to LLM and skip tool execution
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
@@ -425,3 +447,185 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a completion via litellm.acompletion(stream=True).
|
||||
|
||||
Yields StreamEvent objects as chunks arrive from the provider.
|
||||
Tool call arguments are accumulated across chunks and yielded as
|
||||
a single ToolCallEvent with fully parsed JSON when complete.
|
||||
|
||||
Empty responses (e.g. Gemini stealth rate-limits that return 200
|
||||
with no content) are retried with exponential backoff, mirroring
|
||||
the retry behaviour of ``_completion_with_rate_limit_retry``.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
# Post-stream events (ToolCall, TextEnd, Finish) are buffered
|
||||
# because they depend on the full stream. TextDeltaEvents are
|
||||
# yielded immediately so callers see tokens in real time.
|
||||
tail_events: list[StreamEvent] = []
|
||||
accumulated_text = ""
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
async for chunk in response:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# --- Text content — yield immediately for real-time streaming ---
|
||||
if delta and delta.content:
|
||||
accumulated_text += delta.content
|
||||
yield TextDeltaEvent(
|
||||
content=delta.content,
|
||||
snapshot=accumulated_text,
|
||||
)
|
||||
|
||||
# --- Tool calls (accumulate across chunks) ---
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function:
|
||||
if tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
|
||||
# --- Finish ---
|
||||
if choice.finish_reason:
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
tail_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
tool_name=tc_data["name"],
|
||||
tool_input=parsed_args,
|
||||
)
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
tail_events.append(TextEndEvent(full_text=accumulated_text))
|
||||
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
tail_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether the stream produced any real content.
|
||||
# (If text deltas were yielded above, has_content is True
|
||||
# and we skip the retry path — nothing was yielded in vain.)
|
||||
has_content = accumulated_text or tool_calls_acc
|
||||
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
# If the conversation ends with an assistant message,
|
||||
# an empty stream is expected (nothing new to say).
|
||||
# Don't retry — just flush whatever we have.
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(full_messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
logger.debug(
|
||||
"[stream] Empty response after assistant message — "
|
||||
"expected, not retrying."
|
||||
)
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
)
|
||||
dump_path = _dump_failed_request(
|
||||
model=self.model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_stream",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} returned empty stream — "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
# Success (or final attempt) — flush remaining events.
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
|
||||
except RateLimitError as e:
|
||||
if attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
@@ -2,10 +2,16 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
|
||||
class MockLLMProvider(LLMProvider):
|
||||
@@ -175,3 +181,28 @@ class MockLLMProvider(LLMProvider):
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a mock completion as word-level TextDeltaEvents.
|
||||
|
||||
Splits the mock response into words and yields each as a separate
|
||||
TextDeltaEvent with an accumulating snapshot, exercising the full
|
||||
streaming pipeline without any API calls.
|
||||
"""
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
|
||||
for i, word in enumerate(words):
|
||||
chunk = word if i == 0 else " " + word
|
||||
accumulated += chunk
|
||||
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
|
||||
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(stop_reason="mock_complete", model=self.model)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -108,3 +108,45 @@ class LLMProvider(ABC):
|
||||
Final LLMResponse after tool use completes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator["StreamEvent"]:
|
||||
"""
|
||||
Stream a completion as an async iterator of StreamEvents.
|
||||
|
||||
Default implementation wraps complete() with synthetic events.
|
||||
Subclasses SHOULD override for true streaming.
|
||||
|
||||
Tool orchestration is the CALLER's responsibility:
|
||||
- Caller detects ToolCallEvent, executes tool, adds result
|
||||
to messages, calls stream() again.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
response = self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
|
||||
# Deferred import target for type annotation
|
||||
from framework.llm.stream_events import StreamEvent as StreamEvent # noqa: E402, F401
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Stream event types for LLM streaming responses.
|
||||
|
||||
Defines a discriminated union of frozen dataclasses representing every event
|
||||
a streaming LLM call can produce. These types form the contract between the
|
||||
LLM provider layer, EventLoopNode, event bus, persistence, and monitoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextDeltaEvent:
|
||||
"""A chunk of text produced by the LLM."""
|
||||
|
||||
type: Literal["text_delta"] = "text_delta"
|
||||
content: str = "" # this chunk's text
|
||||
snapshot: str = "" # accumulated text so far
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextEndEvent:
|
||||
"""Signals that text generation is complete."""
|
||||
|
||||
type: Literal["text_end"] = "text_end"
|
||||
full_text: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolCallEvent:
|
||||
"""The LLM has requested a tool call."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
tool_use_id: str = ""
|
||||
tool_name: str = ""
|
||||
tool_input: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolResultEvent:
|
||||
"""Result of executing a tool call."""
|
||||
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str = ""
|
||||
content: str = ""
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningStartEvent:
|
||||
"""The LLM has started a reasoning/thinking block."""
|
||||
|
||||
type: Literal["reasoning_start"] = "reasoning_start"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningDeltaEvent:
|
||||
"""A chunk of reasoning/thinking content."""
|
||||
|
||||
type: Literal["reasoning_delta"] = "reasoning_delta"
|
||||
content: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinishEvent:
|
||||
"""The LLM has finished generating."""
|
||||
|
||||
type: Literal["finish"] = "finish"
|
||||
stop_reason: str = ""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamErrorEvent:
|
||||
"""An error occurred during streaming."""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: str = ""
|
||||
recoverable: bool = False
|
||||
|
||||
|
||||
# Discriminated union of all stream event types
|
||||
StreamEvent = (
|
||||
TextDeltaEvent
|
||||
| TextEndEvent
|
||||
| ToolCallEvent
|
||||
| ToolResultEvent
|
||||
| ReasoningStartEvent
|
||||
| ReasoningDeltaEvent
|
||||
| FinishEvent
|
||||
| StreamErrorEvent
|
||||
)
|
||||
@@ -22,6 +22,7 @@ from framework.graph.plan import Plan
|
||||
from framework.testing.prompts import (
|
||||
PYTEST_TEST_FILE_HEADER,
|
||||
)
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
# Initialize MCP server
|
||||
mcp = FastMCP("agent-builder")
|
||||
@@ -122,11 +123,11 @@ def _save_session(session: BuildSession):
|
||||
|
||||
# Save session file
|
||||
session_file = SESSIONS_DIR / f"{session.id}.json"
|
||||
with open(session_file, "w") as f:
|
||||
with atomic_write(session_file) as f:
|
||||
json.dump(session.to_dict(), f, indent=2, default=str)
|
||||
|
||||
# Update active session pointer
|
||||
with open(ACTIVE_SESSION_FILE, "w") as f:
|
||||
with atomic_write(ACTIVE_SESSION_FILE) as f:
|
||||
f.write(session.id)
|
||||
|
||||
|
||||
@@ -246,7 +247,7 @@ def load_session_by_id(session_id: Annotated[str, "ID of the session to load"])
|
||||
_session = _load_session(session_id)
|
||||
|
||||
# Update active session pointer
|
||||
with open(ACTIVE_SESSION_FILE, "w") as f:
|
||||
with atomic_write(ACTIVE_SESSION_FILE) as f:
|
||||
f.write(session_id)
|
||||
|
||||
return json.dumps(
|
||||
@@ -515,6 +516,36 @@ def _validate_tool_credentials(tools_list: list[str]) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def _validate_agent_path(agent_path: str) -> tuple[Path | None, str | None]:
|
||||
"""
|
||||
Validate and normalize agent_path.
|
||||
|
||||
Returns:
|
||||
(Path, None) if valid
|
||||
(None, error_json) if invalid
|
||||
"""
|
||||
if not agent_path:
|
||||
return None, json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "agent_path is required (e.g., 'exports/my_agent')",
|
||||
}
|
||||
)
|
||||
|
||||
path = Path(agent_path)
|
||||
|
||||
if not path.exists():
|
||||
return None, json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Agent path not found: {path}",
|
||||
"hint": "Run export_graph to create an agent in exports/ first",
|
||||
}
|
||||
)
|
||||
|
||||
return path, None
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def add_node(
|
||||
node_id: Annotated[str, "Unique identifier for the node"],
|
||||
@@ -1488,13 +1519,13 @@ def export_graph() -> str:
|
||||
|
||||
# Write agent.json
|
||||
agent_json_path = exports_dir / "agent.json"
|
||||
with open(agent_json_path, "w") as f:
|
||||
with atomic_write(agent_json_path) as f:
|
||||
json.dump(export_data, f, indent=2, default=str)
|
||||
|
||||
# Generate README.md
|
||||
readme_content = _generate_readme(session, export_data, all_tools)
|
||||
readme_path = exports_dir / "README.md"
|
||||
with open(readme_path, "w") as f:
|
||||
with atomic_write(readme_path) as f:
|
||||
f.write(readme_content)
|
||||
|
||||
# Write mcp_servers.json if MCP servers are configured
|
||||
@@ -1503,8 +1534,9 @@ def export_graph() -> str:
|
||||
if session.mcp_servers:
|
||||
mcp_config = {"servers": session.mcp_servers}
|
||||
mcp_servers_path = exports_dir / "mcp_servers.json"
|
||||
with open(mcp_servers_path, "w") as f:
|
||||
with atomic_write(mcp_servers_path) as f:
|
||||
json.dump(mcp_config, f, indent=2)
|
||||
|
||||
mcp_servers_size = mcp_servers_path.stat().st_size
|
||||
|
||||
# Get file sizes
|
||||
@@ -2595,10 +2627,11 @@ def generate_constraint_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
agent_module = _get_agent_module_from_path(agent_path)
|
||||
agent_module = _get_agent_module_from_path(path)
|
||||
|
||||
# Format constraints for display
|
||||
constraints_formatted = (
|
||||
@@ -2617,9 +2650,9 @@ def generate_constraint_tests(
|
||||
return json.dumps(
|
||||
{
|
||||
"goal_id": goal_id,
|
||||
"agent_path": agent_path,
|
||||
"agent_path": str(path),
|
||||
"agent_module": agent_module,
|
||||
"output_file": f"{agent_path}/tests/test_constraints.py",
|
||||
"output_file": f"{str(path)}/tests/test_constraints.py",
|
||||
"constraints": [c.model_dump() for c in goal.constraints] if goal.constraints else [],
|
||||
"constraints_formatted": constraints_formatted,
|
||||
"test_guidelines": {
|
||||
@@ -2675,10 +2708,11 @@ def generate_success_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
agent_module = _get_agent_module_from_path(agent_path)
|
||||
agent_module = _get_agent_module_from_path(path)
|
||||
|
||||
# Parse node/tool names for context
|
||||
nodes = [n.strip() for n in node_names.split(",") if n.strip()]
|
||||
@@ -2703,9 +2737,9 @@ def generate_success_tests(
|
||||
return json.dumps(
|
||||
{
|
||||
"goal_id": goal_id,
|
||||
"agent_path": agent_path,
|
||||
"agent_path": str(path),
|
||||
"agent_module": agent_module,
|
||||
"output_file": f"{agent_path}/tests/test_success_criteria.py",
|
||||
"output_file": f"{str(path)}/tests/test_success_criteria.py",
|
||||
"success_criteria": [c.model_dump() for c in goal.success_criteria]
|
||||
if goal.success_criteria
|
||||
else [],
|
||||
@@ -2764,7 +2798,11 @@ def run_tests(
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -2955,10 +2993,11 @@ def debug_test(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -3099,10 +3138,11 @@ def list_tests(
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
|
||||
if not agent_path:
|
||||
return json.dumps({"error": "agent_path required (e.g., 'exports/my_agent')"})
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
|
||||
tests_dir = Path(agent_path) / "tests"
|
||||
tests_dir = path / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return json.dumps(
|
||||
@@ -3377,7 +3417,7 @@ def store_credential(
|
||||
display_name: Annotated[str, "Human-readable name (e.g., 'HubSpot Access Token')"] = "",
|
||||
) -> str:
|
||||
"""
|
||||
Store a credential securely in the encrypted credential store at ~/.hive/credentials.
|
||||
Store a credential securely in the local encrypted store at ~/.hive/credentials.
|
||||
|
||||
Uses Fernet encryption (AES-128-CBC + HMAC). Requires HIVE_CREDENTIAL_KEY env var.
|
||||
"""
|
||||
@@ -3419,7 +3459,7 @@ def store_credential(
|
||||
@mcp.tool()
|
||||
def list_stored_credentials() -> str:
|
||||
"""
|
||||
List all credentials currently stored in the encrypted credential store.
|
||||
List all credentials currently stored in the local encrypted store.
|
||||
|
||||
Returns credential IDs and metadata (never returns secret values).
|
||||
"""
|
||||
@@ -3459,7 +3499,7 @@ def delete_stored_credential(
|
||||
credential_name: Annotated[str, "Logical credential name to delete (e.g., 'hubspot')"],
|
||||
) -> str:
|
||||
"""
|
||||
Delete a credential from the encrypted credential store.
|
||||
Delete a credential from the local encrypted store.
|
||||
"""
|
||||
try:
|
||||
store = _get_credential_store()
|
||||
|
||||
@@ -411,25 +411,8 @@ class AgentRunner:
|
||||
return self._tool_registry.register_mcp_server(server_config)
|
||||
|
||||
def _load_mcp_servers_from_config(self, config_path: Path) -> None:
|
||||
"""
|
||||
Load and register MCP servers from a configuration file.
|
||||
|
||||
Args:
|
||||
config_path: Path to mcp_servers.json file
|
||||
"""
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
servers = config.get("servers", [])
|
||||
for server_config in servers:
|
||||
try:
|
||||
self._tool_registry.register_mcp_server(server_config)
|
||||
except Exception as e:
|
||||
server_name = server_config.get("name", "unknown")
|
||||
logger.warning(f"Failed to register MCP server '{server_name}': {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP servers config from {config_path}: {e}")
|
||||
"""Load and register MCP servers from a configuration file."""
|
||||
self._tool_registry.load_mcp_config(config_path)
|
||||
|
||||
def set_approval_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
|
||||
@@ -257,6 +257,34 @@ class ToolRegistry:
|
||||
"""
|
||||
self._session_context.update(context)
|
||||
|
||||
def load_mcp_config(self, config_path: Path) -> None:
|
||||
"""
|
||||
Load and register MCP servers from a config file.
|
||||
|
||||
Resolves relative ``cwd`` paths against the config file's parent
|
||||
directory so callers never need to handle path resolution themselves.
|
||||
|
||||
Args:
|
||||
config_path: Path to an ``mcp_servers.json`` file.
|
||||
"""
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP config from {config_path}: {e}")
|
||||
return
|
||||
|
||||
base_dir = config_path.parent
|
||||
for server_config in config.get("servers", []):
|
||||
cwd = server_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
server_config["cwd"] = str((base_dir / cwd).resolve())
|
||||
try:
|
||||
self.register_mcp_server(server_config)
|
||||
except Exception as e:
|
||||
name = server_config.get("name", "unknown")
|
||||
logger.warning(f"Failed to register MCP server '{name}': {e}")
|
||||
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
@@ -309,11 +337,21 @@ class ToolRegistry:
|
||||
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
|
||||
|
||||
# Create executor that calls the MCP server
|
||||
def make_mcp_executor(client_ref: MCPClient, tool_name: str, registry_ref):
|
||||
def make_mcp_executor(
|
||||
client_ref: MCPClient,
|
||||
tool_name: str,
|
||||
registry_ref,
|
||||
tool_params: set[str],
|
||||
):
|
||||
def executor(inputs: dict) -> Any:
|
||||
try:
|
||||
# Inject session context for tools that need it
|
||||
merged_inputs = {**registry_ref._session_context, **inputs}
|
||||
# Only inject session context params the tool accepts
|
||||
filtered_context = {
|
||||
k: v
|
||||
for k, v in registry_ref._session_context.items()
|
||||
if k in tool_params
|
||||
}
|
||||
merged_inputs = {**filtered_context, **inputs}
|
||||
result = client_ref.call_tool(tool_name, merged_inputs)
|
||||
# MCP tools return content array, extract the result
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
@@ -327,10 +365,11 @@ class ToolRegistry:
|
||||
|
||||
return executor
|
||||
|
||||
tool_params = set(mcp_tool.input_schema.get("properties", {}).keys())
|
||||
self.register(
|
||||
mcp_tool.name,
|
||||
tool,
|
||||
make_mcp_executor(client, mcp_tool.name, self),
|
||||
make_mcp_executor(client, mcp_tool.name, self, tool_params),
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -12,13 +12,13 @@ import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
class EventType(StrEnum):
|
||||
"""Types of events that can be published."""
|
||||
|
||||
# Execution lifecycle
|
||||
@@ -41,6 +41,28 @@ class EventType(str, Enum):
|
||||
STREAM_STARTED = "stream_started"
|
||||
STREAM_STOPPED = "stream_stopped"
|
||||
|
||||
# Node event-loop lifecycle
|
||||
NODE_LOOP_STARTED = "node_loop_started"
|
||||
NODE_LOOP_ITERATION = "node_loop_iteration"
|
||||
NODE_LOOP_COMPLETED = "node_loop_completed"
|
||||
|
||||
# LLM streaming observability
|
||||
LLM_TEXT_DELTA = "llm_text_delta"
|
||||
LLM_REASONING_DELTA = "llm_reasoning_delta"
|
||||
|
||||
# Tool lifecycle
|
||||
TOOL_CALL_STARTED = "tool_call_started"
|
||||
TOOL_CALL_COMPLETED = "tool_call_completed"
|
||||
|
||||
# Client I/O (client_facing=True nodes only)
|
||||
CLIENT_OUTPUT_DELTA = "client_output_delta"
|
||||
CLIENT_INPUT_REQUESTED = "client_input_requested"
|
||||
|
||||
# Internal node observability (client_facing=False nodes)
|
||||
NODE_INTERNAL_OUTPUT = "node_internal_output"
|
||||
NODE_INPUT_BLOCKED = "node_input_blocked"
|
||||
NODE_STALLED = "node_stalled"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
|
||||
@@ -51,6 +73,7 @@ class AgentEvent:
|
||||
|
||||
type: EventType
|
||||
stream_id: str
|
||||
node_id: str | None = None # Which node emitted this event
|
||||
execution_id: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
@@ -61,6 +84,7 @@ class AgentEvent:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"stream_id": self.stream_id,
|
||||
"node_id": self.node_id,
|
||||
"execution_id": self.execution_id,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
@@ -80,6 +104,7 @@ class Subscription:
|
||||
event_types: set[EventType]
|
||||
handler: EventHandler
|
||||
filter_stream: str | None = None # Only receive events from this stream
|
||||
filter_node: str | None = None # Only receive events from this node
|
||||
filter_execution: str | None = None # Only receive events from this execution
|
||||
|
||||
|
||||
@@ -138,6 +163,7 @@ class EventBus:
|
||||
event_types: list[EventType],
|
||||
handler: EventHandler,
|
||||
filter_stream: str | None = None,
|
||||
filter_node: str | None = None,
|
||||
filter_execution: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -147,6 +173,7 @@ class EventBus:
|
||||
event_types: Types of events to receive
|
||||
handler: Async function to call when event occurs
|
||||
filter_stream: Only receive events from this stream
|
||||
filter_node: Only receive events from this node
|
||||
filter_execution: Only receive events from this execution
|
||||
|
||||
Returns:
|
||||
@@ -160,6 +187,7 @@ class EventBus:
|
||||
event_types=set(event_types),
|
||||
handler=handler,
|
||||
filter_stream=filter_stream,
|
||||
filter_node=filter_node,
|
||||
filter_execution=filter_execution,
|
||||
)
|
||||
|
||||
@@ -218,6 +246,10 @@ class EventBus:
|
||||
if subscription.filter_stream and subscription.filter_stream != event.stream_id:
|
||||
return False
|
||||
|
||||
# Check node filter
|
||||
if subscription.filter_node and subscription.filter_node != event.node_id:
|
||||
return False
|
||||
|
||||
# Check execution filter
|
||||
if subscription.filter_execution and subscription.filter_execution != event.execution_id:
|
||||
return False
|
||||
@@ -359,6 +391,248 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
# === NODE EVENT-LOOP PUBLISHERS ===
|
||||
|
||||
async def emit_node_loop_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str | None = None,
|
||||
max_iterations: int | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop started event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"max_iterations": max_iterations},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_loop_iteration(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iteration: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop iteration event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_ITERATION,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"iteration": iteration},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_loop_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iterations: int,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node loop completed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"iterations": iterations},
|
||||
)
|
||||
)
|
||||
|
||||
# === LLM STREAMING PUBLISHERS ===
|
||||
|
||||
async def emit_llm_text_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM text delta event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_llm_reasoning_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM reasoning delta event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_REASONING_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content},
|
||||
)
|
||||
)
|
||||
|
||||
# === TOOL LIFECYCLE PUBLISHERS ===
|
||||
|
||||
async def emit_tool_call_started(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any] | None = None,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool call started event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_use_id": tool_use_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_input": tool_input or {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_tool_call_completed(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
result: str = "",
|
||||
is_error: bool = False,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit tool call completed event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"tool_use_id": tool_use_id,
|
||||
"tool_name": tool_name,
|
||||
"result": result,
|
||||
"is_error": is_error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === CLIENT I/O PUBLISHERS ===
|
||||
|
||||
async def emit_client_output_delta(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit client output delta event (client_facing=True nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_OUTPUT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_client_input_requested(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
prompt: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit client input requested event (client_facing=True nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"prompt": prompt},
|
||||
)
|
||||
)
|
||||
|
||||
# === INTERNAL NODE PUBLISHERS ===
|
||||
|
||||
async def emit_node_internal_output(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node internal output event (client_facing=False nodes)."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_INTERNAL_OUTPUT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_stalled(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
reason: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node stalled event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_STALLED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"reason": reason},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_node_input_blocked(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
prompt: str = "",
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit node input blocked event."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_INPUT_BLOCKED,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"prompt": prompt},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
@@ -410,6 +684,7 @@ class EventBus:
|
||||
self,
|
||||
event_type: EventType,
|
||||
stream_id: str | None = None,
|
||||
node_id: str | None = None,
|
||||
execution_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AgentEvent | None:
|
||||
@@ -419,6 +694,7 @@ class EventBus:
|
||||
Args:
|
||||
event_type: Type of event to wait for
|
||||
stream_id: Filter by stream
|
||||
node_id: Filter by node
|
||||
execution_id: Filter by execution
|
||||
timeout: Maximum time to wait (seconds)
|
||||
|
||||
@@ -438,6 +714,7 @@ class EventBus:
|
||||
event_types=[event_type],
|
||||
handler=handler,
|
||||
filter_stream=stream_id,
|
||||
filter_node=node_id,
|
||||
filter_execution=execution_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IsolationLevel(str, Enum):
|
||||
class IsolationLevel(StrEnum):
|
||||
"""State isolation level for concurrent executions."""
|
||||
|
||||
ISOLATED = "isolated" # Private state per execution
|
||||
@@ -25,7 +25,7 @@ class IsolationLevel(str, Enum):
|
||||
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
|
||||
|
||||
|
||||
class StateScope(str, Enum):
|
||||
class StateScope(StrEnum):
|
||||
"""Scope for state operations."""
|
||||
|
||||
EXECUTION = "execution" # Local to a single execution
|
||||
|
||||
@@ -10,13 +10,13 @@ This is MORE important than actions because:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
class DecisionType(StrEnum):
|
||||
"""Types of decisions an agent can make."""
|
||||
|
||||
TOOL_SELECTION = "tool_selection" # Which tool to use
|
||||
|
||||
@@ -6,7 +6,7 @@ summaries and metrics that Builder needs to understand what happened.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
@@ -14,7 +14,7 @@ from pydantic import BaseModel, Field, computed_field
|
||||
from framework.schemas.decision import Decision, Outcome
|
||||
|
||||
|
||||
class RunStatus(str, Enum):
|
||||
class RunStatus(StrEnum):
|
||||
"""Status of a run."""
|
||||
|
||||
RUNNING = "running"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Storage backends for runtime data."""
|
||||
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
__all__ = ["FileStorage"]
|
||||
__all__ = ["FileStorage", "FileConversationStore"]
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
from pathlib import Path
|
||||
|
||||
from framework.schemas.run import Run, RunStatus, RunSummary
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
|
||||
class FileStorage:
|
||||
@@ -86,13 +87,13 @@ class FileStorage:
|
||||
"""Save a run to storage."""
|
||||
# Save full run using Pydantic's model_dump_json
|
||||
run_path = self.base_path / "runs" / f"{run.id}.json"
|
||||
with open(run_path, "w", encoding="utf-8") as f:
|
||||
with atomic_write(run_path) as f:
|
||||
f.write(run.model_dump_json(indent=2))
|
||||
|
||||
# Save summary
|
||||
summary = RunSummary.from_run(run)
|
||||
summary_path = self.base_path / "summaries" / f"{run.id}.json"
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
with atomic_write(summary_path) as f:
|
||||
f.write(summary.model_dump_json(indent=2))
|
||||
|
||||
# Update indexes
|
||||
@@ -188,8 +189,8 @@ class FileStorage:
|
||||
values = self._get_index(index_type, key) # Already validated in _get_index
|
||||
if value not in values:
|
||||
values.append(value)
|
||||
with open(index_path, "w", encoding="utf-8") as f:
|
||||
json.dump(values, f)
|
||||
with atomic_write(index_path) as f:
|
||||
json.dump(values, f, indent=2)
|
||||
|
||||
def _remove_from_index(self, index_type: str, key: str, value: str) -> None:
|
||||
"""Remove a value from an index."""
|
||||
@@ -198,8 +199,8 @@ class FileStorage:
|
||||
values = self._get_index(index_type, key) # Already validated in _get_index
|
||||
if value in values:
|
||||
values.remove(value)
|
||||
with open(index_path, "w", encoding="utf-8") as f:
|
||||
json.dump(values, f)
|
||||
with atomic_write(index_path) as f:
|
||||
json.dump(values, f, indent=2)
|
||||
|
||||
# === UTILITY ===
|
||||
|
||||
|
||||
@@ -167,14 +167,18 @@ class ConcurrentStorage:
|
||||
run: Run to save
|
||||
immediate: If True, save immediately (bypasses batching)
|
||||
"""
|
||||
# Invalidate summary cache since the run data is changing
|
||||
# This ensures load_summary() fetches fresh data after the save
|
||||
self._cache.pop(f"summary:{run.id}", None)
|
||||
|
||||
if immediate or not self._running:
|
||||
await self._save_run_locked(run)
|
||||
# Update cache only after successful immediate write
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
else:
|
||||
# For batched writes, cache will be updated in _flush_batch after successful write
|
||||
await self._write_queue.put(("run", run))
|
||||
|
||||
# Update cache
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
|
||||
async def _save_run_locked(self, run: Run) -> None:
|
||||
"""Save a run with file locking, including index locks."""
|
||||
lock_key = f"run:{run.id}"
|
||||
@@ -363,8 +367,12 @@ class ConcurrentStorage:
|
||||
try:
|
||||
if item_type == "run":
|
||||
await self._save_run_locked(item)
|
||||
# Update cache only after successful batched write
|
||||
# This fixes the race condition where cache was updated before write completed
|
||||
self._cache[f"run:{item.id}"] = CacheEntry(item, time.time())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {item_type}: {e}")
|
||||
# Cache is NOT updated on failure - prevents stale/inconsistent cache state
|
||||
|
||||
async def _flush_pending(self) -> None:
|
||||
"""Flush all pending writes."""
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
"""File-per-part ConversationStore implementation.
|
||||
|
||||
Each conversation part is stored as a separate JSON file under a
|
||||
``parts/`` subdirectory. Meta and cursor are stored as ``meta.json``
|
||||
and ``cursor.json`` in the base directory.
|
||||
|
||||
Directory layout::
|
||||
|
||||
{base_path}/
|
||||
meta.json
|
||||
cursor.json
|
||||
parts/
|
||||
0000000000.json
|
||||
0000000001.json
|
||||
...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class FileConversationStore:
|
||||
"""File-per-part ConversationStore.
|
||||
|
||||
Uses one JSON file per message part, with ``pathlib.Path`` for
|
||||
cross-platform path handling and ``asyncio.to_thread`` for
|
||||
non-blocking I/O.
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: str | Path) -> None:
|
||||
self._base = Path(base_path)
|
||||
self._parts_dir = self._base / "parts"
|
||||
|
||||
# --- sync helpers --------------------------------------------------------
|
||||
|
||||
def _write_json(self, path: Path, data: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _read_json(self, path: Path) -> dict | None:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
|
||||
# --- async wrapper -------------------------------------------------------
|
||||
|
||||
async def _run(self, fn, *args):
|
||||
return await asyncio.to_thread(fn, *args)
|
||||
|
||||
# --- ConversationStore interface -----------------------------------------
|
||||
|
||||
async def write_part(self, seq: int, data: dict[str, Any]) -> None:
|
||||
path = self._parts_dir / f"{seq:010d}.json"
|
||||
await self._run(self._write_json, path, data)
|
||||
|
||||
async def read_parts(self) -> list[dict[str, Any]]:
|
||||
def _read_all() -> list[dict[str, Any]]:
|
||||
if not self._parts_dir.exists():
|
||||
return []
|
||||
files = sorted(self._parts_dir.glob("*.json"))
|
||||
parts = []
|
||||
for f in files:
|
||||
data = self._read_json(f)
|
||||
if data is not None:
|
||||
parts.append(data)
|
||||
return parts
|
||||
|
||||
return await self._run(_read_all)
|
||||
|
||||
async def write_meta(self, data: dict[str, Any]) -> None:
|
||||
await self._run(self._write_json, self._base / "meta.json", data)
|
||||
|
||||
async def read_meta(self) -> dict[str, Any] | None:
|
||||
return await self._run(self._read_json, self._base / "meta.json")
|
||||
|
||||
async def write_cursor(self, data: dict[str, Any]) -> None:
|
||||
await self._run(self._write_json, self._base / "cursor.json", data)
|
||||
|
||||
async def read_cursor(self) -> dict[str, Any] | None:
|
||||
return await self._run(self._read_json, self._base / "cursor.json")
|
||||
|
||||
async def delete_parts_before(self, seq: int) -> None:
|
||||
def _delete() -> None:
|
||||
if not self._parts_dir.exists():
|
||||
return
|
||||
for f in self._parts_dir.glob("*.json"):
|
||||
file_seq = int(f.stem)
|
||||
if file_seq < seq:
|
||||
f.unlink()
|
||||
|
||||
await self._run(_delete)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""No-op — no persistent handles for file-per-part storage."""
|
||||
pass
|
||||
|
||||
async def destroy(self) -> None:
|
||||
"""Delete the entire base directory and all persisted data."""
|
||||
|
||||
def _destroy() -> None:
|
||||
if self._base.exists():
|
||||
shutil.rmtree(self._base)
|
||||
|
||||
await self._run(_destroy)
|
||||
@@ -6,13 +6,13 @@ programmatic/MCP-based approval.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalAction(str, Enum):
|
||||
class ApprovalAction(StrEnum):
|
||||
"""Actions a user can take on a generated test."""
|
||||
|
||||
APPROVE = "approve" # Accept as-is
|
||||
|
||||
@@ -24,7 +24,7 @@ def _get_api_key():
|
||||
# 1. Try CredentialStoreAdapter for Anthropic
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except (ImportError, KeyError):
|
||||
@@ -57,7 +57,7 @@ def _get_api_key():
|
||||
"""Get API key from CredentialStoreAdapter or environment."""
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
creds = CredentialStoreAdapter.default()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except (ImportError, KeyError):
|
||||
|
||||
@@ -6,13 +6,13 @@ but require mandatory user approval before being stored.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
class ApprovalStatus(StrEnum):
|
||||
"""Status of user approval for a generated test."""
|
||||
|
||||
PENDING = "pending" # Awaiting user review
|
||||
@@ -21,7 +21,7 @@ class ApprovalStatus(str, Enum):
|
||||
REJECTED = "rejected" # User declined (with reason)
|
||||
|
||||
|
||||
class TestType(str, Enum):
|
||||
class TestType(StrEnum):
|
||||
"""Type of test based on what it validates."""
|
||||
|
||||
__test__ = False # Not a pytest test class
|
||||
|
||||
@@ -6,13 +6,13 @@ categorization for guiding iteration strategy.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorCategory(str, Enum):
|
||||
class ErrorCategory(StrEnum):
|
||||
"""
|
||||
Category of test failure for guiding iteration.
|
||||
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def atomic_write(path: Path, mode: str = "w", encoding: str = "utf-8"):
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
try:
|
||||
with open(tmp_path, mode, encoding=encoding) as f:
|
||||
yield f
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
tmp_path.replace(path)
|
||||
except BaseException:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
raise
|
||||
@@ -14,6 +14,7 @@ dependencies = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"pytest-xdist>=3.0",
|
||||
"tools",
|
||||
]
|
||||
|
||||
# [project.optional-dependencies]
|
||||
@@ -21,6 +22,9 @@ dependencies = [
|
||||
[project.scripts]
|
||||
hive = "framework.cli:main"
|
||||
|
||||
[tool.uv.sources]
|
||||
tools = { workspace = true }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
@@ -43,6 +47,7 @@ lint.select = [
|
||||
"W", # pycodestyle warnings
|
||||
]
|
||||
|
||||
lint.per-file-ignores."demos/*" = ["E501"]
|
||||
lint.isort.combine-as-imports = true
|
||||
lint.isort.known-first-party = ["framework"]
|
||||
lint.isort.section-order = [
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
# Development dependencies
|
||||
-r requirements.txt
|
||||
|
||||
# Testing
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.23
|
||||
|
||||
# Linting & type checking
|
||||
ruff>=0.1.0
|
||||
mypy>=1.0
|
||||
@@ -1,14 +0,0 @@
|
||||
# Core dependencies
|
||||
pydantic>=2.0
|
||||
anthropic>=0.40.0
|
||||
httpx>=0.27.0
|
||||
litellm>=1.81.0
|
||||
|
||||
# MCP server dependencies
|
||||
mcp
|
||||
fastmcp
|
||||
|
||||
# Testing (required for test framework)
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.23
|
||||
pytest-xdist>=3.0
|
||||
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Tests for client-facing fan-out and event_loop output_key overlap validation.
|
||||
|
||||
Validates two rules added to GraphSpec.validate():
|
||||
1. Fan-out must not have multiple client_facing=True targets.
|
||||
2. Parallel event_loop nodes must have disjoint output_keys.
|
||||
"""
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rule 1: client_facing fan-out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClientFacingFanOut:
|
||||
"""Fan-out to multiple client_facing=True targets must be rejected."""
|
||||
|
||||
def test_fan_out_two_client_facing_fails(self):
|
||||
"""Two client-facing targets on the same fan-out -> error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
|
||||
NodeSpec(id="b", name="b", description="Node b", client_facing=True),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
cf_errors = [e for e in errors if "multiple client-facing" in e]
|
||||
assert len(cf_errors) == 1
|
||||
assert "'src'" in cf_errors[0]
|
||||
|
||||
def test_fan_out_one_client_facing_passes(self):
|
||||
"""Only one client-facing target -> no error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
|
||||
NodeSpec(id="b", name="b", description="Node b", client_facing=False),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
cf_errors = [e for e in errors if "multiple client-facing" in e]
|
||||
assert len(cf_errors) == 0
|
||||
|
||||
def test_fan_out_zero_client_facing_passes(self):
|
||||
"""No client-facing targets at all -> no error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(id="a", name="a", description="Node a"),
|
||||
NodeSpec(id="b", name="b", description="Node b"),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
cf_errors = [e for e in errors if "multiple client-facing" in e]
|
||||
assert len(cf_errors) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rule 2: event_loop output_key overlap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEventLoopOutputKeyOverlap:
|
||||
"""Parallel event_loop nodes with overlapping output_keys must be rejected."""
|
||||
|
||||
def test_overlapping_output_keys_event_loop_fails(self):
|
||||
"""Two event_loop nodes sharing an output_key -> error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="a",
|
||||
description="Node a",
|
||||
node_type="event_loop",
|
||||
output_keys=["status", "shared"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="event_loop",
|
||||
output_keys=["result", "shared"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(key_errors) == 1
|
||||
assert "'shared'" in key_errors[0]
|
||||
|
||||
def test_disjoint_output_keys_event_loop_passes(self):
|
||||
"""Two event_loop nodes with disjoint output_keys -> no error."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="a",
|
||||
description="Node a",
|
||||
node_type="event_loop",
|
||||
output_keys=["status"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(key_errors) == 0
|
||||
|
||||
def test_overlapping_keys_non_event_loop_no_error(self):
|
||||
"""Non-event_loop nodes with overlapping keys -> no error (last-wins OK)."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="src",
|
||||
nodes=[
|
||||
NodeSpec(id="src", name="src", description="Source node"),
|
||||
NodeSpec(
|
||||
id="a",
|
||||
name="a",
|
||||
description="Node a",
|
||||
node_type="llm_generate",
|
||||
output_keys=["shared"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="llm_generate",
|
||||
output_keys=["shared"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="src->a", source="src", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="src->b", source="src", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(key_errors) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Baseline: no fan-out -> no errors from these rules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoFanOutUnaffected:
|
||||
"""Linear graphs should not trigger either validation rule."""
|
||||
|
||||
def test_no_fan_out_unaffected(self):
|
||||
"""Linear chain with client_facing and event_loop nodes -> no errors."""
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="goal1",
|
||||
entry_node="a",
|
||||
terminal_nodes=["c"],
|
||||
nodes=[
|
||||
NodeSpec(id="a", name="a", description="Node a", client_facing=True),
|
||||
NodeSpec(
|
||||
id="b",
|
||||
name="b",
|
||||
description="Node b",
|
||||
node_type="event_loop",
|
||||
output_keys=["x"],
|
||||
),
|
||||
NodeSpec(
|
||||
id="c",
|
||||
name="c",
|
||||
description="Node c",
|
||||
client_facing=True,
|
||||
node_type="event_loop",
|
||||
output_keys=["x"],
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
EdgeSpec(id="a->b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b->c", source="b", target="c", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
)
|
||||
|
||||
errors = graph.validate()
|
||||
cf_errors = [e for e in errors if "multiple client-facing" in e]
|
||||
key_errors = [e for e in errors if "output_key" in e]
|
||||
assert len(cf_errors) == 0
|
||||
assert len(key_errors) == 0
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Tests for ClientIO gateway (WP-9).
|
||||
|
||||
Covers:
|
||||
- ActiveNodeClientIO: emit_output → output_stream round-trip, request_input, timeout
|
||||
- InertNodeClientIO: emit_output publishes NODE_INTERNAL_OUTPUT, request_input returns redirect
|
||||
- ClientIOGateway: factory creates correct variant
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.client_io import (
|
||||
ActiveNodeClientIO,
|
||||
ClientIOGateway,
|
||||
InertNodeClientIO,
|
||||
NodeClientIO,
|
||||
)
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
_AGENT_EVENT_FIELDS = {"stream_id", "node_id", "execution_id", "correlation_id"}
|
||||
|
||||
|
||||
class MockEventBus:
|
||||
"""Lightweight stand-in for EventBus that records published events."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: list[AgentEvent] = []
|
||||
|
||||
async def _record(self, event_type: EventType, **kwargs) -> None:
|
||||
agent_kwargs = {k: v for k, v in kwargs.items() if k in _AGENT_EVENT_FIELDS}
|
||||
data = {k: v for k, v in kwargs.items() if k not in _AGENT_EVENT_FIELDS}
|
||||
self.events.append(AgentEvent(type=event_type, **agent_kwargs, data=data))
|
||||
|
||||
async def emit_client_output_delta(self, **kwargs) -> None:
|
||||
await self._record(EventType.CLIENT_OUTPUT_DELTA, **kwargs)
|
||||
|
||||
async def emit_client_input_requested(self, **kwargs) -> None:
|
||||
await self._record(EventType.CLIENT_INPUT_REQUESTED, **kwargs)
|
||||
|
||||
async def emit_node_internal_output(self, **kwargs) -> None:
|
||||
await self._record(EventType.NODE_INTERNAL_OUTPUT, **kwargs)
|
||||
|
||||
async def emit_node_input_blocked(self, **kwargs) -> None:
|
||||
await self._record(EventType.NODE_INPUT_BLOCKED, **kwargs)
|
||||
|
||||
|
||||
# --- ActiveNodeClientIO tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_emit_and_consume():
|
||||
"""emit_output → output_stream round-trip works correctly."""
|
||||
bus = MockEventBus()
|
||||
io = ActiveNodeClientIO(node_id="n1", event_bus=bus)
|
||||
|
||||
await io.emit_output("Hello ")
|
||||
await io.emit_output("World", is_final=True)
|
||||
|
||||
chunks = []
|
||||
async for chunk in io.output_stream():
|
||||
chunks.append(chunk)
|
||||
|
||||
assert chunks == ["Hello ", "World"]
|
||||
assert len(bus.events) == 2
|
||||
assert all(e.type == EventType.CLIENT_OUTPUT_DELTA for e in bus.events)
|
||||
# Verify snapshot accumulates
|
||||
assert bus.events[0].data["snapshot"] == "Hello "
|
||||
assert bus.events[1].data["snapshot"] == "Hello World"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_request_input():
|
||||
"""request_input blocks until provide_input is called."""
|
||||
bus = MockEventBus()
|
||||
io = ActiveNodeClientIO(node_id="n1", event_bus=bus)
|
||||
|
||||
async def fulfill_later():
|
||||
await asyncio.sleep(0.01)
|
||||
await io.provide_input("user says hi")
|
||||
|
||||
task = asyncio.create_task(fulfill_later())
|
||||
result = await io.request_input(prompt="What?")
|
||||
await task
|
||||
|
||||
assert result == "user says hi"
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
assert bus.events[0].data["prompt"] == "What?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_request_input_timeout():
|
||||
"""request_input raises TimeoutError when timeout expires."""
|
||||
io = ActiveNodeClientIO(node_id="n1")
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await io.request_input(prompt="waiting", timeout=0.01)
|
||||
|
||||
|
||||
# --- InertNodeClientIO tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inert_emit_publishes_internal():
|
||||
"""InertNodeClientIO.emit_output publishes NODE_INTERNAL_OUTPUT."""
|
||||
bus = MockEventBus()
|
||||
io = InertNodeClientIO(node_id="n2", event_bus=bus)
|
||||
|
||||
await io.emit_output("internal log")
|
||||
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.NODE_INTERNAL_OUTPUT
|
||||
assert bus.events[0].data["content"] == "internal log"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inert_request_input_returns_redirect():
|
||||
"""request_input returns a redirect string and publishes NODE_INPUT_BLOCKED."""
|
||||
bus = MockEventBus()
|
||||
io = InertNodeClientIO(node_id="n2", event_bus=bus)
|
||||
|
||||
result = await io.request_input(prompt="need data")
|
||||
|
||||
assert "internal processing node" in result
|
||||
assert len(bus.events) == 1
|
||||
assert bus.events[0].type == EventType.NODE_INPUT_BLOCKED
|
||||
assert bus.events[0].data["prompt"] == "need data"
|
||||
|
||||
|
||||
# --- ClientIOGateway tests ---
|
||||
|
||||
|
||||
def test_gateway_creates_active_for_client_facing():
|
||||
"""ClientIOGateway.create_io returns ActiveNodeClientIO when client_facing=True."""
|
||||
gateway = ClientIOGateway()
|
||||
io = gateway.create_io(node_id="n1", client_facing=True)
|
||||
|
||||
assert isinstance(io, ActiveNodeClientIO)
|
||||
assert isinstance(io, NodeClientIO)
|
||||
|
||||
|
||||
def test_gateway_creates_inert_for_internal():
|
||||
"""ClientIOGateway.create_io returns InertNodeClientIO when client_facing=False."""
|
||||
gateway = ClientIOGateway()
|
||||
io = gateway.create_io(node_id="n2", client_facing=False)
|
||||
|
||||
assert isinstance(io, InertNodeClientIO)
|
||||
assert isinstance(io, NodeClientIO)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tests for ConcurrentStorage race condition and cache invalidation fixes."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.schemas.run import Run, RunMetrics, RunStatus
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
|
||||
def create_test_run(
|
||||
run_id: str, goal_id: str = "test-goal", status: RunStatus = RunStatus.RUNNING
|
||||
) -> Run:
|
||||
"""Create a minimal test Run object."""
|
||||
return Run(
|
||||
id=run_id,
|
||||
goal_id=goal_id,
|
||||
status=status,
|
||||
narrative="Test run",
|
||||
metrics=RunMetrics(
|
||||
nodes_executed=[],
|
||||
),
|
||||
decisions=[],
|
||||
problems=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidation_on_save(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated when a run is saved.
|
||||
|
||||
This tests the fix for the cache invalidation bug where load_summary()
|
||||
would return stale data after a run was updated.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-1"
|
||||
|
||||
# Create and save initial run
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to populate the cache
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.RUNNING
|
||||
|
||||
# Update run with new status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary again - should get fresh data, not cached stale data
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.COMPLETED, (
|
||||
"Summary cache should be invalidated on save - got stale data"
|
||||
)
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_write_cache_consistency(tmp_path: Path):
|
||||
"""Test that cache is only updated after successful batched write.
|
||||
|
||||
This tests the fix for the race condition where cache was updated
|
||||
before the batched write completed.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path, batch_interval=0.05)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-2"
|
||||
|
||||
# Save via batching (immediate=False)
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=False)
|
||||
|
||||
# Before batch flush, cache should NOT contain the run
|
||||
# (This is the fix - previously cache was updated immediately)
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key not in storage._cache, (
|
||||
"Cache should not be updated before batch is flushed"
|
||||
)
|
||||
|
||||
# Wait for batch to flush
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# After batch flush, cache should contain the run
|
||||
assert cache_key in storage._cache, "Cache should be updated after batch flush"
|
||||
|
||||
# Verify data on disk matches cache
|
||||
loaded_run = await storage.load_run(run_id, use_cache=False)
|
||||
assert loaded_run is not None
|
||||
assert loaded_run.id == run_id
|
||||
assert loaded_run.status == RunStatus.RUNNING
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_immediate_write_updates_cache(tmp_path: Path):
|
||||
"""Test that immediate writes still update cache correctly."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-3"
|
||||
|
||||
# Save with immediate=True
|
||||
run = create_test_run(run_id, status=RunStatus.COMPLETED)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Cache should be updated immediately for immediate writes
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key in storage._cache, "Cache should be updated after immediate write"
|
||||
|
||||
# Verify cached value is correct
|
||||
cached_run = storage._cache[cache_key].value
|
||||
assert cached_run.id == run_id
|
||||
assert cached_run.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_cache_invalidated_on_multiple_saves(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated on each save, not just the first."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-4"
|
||||
|
||||
# First save
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to cache it
|
||||
summary1 = await storage.load_summary(run_id)
|
||||
assert summary1.status == RunStatus.RUNNING
|
||||
|
||||
# Second save with new status
|
||||
run.status = RunStatus.RUNNING
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh
|
||||
summary2 = await storage.load_summary(run_id)
|
||||
assert summary2.status == RunStatus.RUNNING
|
||||
|
||||
# Third save with final status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh again
|
||||
summary3 = await storage.load_summary(run_id)
|
||||
assert summary3.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
@@ -0,0 +1,326 @@
|
||||
"""Tests for ContextHandoff and HandoffContext."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.context_handoff import ContextHandoff, HandoffContext
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SpyLLMProvider(MockLLMProvider):
|
||||
"""MockLLMProvider that records whether complete() was called."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.complete_called = False
|
||||
self.complete_call_args: dict[str, Any] | None = None
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
self.complete_called = True
|
||||
self.complete_call_args = {"messages": messages, **kwargs}
|
||||
return super().complete(messages, **kwargs)
|
||||
|
||||
|
||||
class FailingLLMProvider(LLMProvider):
|
||||
"""LLM provider that always raises."""
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list,
|
||||
tool_executor: Any,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
|
||||
async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation:
|
||||
"""Build a NodeConversation from (user, assistant) message pairs."""
|
||||
conv = NodeConversation()
|
||||
for user_msg, assistant_msg in pairs:
|
||||
await conv.add_user_message(user_msg)
|
||||
await conv.add_assistant_message(assistant_msg)
|
||||
return conv
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestHandoffContext
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandoffContext:
|
||||
def test_instantiation(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="node_A",
|
||||
summary="Summary text",
|
||||
key_outputs={"result": "42"},
|
||||
turn_count=3,
|
||||
total_tokens_used=1200,
|
||||
)
|
||||
assert hc.source_node_id == "node_A"
|
||||
assert hc.summary == "Summary text"
|
||||
assert hc.key_outputs == {"result": "42"}
|
||||
assert hc.turn_count == 3
|
||||
assert hc.total_tokens_used == 1200
|
||||
|
||||
def test_field_access(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="n1",
|
||||
summary="s",
|
||||
key_outputs={},
|
||||
turn_count=0,
|
||||
total_tokens_used=0,
|
||||
)
|
||||
assert hc.key_outputs == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestExtractiveSummary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractiveSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_summary_includes_first_last(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("hello", "First response here."),
|
||||
("continue", "Middle response."),
|
||||
("finish", "Final conclusion."),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="test_node")
|
||||
|
||||
assert "First response here." in hc.summary
|
||||
assert "Final conclusion." in hc.summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_summary_metadata(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("hi", "hello"),
|
||||
("bye", "goodbye"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="node_42")
|
||||
|
||||
assert hc.source_node_id == "node_42"
|
||||
assert hc.turn_count == 2
|
||||
assert hc.total_tokens_used > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_with_output_keys_colon(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("what is the answer?", "answer: 42"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["answer"])
|
||||
|
||||
assert hc.key_outputs["answer"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_with_output_keys_equals(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("compute", "result = success"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["result"])
|
||||
|
||||
assert hc.key_outputs["result"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_json_output_keys(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("give me json", '{"score": 95, "grade": "A"}'),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
|
||||
|
||||
assert hc.key_outputs["score"] == "95"
|
||||
assert hc.key_outputs["grade"] == "A"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_empty_conversation(self) -> None:
|
||||
conv = NodeConversation()
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="empty")
|
||||
|
||||
assert hc.summary == "Empty conversation."
|
||||
assert hc.turn_count == 0
|
||||
assert hc.key_outputs == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_no_assistant_messages(self) -> None:
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("hello?")
|
||||
await conv.add_user_message("anyone there?")
|
||||
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="silent")
|
||||
|
||||
assert hc.summary == "No assistant responses."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_most_recent_wins(self) -> None:
|
||||
conv = await _build_conversation(
|
||||
("first", "status: old_value"),
|
||||
("second", "status: new_value"),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n", output_keys=["status"])
|
||||
|
||||
assert hc.key_outputs["status"] == "new_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extractive_truncation(self) -> None:
|
||||
long_text = "x" * 1000
|
||||
conv = await _build_conversation(
|
||||
("go", long_text),
|
||||
)
|
||||
ch = ContextHandoff()
|
||||
hc = ch.summarize_conversation(conv, node_id="n")
|
||||
|
||||
# Summary should be truncated to ~500 chars
|
||||
assert len(hc.summary) <= 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestLLMSummary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_summary_calls_provider(self) -> None:
|
||||
llm = SpyLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("hi", "hello back"),
|
||||
("what now?", "we are done"),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
hc = ch.summarize_conversation(conv, node_id="llm_node")
|
||||
|
||||
assert llm.complete_called, "LLM complete() was never invoked"
|
||||
assert hc.summary == "This is a mock response for testing purposes."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_summary_includes_output_key_hint(self) -> None:
|
||||
llm = SpyLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("compute", '{"score": 95}'),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
ch.summarize_conversation(conv, node_id="n", output_keys=["score", "grade"])
|
||||
|
||||
assert llm.complete_call_args is not None
|
||||
system = llm.complete_call_args.get("system", "")
|
||||
assert "score" in system
|
||||
assert "grade" in system
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_fallback_on_error(self) -> None:
|
||||
llm = FailingLLMProvider()
|
||||
conv = await _build_conversation(
|
||||
("start", "First assistant message."),
|
||||
("end", "Last assistant message."),
|
||||
)
|
||||
ch = ContextHandoff(llm=llm)
|
||||
hc = ch.summarize_conversation(conv, node_id="fallback_node")
|
||||
|
||||
# Should fall back to extractive (first + last assistant messages)
|
||||
assert "First assistant message." in hc.summary
|
||||
assert "Last assistant message." in hc.summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestFormatAsInput
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatAsInput:
|
||||
def test_format_structure(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="analyzer",
|
||||
summary="Analysis complete.",
|
||||
key_outputs={"score": "95"},
|
||||
turn_count=5,
|
||||
total_tokens_used=2000,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "--- CONTEXT FROM: analyzer" in output
|
||||
assert "KEY OUTPUTS:" in output
|
||||
assert "SUMMARY:" in output
|
||||
assert "--- END CONTEXT ---" in output
|
||||
|
||||
def test_format_no_key_outputs(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="simple",
|
||||
summary="Done.",
|
||||
key_outputs={},
|
||||
turn_count=1,
|
||||
total_tokens_used=100,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "KEY OUTPUTS:" not in output
|
||||
assert "SUMMARY:" in output
|
||||
|
||||
def test_format_content_values(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="node_X",
|
||||
summary="Found 3 bugs.",
|
||||
key_outputs={"bugs": "3", "severity": "high"},
|
||||
turn_count=7,
|
||||
total_tokens_used=5000,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "node_X" in output
|
||||
assert "7 turns" in output
|
||||
assert "~5000 tokens" in output
|
||||
assert "- bugs: 3" in output
|
||||
assert "- severity: high" in output
|
||||
assert "Found 3 bugs." in output
|
||||
|
||||
def test_format_empty_summary(self) -> None:
|
||||
hc = HandoffContext(
|
||||
source_node_id="n",
|
||||
summary="",
|
||||
key_outputs={},
|
||||
turn_count=0,
|
||||
total_tokens_used=0,
|
||||
)
|
||||
output = ContextHandoff.format_as_input(hc)
|
||||
|
||||
assert "No summary available." in output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_as_input_usable_as_message(self) -> None:
|
||||
"""Formatted output can be fed into a NodeConversation as a user message."""
|
||||
hc = HandoffContext(
|
||||
source_node_id="prev_node",
|
||||
summary="Completed analysis.",
|
||||
key_outputs={"result": "42"},
|
||||
turn_count=3,
|
||||
total_tokens_used=900,
|
||||
)
|
||||
text = ContextHandoff.format_as_input(hc)
|
||||
|
||||
conv = NodeConversation()
|
||||
msg = await conv.add_user_message(text)
|
||||
|
||||
assert msg.role == "user"
|
||||
assert "CONTEXT FROM: prev_node" in msg.content
|
||||
assert conv.turn_count == 1
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,906 @@
|
||||
"""WP-8: Tests for EventLoopNode, OutputAccumulator, LoopConfig, JudgeProtocol.
|
||||
|
||||
Uses real FileConversationStore (no mocks for storage) and a MockStreamingLLM
|
||||
that yields pre-programmed StreamEvents to control the loop deterministically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop_node import (
|
||||
EventLoopNode,
|
||||
JudgeProtocol,
|
||||
JudgeVerdict,
|
||||
LoopConfig,
|
||||
OutputAccumulator,
|
||||
)
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeSpec, SharedMemory
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.event_bus import EventBus, EventType
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM that yields pre-programmed stream events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Mock LLM that yields pre-programmed StreamEvent sequences.
|
||||
|
||||
Each call to stream() consumes the next scenario from the list.
|
||||
Cycles back to the beginning if more calls are made than scenarios.
|
||||
"""
|
||||
|
||||
def __init__(self, scenarios: list[list] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build a simple text-only scenario
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list:
|
||||
"""Build a stream scenario that produces text and finishes."""
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(
|
||||
stop_reason="stop", input_tokens=input_tokens, output_tokens=output_tokens, model="mock"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def tool_call_scenario(
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
tool_use_id: str = "call_1",
|
||||
text: str = "",
|
||||
) -> list:
|
||||
"""Build a stream scenario that produces a tool call."""
|
||||
events = []
|
||||
if text:
|
||||
events.append(TextDeltaEvent(content=text, snapshot=text))
|
||||
events.append(
|
||||
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
|
||||
)
|
||||
events.append(
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_spec():
|
||||
return NodeSpec(
|
||||
id="test_loop",
|
||||
name="Test Loop",
|
||||
description="A test event loop node",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
system_prompt="You are a test assistant.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory():
|
||||
return SharedMemory()
|
||||
|
||||
|
||||
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
|
||||
"""Build a NodeContext for testing."""
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_spec.id,
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data=input_data or {},
|
||||
llm=llm,
|
||||
available_tools=tools or [],
|
||||
goal_context=goal_context,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# NodeProtocol conformance
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNodeProtocolConformance:
|
||||
def test_subclasses_node_protocol(self):
|
||||
"""EventLoopNode must be a subclass of NodeProtocol."""
|
||||
assert issubclass(EventLoopNode, NodeProtocol)
|
||||
|
||||
def test_has_execute_method(self):
|
||||
node = EventLoopNode()
|
||||
assert hasattr(node, "execute")
|
||||
assert asyncio.iscoroutinefunction(node.execute)
|
||||
|
||||
def test_has_validate_input(self):
|
||||
node = EventLoopNode()
|
||||
assert hasattr(node, "validate_input")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Basic loop execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestBasicLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_text_only_implicit_accept(self, runtime, node_spec, memory):
|
||||
"""No tools, no judge. LLM produces text, implicit accept on stop."""
|
||||
# Override to no output_keys so implicit judge accepts immediately
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello world")])
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.tokens_used > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_llm_returns_failure(self, runtime, node_spec, memory):
|
||||
"""ctx.llm=None should return failure immediately."""
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm=None)
|
||||
|
||||
node = EventLoopNode()
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "LLM" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_failure(self, runtime, node_spec, memory):
|
||||
"""When max_iterations is reached without acceptance, should fail."""
|
||||
# LLM always produces text but never calls set_output, so implicit
|
||||
# judge retries asking for missing keys
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "Max iterations" in result.error
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Judge integration
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestJudgeIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_accept(self, runtime, node_spec, memory):
|
||||
"""Mock judge ACCEPT -> success."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Done!")])
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
judge.evaluate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_escalate(self, runtime, node_spec, memory):
|
||||
"""Mock judge ESCALATE -> failure."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Attempt")])
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(
|
||||
return_value=JudgeVerdict(action="ESCALATE", feedback="Tone violation")
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "escalated" in result.error.lower()
|
||||
assert "Tone violation" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_retry_then_accept(self, runtime, node_spec, memory):
|
||||
"""RETRY twice, then ACCEPT. Should run 3 iterations."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("attempt 1"),
|
||||
text_scenario("attempt 2"),
|
||||
text_scenario("attempt 3"),
|
||||
]
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def evaluate_fn(context):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
return JudgeVerdict(action="RETRY", feedback="Try harder")
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=10))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert call_count == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# set_output tool
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSetOutput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_accumulates(self, runtime, node_spec, memory):
|
||||
"""LLM calls set_output -> values appear in NodeResult.output."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call set_output
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "42"}),
|
||||
# Turn 2: text response (triggers implicit judge)
|
||||
text_scenario("Done, result is 42"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory):
|
||||
"""set_output with key not in output_keys -> is_error=True."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call set_output with bad key
|
||||
tool_call_scenario("set_output", {"key": "bad_key", "value": "x"}),
|
||||
# Turn 2: call set_output with good key
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
|
||||
# Turn 3: text done
|
||||
text_scenario("Done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "ok"
|
||||
assert "bad_key" not in result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
|
||||
"""Judge accepts but output keys are missing -> retry with hint."""
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: text without set_output -> judge accepts but keys missing -> retry
|
||||
text_scenario("I'll get to it"),
|
||||
# Turn 2: set_output
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
# Turn 3: text -> judge accepts, keys present -> success
|
||||
text_scenario("All done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "done"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stall detection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStallDetection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stall_detection(self, runtime, node_spec, memory):
|
||||
"""3 identical responses should trigger stall detection."""
|
||||
node_spec.output_keys = [] # so implicit judge would accept
|
||||
# But we need the judge to RETRY so we actually get 3 identical responses
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
|
||||
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("same answer")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=10, stall_detection_threshold=3),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "stalled" in result.error.lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# EventBus lifecycle events
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEventBusLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_events_published(self, runtime, node_spec, memory):
|
||||
"""NODE_LOOP_STARTED, NODE_LOOP_ITERATION, NODE_LOOP_COMPLETED should be published."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("ok")])
|
||||
bus = EventBus()
|
||||
|
||||
received_events = []
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
],
|
||||
handler=lambda e: received_events.append(e.type),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert EventType.NODE_LOOP_STARTED in received_events
|
||||
assert EventType.NODE_LOOP_ITERATION in received_events
|
||||
assert EventType.NODE_LOOP_COMPLETED in received_events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
|
||||
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
|
||||
spec = NodeSpec(
|
||||
id="ui_node",
|
||||
name="UI Node",
|
||||
description="Streams to user",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("visible to user")])
|
||||
bus = EventBus()
|
||||
|
||||
received_types = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.LLM_TEXT_DELTA],
|
||||
handler=lambda e: received_types.append(e.type),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
|
||||
# client_facing + text-only blocks for user input; use shutdown to unblock
|
||||
async def auto_shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(auto_shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert EventType.CLIENT_OUTPUT_DELTA in received_types
|
||||
assert EventType.LLM_TEXT_DELTA not in received_types
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Client-facing blocking
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestClientFacingBlocking:
|
||||
"""Tests for native client_facing input blocking in EventLoopNode."""
|
||||
|
||||
@pytest.fixture
|
||||
def client_spec(self):
|
||||
return NodeSpec(
|
||||
id="chat",
|
||||
name="Chat",
|
||||
description="chat node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_blocks_on_text(self, runtime, memory, client_spec):
|
||||
"""client_facing + text-only response blocks until inject_event."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("Hello!"),
|
||||
text_scenario("Got your message."),
|
||||
]
|
||||
)
|
||||
bus = EventBus()
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def user_responds():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("I need help")
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
user_task = asyncio.create_task(user_responds())
|
||||
result = await node.execute(ctx)
|
||||
await user_task
|
||||
|
||||
assert result.success is True
|
||||
# LLM should have been called at least twice (first response + after inject)
|
||||
assert llm._call_index >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_does_not_block_on_tools(self, runtime, memory):
|
||||
"""client_facing + tool calls should NOT block — judge evaluates normally."""
|
||||
spec = NodeSpec(
|
||||
id="chat",
|
||||
name="Chat",
|
||||
description="chat node",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
client_facing=True,
|
||||
)
|
||||
# Scenario 1: LLM calls set_output (tool call present → no blocking, judge RETRYs)
|
||||
# Scenario 2: LLM produces text (implicit judge sees output key set → ACCEPT)
|
||||
# But scenario 2 is text-only on client_facing → would block.
|
||||
# So we need shutdown to handle that case.
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
text_scenario("All set!"),
|
||||
]
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
# After set_output, implicit judge RETRYs (tool calls present).
|
||||
# Next turn: text-only on client_facing → blocks.
|
||||
# But implicit judge should ACCEPT first (output key is set, no tools).
|
||||
# Actually, client_facing check happens BEFORE judge, so it blocks.
|
||||
# Use shutdown as safety net.
|
||||
async def auto_shutdown():
|
||||
await asyncio.sleep(0.1)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(auto_shutdown())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_client_facing_unchanged(self, runtime, memory):
|
||||
"""client_facing=False should not block — existing behavior."""
|
||||
spec = NodeSpec(
|
||||
id="internal",
|
||||
name="Internal",
|
||||
description="internal node",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
# Should complete without blocking (implicit judge ACCEPTs on no tools + no keys)
|
||||
result = await node.execute(ctx)
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_shutdown_unblocks(self, runtime, memory, client_spec):
|
||||
"""signal_shutdown should unblock a waiting client_facing node."""
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Waiting...")])
|
||||
bus = EventBus()
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=10))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def shutdown_after_delay():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown_after_delay())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_input_requested_event_published(self, runtime, memory, client_spec):
|
||||
"""CLIENT_INPUT_REQUESTED should be published when blocking."""
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello!")])
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def capture(e):
|
||||
received.append(e)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=capture,
|
||||
)
|
||||
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm)
|
||||
|
||||
async def shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert len(received) >= 1
|
||||
assert received[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tool execution
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestToolExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_feedback(self, runtime, node_spec, memory):
|
||||
"""Tool call -> result fed back to conversation via stream loop."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
def my_tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=f"Result for {tool_use.name}",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 1: call a tool
|
||||
tool_call_scenario("search", {"query": "test"}, tool_use_id="call_search"),
|
||||
# Turn 2: text response after seeing tool result
|
||||
text_scenario("Found the answer"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="search", description="Search", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
tool_executor=my_tool_executor,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
# stream() should have been called twice (tool call turn + final text turn)
|
||||
assert llm._call_index >= 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Write-through persistence with real FileConversationStore
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWriteThroughPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_written_to_store(self, tmp_path, runtime, node_spec, memory):
|
||||
"""Messages should be persisted immediately via write-through."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Hello")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
# Verify parts were written to disk
|
||||
parts = await store.read_parts()
|
||||
assert len(parts) >= 2 # at least initial user msg + assistant msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_accumulator_write_through(self, tmp_path, runtime, node_spec, memory):
|
||||
"""set_output values should be persisted in cursor immediately."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "persisted_value"}),
|
||||
text_scenario("Done"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "persisted_value"
|
||||
|
||||
# Verify output was written to cursor on disk
|
||||
cursor = await store.read_cursor()
|
||||
assert cursor is not None
|
||||
assert cursor["outputs"]["result"] == "persisted_value"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Crash recovery (restore from real FileConversationStore)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCrashRecovery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_checkpoint(self, tmp_path, runtime, node_spec, memory):
|
||||
"""Populate a store with state, then verify EventLoopNode restores from it."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
|
||||
# Simulate a previous run that wrote conversation + cursor
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a test assistant.",
|
||||
output_keys=["result"],
|
||||
store=store,
|
||||
)
|
||||
await conv.add_user_message("Initial input")
|
||||
await conv.add_assistant_message("Working on it...")
|
||||
|
||||
# Write cursor with iteration and outputs
|
||||
await store.write_cursor(
|
||||
{
|
||||
"iteration": 1,
|
||||
"next_seq": conv.next_seq,
|
||||
"outputs": {"result": "partial_value"},
|
||||
}
|
||||
)
|
||||
|
||||
# Now create a new EventLoopNode and execute -- it should restore
|
||||
node_spec.output_keys = [] # no required keys so implicit accept works
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("Continuing...")])
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
conversation_store=store,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
# Should have the restored output
|
||||
assert result.output.get("result") == "partial_value"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# External event injection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEventInjection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_event(self, runtime, node_spec, memory):
|
||||
"""inject_event() content should appear as user message in next iteration."""
|
||||
node_spec.output_keys = []
|
||||
|
||||
judge_calls = []
|
||||
|
||||
async def evaluate_fn(context):
|
||||
judge_calls.append(context)
|
||||
if len(judge_calls) >= 2:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("iteration 1"),
|
||||
text_scenario("iteration 2"),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
config=LoopConfig(max_iterations=5),
|
||||
)
|
||||
|
||||
# Pre-inject an event before execute runs
|
||||
await node.inject_event("Priority: CEO wants meeting rescheduled")
|
||||
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
# Verify the injected content made it into the LLM messages
|
||||
all_messages = []
|
||||
for call in llm.stream_calls:
|
||||
all_messages.extend(call["messages"])
|
||||
injected_found = any("[External event]" in str(m.get("content", "")) for m in all_messages)
|
||||
assert injected_found
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Pause/resume
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestPauseResume:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_returns_early(self, runtime, node_spec, memory):
|
||||
"""pause_requested in input_data should trigger early return."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("should not run")])
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
input_data={"pause_requested": True},
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
# Should return success (paused, not failed)
|
||||
assert result.success is True
|
||||
# LLM should not have been called (paused before first turn)
|
||||
assert llm._call_index == 0
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stream errors
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamErrors:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_recoverable_stream_error_raises(self, runtime, node_spec, memory):
|
||||
"""Non-recoverable StreamErrorEvent should raise RuntimeError."""
|
||||
node_spec.output_keys = []
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
[StreamErrorEvent(error="Connection lost", recoverable=False)],
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
|
||||
with pytest.raises(RuntimeError, match="Stream error"):
|
||||
await node.execute(ctx)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# OutputAccumulator unit tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOutputAccumulator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get(self):
|
||||
acc = OutputAccumulator()
|
||||
await acc.set("key1", "value1")
|
||||
assert acc.get("key1") == "value1"
|
||||
assert acc.get("nonexistent") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_dict(self):
|
||||
acc = OutputAccumulator()
|
||||
await acc.set("a", 1)
|
||||
await acc.set("b", 2)
|
||||
assert acc.to_dict() == {"a": 1, "b": 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_all_keys(self):
|
||||
acc = OutputAccumulator()
|
||||
assert acc.has_all_keys([]) is True
|
||||
assert acc.has_all_keys(["x"]) is False
|
||||
await acc.set("x", "val")
|
||||
assert acc.has_all_keys(["x"]) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_through_to_real_store(self, tmp_path):
|
||||
"""OutputAccumulator should write through to FileConversationStore cursor."""
|
||||
store = FileConversationStore(tmp_path / "acc_test")
|
||||
acc = OutputAccumulator(store=store)
|
||||
|
||||
await acc.set("result", "hello")
|
||||
|
||||
cursor = await store.read_cursor()
|
||||
assert cursor["outputs"]["result"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_real_store(self, tmp_path):
|
||||
"""OutputAccumulator.restore() should rebuild from FileConversationStore."""
|
||||
store = FileConversationStore(tmp_path / "acc_restore")
|
||||
await store.write_cursor({"outputs": {"key1": "val1", "key2": "val2"}})
|
||||
|
||||
acc = await OutputAccumulator.restore(store)
|
||||
assert acc.get("key1") == "val1"
|
||||
assert acc.get("key2") == "val2"
|
||||
assert acc.has_all_keys(["key1", "key2"]) is True
|
||||
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Tests for event_loop node type wiring (Issue #2513).
|
||||
|
||||
Covers:
|
||||
- NodeSpec.client_facing field
|
||||
- event_loop in VALID_NODE_TYPES
|
||||
- _get_node_implementation() event_loop branch
|
||||
- no-retry enforcement in serial execution path
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
class AlwaysFailsNode(NodeProtocol):
|
||||
"""A test node that always fails."""
|
||||
|
||||
def __init__(self):
|
||||
self.attempt_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
self.attempt_count += 1
|
||||
return NodeResult(success=False, error=f"Permanent error (attempt {self.attempt_count})")
|
||||
|
||||
|
||||
class SucceedsOnceNode(NodeProtocol):
|
||||
"""A test node that always succeeds."""
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
return NodeResult(success=True, output={"result": "ok"})
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fast_sleep(monkeypatch):
|
||||
"""Mock asyncio.sleep to avoid real delays from exponential backoff."""
|
||||
monkeypatch.setattr("asyncio.sleep", AsyncMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
"""Create a mock Runtime for testing."""
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.start_run = MagicMock(return_value="test_run_id")
|
||||
runtime.decide = MagicMock(return_value="test_decision_id")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
runtime.report_problem = MagicMock()
|
||||
runtime.set_node = MagicMock()
|
||||
return runtime
|
||||
|
||||
|
||||
# --- NodeSpec.client_facing tests ---
|
||||
|
||||
|
||||
def test_client_facing_defaults_false():
|
||||
"""NodeSpec without client_facing should default to False."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="llm_generate",
|
||||
)
|
||||
assert spec.client_facing is False
|
||||
|
||||
|
||||
def test_client_facing_explicit_true():
|
||||
"""NodeSpec with client_facing=True should retain the value."""
|
||||
spec = NodeSpec(
|
||||
id="n1",
|
||||
name="Node 1",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
)
|
||||
assert spec.client_facing is True
|
||||
|
||||
|
||||
# --- VALID_NODE_TYPES tests ---
|
||||
|
||||
|
||||
def test_event_loop_in_valid_node_types():
|
||||
"""'event_loop' must be in GraphExecutor.VALID_NODE_TYPES."""
|
||||
assert "event_loop" in GraphExecutor.VALID_NODE_TYPES
|
||||
|
||||
|
||||
def test_event_loop_node_spec_accepted():
|
||||
"""Creating a NodeSpec with node_type='event_loop' should not raise."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
assert spec.node_type == "event_loop"
|
||||
|
||||
|
||||
# --- _get_node_implementation() tests ---
|
||||
|
||||
|
||||
def test_unregistered_event_loop_raises(runtime):
|
||||
"""An event_loop node not in the registry should raise RuntimeError."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
|
||||
with pytest.raises(RuntimeError, match="not found in registry"):
|
||||
executor._get_node_implementation(spec)
|
||||
|
||||
|
||||
def test_registered_event_loop_returns_impl(runtime):
|
||||
"""A registered event_loop node should be returned from the registry."""
|
||||
spec = NodeSpec(
|
||||
id="el1",
|
||||
name="Event Loop",
|
||||
description="test",
|
||||
node_type="event_loop",
|
||||
)
|
||||
impl = SucceedsOnceNode()
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("el1", impl)
|
||||
|
||||
result = executor._get_node_implementation(spec)
|
||||
assert result is impl
|
||||
|
||||
|
||||
# --- No-retry enforcement (serial path) ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_forced_zero(runtime):
|
||||
"""An event_loop node with max_retries=3 should only execute once (no retry)."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_fail",
|
||||
name="Failing Event Loop",
|
||||
description="event loop that fails",
|
||||
node_type="event_loop",
|
||||
max_retries=3,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_fail",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_fail"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_fail", failing_node)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# Event loop nodes get max_retries overridden to 0, meaning execute once then fail
|
||||
assert not result.success
|
||||
assert failing_node.attempt_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_zero_no_warning(runtime, caplog):
|
||||
"""An event_loop node with max_retries=0 should not log a warning."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_zero",
|
||||
name="Zero Retry Event Loop",
|
||||
description="event loop with 0 retries",
|
||||
node_type="event_loop",
|
||||
max_retries=0,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_zero",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_zero"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_zero", failing_node)
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
# max_retries=0 should not trigger the override warning
|
||||
assert "Overriding to 0" not in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_loop_max_retries_positive_logs_warning(runtime, caplog):
|
||||
"""An event_loop node with max_retries=3 should log a warning about override."""
|
||||
node_spec = NodeSpec(
|
||||
id="el_warn",
|
||||
name="Warning Event Loop",
|
||||
description="event loop with retries",
|
||||
node_type="event_loop",
|
||||
max_retries=3,
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="test_graph",
|
||||
goal_id="test_goal",
|
||||
name="Test Graph",
|
||||
entry_node="el_warn",
|
||||
nodes=[node_spec],
|
||||
edges=[],
|
||||
terminal_nodes=["el_warn"],
|
||||
)
|
||||
|
||||
goal = Goal(id="test_goal", name="Test", description="test")
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
failing_node = AlwaysFailsNode()
|
||||
executor.register_node("el_warn", failing_node)
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
assert "Overriding to 0" in caplog.text
|
||||
assert "el_warn" in caplog.text
|
||||
|
||||
|
||||
# --- Existing node types unaffected ---
|
||||
|
||||
|
||||
def test_existing_node_types_unchanged():
|
||||
"""All pre-existing node types must still be in VALID_NODE_TYPES with defaults preserved."""
|
||||
expected = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
assert expected.issubset(GraphExecutor.VALID_NODE_TYPES)
|
||||
|
||||
# Default node_type is still llm_tool_use
|
||||
spec = NodeSpec(id="x", name="X", description="x")
|
||||
assert spec.node_type == "llm_tool_use"
|
||||
|
||||
# Default max_retries is still 3
|
||||
assert spec.max_retries == 3
|
||||
|
||||
# Default client_facing is False
|
||||
assert spec.client_facing is False
|
||||
@@ -0,0 +1,978 @@
|
||||
"""Tests for extending the stream event type system.
|
||||
|
||||
Validates that the StreamEvent discriminated union pattern supports:
|
||||
- Type-based dispatch (matching on event.type)
|
||||
- Pattern matching / isinstance branching
|
||||
- Custom event subclasses following the same frozen-dataclass convention
|
||||
- Serialization of mixed event sequences
|
||||
|
||||
WP-2 tests validate EventType enum extension and node-level event routing:
|
||||
- All 12 new EventType enum members with correct string values
|
||||
- node_id routing on AgentEvent
|
||||
- filter_node on Subscription
|
||||
- Backward compatibility with existing enum members
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import FrozenInstanceError, asdict, dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType, Subscription
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers: type-based dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
def dispatch_event(event) -> str:
|
||||
"""Dispatch an event by its type field, returning a label."""
|
||||
handlers = {
|
||||
"text_delta": lambda e: f"text:{e.content}",
|
||||
"text_end": lambda e: f"end:{len(e.full_text)}chars",
|
||||
"tool_call": lambda e: f"call:{e.tool_name}",
|
||||
"tool_result": lambda e: f"result:{e.tool_use_id}",
|
||||
"reasoning_start": lambda _: "reasoning:start",
|
||||
"reasoning_delta": lambda e: f"reasoning:{e.content[:20]}",
|
||||
"finish": lambda e: f"finish:{e.stop_reason}",
|
||||
"error": lambda e: f"error:{e.error}",
|
||||
}
|
||||
handler = handlers.get(event.type)
|
||||
if handler is None:
|
||||
return f"unknown:{event.type}"
|
||||
return handler(event)
|
||||
|
||||
|
||||
def collect_text(events: list) -> str:
|
||||
"""Accumulate full text from a stream of events."""
|
||||
for event in reversed(events):
|
||||
if isinstance(event, TextEndEvent):
|
||||
return event.full_text
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
return event.snapshot
|
||||
return ""
|
||||
|
||||
|
||||
def extract_tool_calls(events: list) -> list[dict[str, Any]]:
|
||||
"""Extract tool call info from a stream of events."""
|
||||
return [
|
||||
{"id": e.tool_use_id, "name": e.tool_name, "input": e.tool_input}
|
||||
for e in events
|
||||
if isinstance(e, ToolCallEvent)
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type-based dispatch tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTypeDispatch:
|
||||
"""Dispatch on event.type string for handler routing."""
|
||||
|
||||
def test_dispatch_text_delta(self):
|
||||
e = TextDeltaEvent(content="hello")
|
||||
assert dispatch_event(e) == "text:hello"
|
||||
|
||||
def test_dispatch_text_end(self):
|
||||
e = TextEndEvent(full_text="hello world")
|
||||
assert dispatch_event(e) == "end:11chars"
|
||||
|
||||
def test_dispatch_tool_call(self):
|
||||
e = ToolCallEvent(tool_name="web_search")
|
||||
assert dispatch_event(e) == "call:web_search"
|
||||
|
||||
def test_dispatch_tool_result(self):
|
||||
e = ToolResultEvent(tool_use_id="abc")
|
||||
assert dispatch_event(e) == "result:abc"
|
||||
|
||||
def test_dispatch_reasoning_start(self):
|
||||
e = ReasoningStartEvent()
|
||||
assert dispatch_event(e) == "reasoning:start"
|
||||
|
||||
def test_dispatch_reasoning_delta(self):
|
||||
e = ReasoningDeltaEvent(content="Let me think step by step")
|
||||
assert dispatch_event(e) == "reasoning:Let me think step by"
|
||||
|
||||
def test_dispatch_finish(self):
|
||||
e = FinishEvent(stop_reason="end_turn")
|
||||
assert dispatch_event(e) == "finish:end_turn"
|
||||
|
||||
def test_dispatch_error(self):
|
||||
e = StreamErrorEvent(error="timeout")
|
||||
assert dispatch_event(e) == "error:timeout"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# isinstance-based filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestInstanceFiltering:
|
||||
"""Filter event streams using isinstance for each event type."""
|
||||
|
||||
@pytest.fixture
|
||||
def text_stream(self) -> list:
|
||||
"""Simulate a text-only stream."""
|
||||
return [
|
||||
TextDeltaEvent(content="Hello", snapshot="Hello"),
|
||||
TextDeltaEvent(content=" world", snapshot="Hello world"),
|
||||
TextDeltaEvent(content="!", snapshot="Hello world!"),
|
||||
TextEndEvent(full_text="Hello world!"),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=3, model="test"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def tool_stream(self) -> list:
|
||||
"""Simulate a tool call stream."""
|
||||
return [
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_1",
|
||||
tool_name="get_weather",
|
||||
tool_input={"city": "London"},
|
||||
),
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_2",
|
||||
tool_name="calculator",
|
||||
tool_input={"expression": "2+2"},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def reasoning_stream(self) -> list:
|
||||
"""Simulate a stream with reasoning blocks."""
|
||||
return [
|
||||
ReasoningStartEvent(),
|
||||
ReasoningDeltaEvent(content="Let me analyze this..."),
|
||||
ReasoningDeltaEvent(content="The answer is 42."),
|
||||
TextDeltaEvent(content="The answer is 42.", snapshot="The answer is 42."),
|
||||
TextEndEvent(full_text="The answer is 42."),
|
||||
FinishEvent(stop_reason="end_turn"),
|
||||
]
|
||||
|
||||
def test_collect_text(self, text_stream):
|
||||
assert collect_text(text_stream) == "Hello world!"
|
||||
|
||||
def test_collect_text_from_tool_stream(self, tool_stream):
|
||||
assert collect_text(tool_stream) == ""
|
||||
|
||||
def test_extract_tool_calls(self, tool_stream):
|
||||
calls = extract_tool_calls(tool_stream)
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["name"] == "get_weather"
|
||||
assert calls[1]["name"] == "calculator"
|
||||
|
||||
def test_extract_tool_calls_from_text_stream(self, text_stream):
|
||||
assert extract_tool_calls(text_stream) == []
|
||||
|
||||
def test_filter_text_deltas(self, text_stream):
|
||||
deltas = [e for e in text_stream if isinstance(e, TextDeltaEvent)]
|
||||
assert len(deltas) == 3
|
||||
|
||||
def test_filter_finish(self, text_stream):
|
||||
finishes = [e for e in text_stream if isinstance(e, FinishEvent)]
|
||||
assert len(finishes) == 1
|
||||
assert finishes[0].stop_reason == "stop"
|
||||
|
||||
def test_reasoning_then_text(self, reasoning_stream):
|
||||
reasoning = [e for e in reasoning_stream if isinstance(e, ReasoningDeltaEvent)]
|
||||
text = collect_text(reasoning_stream)
|
||||
assert len(reasoning) == 2
|
||||
assert text == "The answer is 42."
|
||||
|
||||
def test_mixed_stream_type_counts(self, reasoning_stream):
|
||||
type_counts = {}
|
||||
for e in reasoning_stream:
|
||||
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
||||
assert type_counts == {
|
||||
"reasoning_start": 1,
|
||||
"reasoning_delta": 2,
|
||||
"text_delta": 1,
|
||||
"text_end": 1,
|
||||
"finish": 1,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom event extension pattern
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass(frozen=True)
|
||||
class CustomMetricsEvent:
|
||||
"""Example custom event following the same pattern."""
|
||||
|
||||
type: Literal["custom_metrics"] = "custom_metrics"
|
||||
latency_ms: float = 0.0
|
||||
tokens_per_second: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CustomCitationEvent:
|
||||
"""Example citation event extending the pattern."""
|
||||
|
||||
type: Literal["citation"] = "citation"
|
||||
source_url: str = ""
|
||||
quote: str = ""
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class TestCustomEventExtension:
|
||||
"""Custom events should follow the same frozen-dataclass convention."""
|
||||
|
||||
def test_custom_event_construction(self):
|
||||
e = CustomMetricsEvent(latency_ms=150.5, tokens_per_second=42.3)
|
||||
assert e.type == "custom_metrics"
|
||||
assert e.latency_ms == 150.5
|
||||
|
||||
def test_custom_event_frozen(self):
|
||||
e = CustomMetricsEvent()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.type = "modified"
|
||||
|
||||
def test_custom_event_serialization(self):
|
||||
e = CustomMetricsEvent(
|
||||
latency_ms=100.0,
|
||||
tokens_per_second=50.0,
|
||||
metadata={"provider": "anthropic"},
|
||||
)
|
||||
d = asdict(e)
|
||||
assert d["type"] == "custom_metrics"
|
||||
assert d["metadata"] == {"provider": "anthropic"}
|
||||
|
||||
def test_custom_event_dispatch(self):
|
||||
"""Custom events can extend the dispatch map."""
|
||||
e = CustomMetricsEvent(latency_ms=200.0)
|
||||
# Falls through to "unknown" in our dispatch_event
|
||||
assert dispatch_event(e) == "unknown:custom_metrics"
|
||||
|
||||
def test_custom_event_in_mixed_stream(self):
|
||||
"""Custom events can coexist with standard events in a list."""
|
||||
stream = [
|
||||
TextDeltaEvent(content="hi", snapshot="hi"),
|
||||
CustomMetricsEvent(latency_ms=50.0),
|
||||
TextEndEvent(full_text="hi"),
|
||||
CustomCitationEvent(source_url="https://example.com", quote="hi"),
|
||||
FinishEvent(stop_reason="stop"),
|
||||
]
|
||||
standard = [
|
||||
e
|
||||
for e in stream
|
||||
if hasattr(e, "type")
|
||||
and e.type
|
||||
in {
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"tool_call",
|
||||
"tool_result",
|
||||
"reasoning_start",
|
||||
"reasoning_delta",
|
||||
"finish",
|
||||
"error",
|
||||
}
|
||||
]
|
||||
custom = [
|
||||
e
|
||||
for e in stream
|
||||
if e.type
|
||||
not in {
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"tool_call",
|
||||
"tool_result",
|
||||
"reasoning_start",
|
||||
"reasoning_delta",
|
||||
"finish",
|
||||
"error",
|
||||
}
|
||||
]
|
||||
assert len(standard) == 3
|
||||
assert len(custom) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization of full event sequences
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSequenceSerialization:
|
||||
"""Serialize entire event sequences, as done by the dump tests."""
|
||||
|
||||
def test_serialize_text_sequence(self):
|
||||
events = [
|
||||
TextDeltaEvent(content="Hello", snapshot="Hello"),
|
||||
TextDeltaEvent(content=" world", snapshot="Hello world"),
|
||||
TextEndEvent(full_text="Hello world"),
|
||||
FinishEvent(stop_reason="stop", model="test-model"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert len(serialized) == 4
|
||||
assert serialized[0]["index"] == 0
|
||||
assert serialized[0]["type"] == "text_delta"
|
||||
assert serialized[-1]["type"] == "finish"
|
||||
assert serialized[-1]["model"] == "test-model"
|
||||
|
||||
def test_serialize_tool_sequence(self):
|
||||
events = [
|
||||
ToolCallEvent(
|
||||
tool_use_id="call_1",
|
||||
tool_name="search",
|
||||
tool_input={"query": "test"},
|
||||
),
|
||||
FinishEvent(stop_reason="tool_calls"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert serialized[0]["tool_input"] == {"query": "test"}
|
||||
assert serialized[1]["stop_reason"] == "tool_calls"
|
||||
|
||||
def test_serialize_error_sequence(self):
|
||||
events = [
|
||||
TextDeltaEvent(content="partial"),
|
||||
StreamErrorEvent(error="connection reset", recoverable=True),
|
||||
FinishEvent(stop_reason="error"),
|
||||
]
|
||||
serialized = [{"index": i, **asdict(e)} for i, e in enumerate(events)]
|
||||
assert serialized[1]["type"] == "error"
|
||||
assert serialized[1]["recoverable"] is True
|
||||
|
||||
def test_roundtrip_snapshot_accumulation(self):
|
||||
"""Verify snapshot grows monotonically through serialization."""
|
||||
chunks = ["Hello", " beautiful", " world", "!"]
|
||||
events = []
|
||||
snapshot = ""
|
||||
for chunk in chunks:
|
||||
snapshot += chunk
|
||||
events.append(TextDeltaEvent(content=chunk, snapshot=snapshot))
|
||||
|
||||
serialized = [asdict(e) for e in events]
|
||||
for i in range(1, len(serialized)):
|
||||
assert len(serialized[i]["snapshot"]) > len(serialized[i - 1]["snapshot"])
|
||||
assert serialized[-1]["snapshot"] == "Hello beautiful world!"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# WP-2: EventType Enum Extension + Node-Level Event Routing
|
||||
# ===========================================================================
|
||||
|
||||
# The 12 new EventType members added by WP-2
|
||||
WP2_EVENT_TYPES = {
|
||||
# Node event-loop lifecycle
|
||||
EventType.NODE_LOOP_STARTED: "node_loop_started",
|
||||
EventType.NODE_LOOP_ITERATION: "node_loop_iteration",
|
||||
EventType.NODE_LOOP_COMPLETED: "node_loop_completed",
|
||||
# LLM streaming observability
|
||||
EventType.LLM_TEXT_DELTA: "llm_text_delta",
|
||||
EventType.LLM_REASONING_DELTA: "llm_reasoning_delta",
|
||||
# Tool lifecycle
|
||||
EventType.TOOL_CALL_STARTED: "tool_call_started",
|
||||
EventType.TOOL_CALL_COMPLETED: "tool_call_completed",
|
||||
# Client I/O
|
||||
EventType.CLIENT_OUTPUT_DELTA: "client_output_delta",
|
||||
EventType.CLIENT_INPUT_REQUESTED: "client_input_requested",
|
||||
# Internal node observability
|
||||
EventType.NODE_INTERNAL_OUTPUT: "node_internal_output",
|
||||
EventType.NODE_INPUT_BLOCKED: "node_input_blocked",
|
||||
EventType.NODE_STALLED: "node_stalled",
|
||||
}
|
||||
|
||||
# Pre-existing enum members that must remain unchanged
|
||||
ORIGINAL_EVENT_TYPES = {
|
||||
EventType.EXECUTION_STARTED: "execution_started",
|
||||
EventType.EXECUTION_COMPLETED: "execution_completed",
|
||||
EventType.EXECUTION_FAILED: "execution_failed",
|
||||
EventType.EXECUTION_PAUSED: "execution_paused",
|
||||
EventType.EXECUTION_RESUMED: "execution_resumed",
|
||||
EventType.STATE_CHANGED: "state_changed",
|
||||
EventType.STATE_CONFLICT: "state_conflict",
|
||||
EventType.GOAL_PROGRESS: "goal_progress",
|
||||
EventType.GOAL_ACHIEVED: "goal_achieved",
|
||||
EventType.CONSTRAINT_VIOLATION: "constraint_violation",
|
||||
EventType.STREAM_STARTED: "stream_started",
|
||||
EventType.STREAM_STOPPED: "stream_stopped",
|
||||
EventType.CUSTOM: "custom",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part A: EventType enum members
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2EventTypeEnumMembers:
|
||||
"""All 12 new EventType members exist with correct string values."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,expected_value",
|
||||
WP2_EVENT_TYPES.items(),
|
||||
ids=lambda x: x.name if isinstance(x, EventType) else x,
|
||||
)
|
||||
def test_new_member_value(self, member, expected_value):
|
||||
assert member.value == expected_value
|
||||
|
||||
def test_all_12_new_members_exist(self):
|
||||
assert len(WP2_EVENT_TYPES) == 12
|
||||
|
||||
def test_new_member_string_values_are_unique(self):
|
||||
values = list(WP2_EVENT_TYPES.values())
|
||||
assert len(values) == len(set(values))
|
||||
|
||||
def test_no_collision_with_original_members(self):
|
||||
new_values = set(WP2_EVENT_TYPES.values())
|
||||
old_values = set(ORIGINAL_EVENT_TYPES.values())
|
||||
overlap = new_values & old_values
|
||||
assert overlap == set(), f"Colliding values: {overlap}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,expected_value",
|
||||
ORIGINAL_EVENT_TYPES.items(),
|
||||
ids=lambda x: x.name if isinstance(x, EventType) else x,
|
||||
)
|
||||
def test_original_members_unchanged(self, member, expected_value):
|
||||
assert member.value == expected_value
|
||||
|
||||
def test_event_type_is_str_enum(self):
|
||||
"""EventType members compare equal to their string values."""
|
||||
assert EventType.NODE_LOOP_STARTED == "node_loop_started"
|
||||
assert EventType.LLM_TEXT_DELTA == "llm_text_delta"
|
||||
assert EventType.LLM_TEXT_DELTA.value == "llm_text_delta"
|
||||
|
||||
def test_event_type_accessible_by_name(self):
|
||||
assert EventType["NODE_LOOP_STARTED"] is EventType.NODE_LOOP_STARTED
|
||||
assert EventType["TOOL_CALL_COMPLETED"] is EventType.TOOL_CALL_COMPLETED
|
||||
|
||||
def test_event_type_accessible_by_value(self):
|
||||
assert EventType("node_loop_started") is EventType.NODE_LOOP_STARTED
|
||||
assert EventType("tool_call_completed") is EventType.TOOL_CALL_COMPLETED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part B: AgentEvent.node_id and Subscription.filter_node
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2AgentEventNodeId:
|
||||
"""AgentEvent supports node_id as a first-class field."""
|
||||
|
||||
def test_node_id_defaults_to_none(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="stream-1",
|
||||
)
|
||||
assert event.node_id is None
|
||||
|
||||
def test_node_id_can_be_set(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="email_composer",
|
||||
)
|
||||
assert event.node_id == "email_composer"
|
||||
|
||||
def test_node_id_in_to_dict(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="stream-1",
|
||||
node_id="search_node",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert d["node_id"] == "search_node"
|
||||
|
||||
def test_node_id_none_in_to_dict(self):
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="stream-1",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert "node_id" in d
|
||||
assert d["node_id"] is None
|
||||
|
||||
|
||||
class TestWP2SubscriptionFilterNode:
|
||||
"""Subscription supports filter_node for node-level routing."""
|
||||
|
||||
@staticmethod
|
||||
async def _noop_handler(event: AgentEvent) -> None:
|
||||
pass
|
||||
|
||||
def test_filter_node_defaults_to_none(self):
|
||||
sub = Subscription(
|
||||
id="sub_1",
|
||||
event_types={EventType.LLM_TEXT_DELTA},
|
||||
handler=self._noop_handler,
|
||||
)
|
||||
assert sub.filter_node is None
|
||||
|
||||
def test_filter_node_can_be_set(self):
|
||||
sub = Subscription(
|
||||
id="sub_1",
|
||||
event_types={EventType.LLM_TEXT_DELTA},
|
||||
handler=self._noop_handler,
|
||||
filter_node="email_composer",
|
||||
)
|
||||
assert sub.filter_node == "email_composer"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2 Part B: Node-level event routing integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2NodeLevelRouting:
|
||||
"""EventBus routes events by node_id using filter_node."""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
return EventBus()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_receives_matching_events(self, bus):
|
||||
"""Subscriber with filter_node='node-A' receives events from node-A."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
filter_node="node-A",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
data={"content": "hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].node_id == "node-A"
|
||||
assert received[0].data["content"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_rejects_non_matching_events(self, bus):
|
||||
"""Subscriber with filter_node='node-B' does NOT receive node-A events."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
filter_node="node-B",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
data={"content": "hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_filter_node_receives_all_events(self, bus):
|
||||
"""Subscriber with no filter_node receives events from all nodes."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-A",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id="node-B",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="stream-1",
|
||||
node_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interleaved_nodes_separated_by_filter(self, bus):
|
||||
"""Two subscribers on different nodes get only their node's events."""
|
||||
node_a_events = []
|
||||
node_b_events = []
|
||||
|
||||
async def handler_a(event):
|
||||
node_a_events.append(event)
|
||||
|
||||
async def handler_b(event):
|
||||
node_b_events.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler_a,
|
||||
filter_node="email_sender",
|
||||
)
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA],
|
||||
handler=handler_b,
|
||||
filter_node="inbox_scanner",
|
||||
)
|
||||
|
||||
# Interleaved events from both nodes
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="email_sender",
|
||||
data={"content": "Dear Jo"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="inbox_scanner",
|
||||
data={"content": "RE: Meeting conf"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="email_sender",
|
||||
data={"content": "hn, Thank you for"},
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TEXT_DELTA,
|
||||
stream_id="webhook",
|
||||
node_id="inbox_scanner",
|
||||
data={"content": "irmed for Thursday"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(node_a_events) == 2
|
||||
assert len(node_b_events) == 2
|
||||
assert node_a_events[0].data["content"] == "Dear Jo"
|
||||
assert node_a_events[1].data["content"] == "hn, Thank you for"
|
||||
assert node_b_events[0].data["content"] == "RE: Meeting conf"
|
||||
assert node_b_events[1].data["content"] == "irmed for Thursday"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_node_combined_with_filter_stream(self, bus):
|
||||
"""filter_node and filter_stream work together."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.TOOL_CALL_STARTED],
|
||||
handler=handler,
|
||||
filter_stream="webhook",
|
||||
filter_node="search_node",
|
||||
)
|
||||
|
||||
# Matching both filters
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="webhook",
|
||||
node_id="search_node",
|
||||
)
|
||||
)
|
||||
# Wrong stream
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="api",
|
||||
node_id="search_node",
|
||||
)
|
||||
)
|
||||
# Wrong node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_STARTED,
|
||||
stream_id="webhook",
|
||||
node_id="other_node",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].stream_id == "webhook"
|
||||
assert received[0].node_id == "search_node"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_with_node_id(self, bus):
|
||||
"""wait_for() accepts node_id parameter for filtering."""
|
||||
|
||||
async def publish_later():
|
||||
await asyncio.sleep(0.01)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="target_node",
|
||||
data={"iterations": 3},
|
||||
)
|
||||
)
|
||||
|
||||
task = asyncio.create_task(publish_later())
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.NODE_LOOP_COMPLETED,
|
||||
node_id="target_node",
|
||||
timeout=2.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "target_node"
|
||||
assert event.data["iterations"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_ignores_wrong_node(self, bus):
|
||||
"""wait_for() with node_id ignores events from other nodes."""
|
||||
|
||||
async def publish_wrong_then_right():
|
||||
await asyncio.sleep(0.01)
|
||||
# Wrong node — should be ignored
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="wrong_node",
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
# Right node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.NODE_LOOP_COMPLETED,
|
||||
stream_id="stream-1",
|
||||
node_id="target_node",
|
||||
data={"iterations": 5},
|
||||
)
|
||||
)
|
||||
|
||||
task = asyncio.create_task(publish_wrong_then_right())
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.NODE_LOOP_COMPLETED,
|
||||
node_id="target_node",
|
||||
timeout=2.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "target_node"
|
||||
assert event.data["iterations"] == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WP-2: Convenience publisher methods
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWP2ConveniencePublishers:
|
||||
"""EventBus convenience methods for new WP-2 event types."""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
return EventBus()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_started(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_STARTED], handler=handler)
|
||||
await bus.emit_node_loop_started(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].node_id == "n1"
|
||||
assert received[0].data["max_iterations"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_iteration(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_ITERATION], handler=handler)
|
||||
await bus.emit_node_loop_iteration(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
iteration=3,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["iteration"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_loop_completed(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_LOOP_COMPLETED], handler=handler)
|
||||
await bus.emit_node_loop_completed(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
iterations=5,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["iterations"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_llm_text_delta(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.LLM_TEXT_DELTA], handler=handler)
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
content="hello",
|
||||
snapshot="hello world",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["content"] == "hello"
|
||||
assert received[0].data["snapshot"] == "hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_started(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler)
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
tool_use_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "test"},
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["tool_name"] == "web_search"
|
||||
assert received[0].data["tool_input"] == {"query": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_completed(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler)
|
||||
await bus.emit_tool_call_completed(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
tool_use_id="call_1",
|
||||
tool_name="web_search",
|
||||
result="3 results found",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["result"] == "3 results found"
|
||||
assert received[0].data["is_error"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_client_output_delta(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.CLIENT_OUTPUT_DELTA], handler=handler)
|
||||
await bus.emit_client_output_delta(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
content="chunk",
|
||||
snapshot="full chunk",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["content"] == "chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_stalled(self, bus):
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_STALLED], handler=handler)
|
||||
await bus.emit_node_stalled(
|
||||
stream_id="s1",
|
||||
node_id="n1",
|
||||
reason="no progress after 10 iterations",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["reason"] == "no progress after 10 iterations"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convenience_publishers_set_node_id(self, bus):
|
||||
"""All WP-2 convenience publishers set node_id on the emitted event."""
|
||||
received = []
|
||||
|
||||
async def handler(event):
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.LLM_TEXT_DELTA, EventType.TOOL_CALL_STARTED],
|
||||
handler=handler,
|
||||
filter_node="my_node",
|
||||
)
|
||||
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="my_node",
|
||||
content="hi",
|
||||
snapshot="hi",
|
||||
)
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s1",
|
||||
node_id="my_node",
|
||||
tool_use_id="c1",
|
||||
tool_name="calc",
|
||||
)
|
||||
# Wrong node — should not be received
|
||||
await bus.emit_llm_text_delta(
|
||||
stream_id="s1",
|
||||
node_id="other_node",
|
||||
content="bye",
|
||||
snapshot="bye",
|
||||
)
|
||||
|
||||
assert len(received) == 2
|
||||
assert all(e.node_id == "my_node" for e in received)
|
||||
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
Tests for feedback/callback edges and max_node_visits in GraphExecutor.
|
||||
|
||||
Covers:
|
||||
- NodeSpec.max_node_visits default value
|
||||
- Visit limit enforcement (skip on exceed)
|
||||
- Multiple visits allowed when max_node_visits > 1
|
||||
- Unlimited visits with max_node_visits=0
|
||||
- Conditional feedback edges (backward traversal)
|
||||
- Conditional edge NOT firing (forward path taken)
|
||||
- node_visit_counts populated in ExecutionResult
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock node implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SuccessNode(NodeProtocol):
|
||||
"""Always succeeds with configurable output."""
|
||||
|
||||
def __init__(self, output: dict | None = None):
|
||||
self._output = output or {"result": "ok"}
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
self.execute_count += 1
|
||||
return NodeResult(success=True, output=self._output, tokens_used=10, latency_ms=5)
|
||||
|
||||
|
||||
class StatefulNode(NodeProtocol):
|
||||
"""Returns different outputs on successive executions."""
|
||||
|
||||
def __init__(self, outputs: list[dict]):
|
||||
self._outputs = outputs
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
output = self._outputs[min(self.execute_count, len(self._outputs) - 1)]
|
||||
self.execute_count += 1
|
||||
return NodeResult(success=True, output=output, tokens_used=10, latency_ms=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_id")
|
||||
rt.decide = MagicMock(return_value="decision_id")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
rt.report_problem = MagicMock()
|
||||
rt.set_node = MagicMock()
|
||||
return rt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def goal():
|
||||
return Goal(id="g1", name="Test", description="Feedback edge tests")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. NodeSpec default
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_max_node_visits_default():
|
||||
"""NodeSpec.max_node_visits should default to 1."""
|
||||
spec = NodeSpec(id="n", name="N", description="test", node_type="function", output_keys=["out"])
|
||||
assert spec.max_node_visits == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Visit limit skips node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_skips_node(runtime, goal):
|
||||
"""A→B→A cycle with A.max_visits=1: second visit to A should be skipped.
|
||||
|
||||
Neither node is terminal — max_steps is the guard. After A is skipped,
|
||||
the skip-redirect loop (A skip→B→A skip→B...) burns through max_steps.
|
||||
"""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry with visit limit",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=1,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0, # unlimited — let max_steps guard
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[], # No terminal — max_steps is the guard
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# A should only execute once (all subsequent visits are skipped)
|
||||
assert a_impl.execute_count == 1
|
||||
# Path should contain "a" exactly once (skipped visits aren't appended)
|
||||
assert result.path.count("a") == 1
|
||||
# Visit count tracks ALL visits (including skipped ones)
|
||||
assert result.node_visit_counts["a"] >= 2
|
||||
# B executes multiple times (no visit limit)
|
||||
assert b_impl.execute_count >= 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Visit limit allows multiple
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_allows_multiple(runtime, goal):
|
||||
"""A→B→A cycle with A.max_visits=2: A executes twice before skip."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry allows two visits",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0, # unlimited
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# A should execute exactly twice
|
||||
assert a_impl.execute_count == 2
|
||||
# Path should contain "a" exactly twice
|
||||
assert result.path.count("a") == 2
|
||||
# Visit count includes skipped visits too
|
||||
assert result.node_visit_counts["a"] >= 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Visit limit zero = unlimited
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_limit_zero_unlimited(runtime, goal):
|
||||
"""max_node_visits=0 means unlimited; max_steps is the only guard."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="unlimited visits",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
max_node_visits=0,
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="middle node",
|
||||
node_type="function",
|
||||
output_keys=["b_out"],
|
||||
max_node_visits=0,
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cycle_graph",
|
||||
goal_id="g1",
|
||||
name="Cycle Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
EdgeSpec(id="b_to_a", source="b", target="a", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=[],
|
||||
max_steps=6, # A,B,A,B,A,B
|
||||
)
|
||||
|
||||
a_impl = SuccessNode({"a_out": "from_a"})
|
||||
b_impl = SuccessNode({"b_out": "from_b"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", a_impl)
|
||||
executor.register_node("b", b_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# With max_steps=6: A,B,A,B,A,B → each executes 3 times
|
||||
assert a_impl.execute_count == 3
|
||||
assert b_impl.execute_count == 3
|
||||
assert result.steps_executed == 6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Conditional feedback edge fires
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_feedback_edge(runtime, goal):
|
||||
"""Writer→Director backward edge fires when needs_revision==True in output.
|
||||
|
||||
Edge conditions evaluate `output` (current node result) and `memory`
|
||||
(accumulated shared state). The writer's output hasn't been written to
|
||||
memory yet when edges are evaluated, so we use `output.get(...)`.
|
||||
"""
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="plans work",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="writes draft",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="final output",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="feedback_graph",
|
||||
goal_id="g1",
|
||||
name="Feedback Graph",
|
||||
entry_node="director",
|
||||
nodes=[director, writer, output_node],
|
||||
edges=[
|
||||
EdgeSpec(
|
||||
id="director_to_writer",
|
||||
source="director",
|
||||
target="writer",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
# Forward path: writer → output (when NOT needs_revision)
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') != True",
|
||||
priority=0,
|
||||
),
|
||||
# Feedback path: writer → director (when needs_revision)
|
||||
EdgeSpec(
|
||||
id="writer_feedback",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
director_impl = SuccessNode({"plan": "research AI"})
|
||||
# Writer: first call sets needs_revision=True, second sets False
|
||||
writer_impl = StatefulNode(
|
||||
[
|
||||
{"draft": "draft_v1", "needs_revision": True},
|
||||
{"draft": "draft_v2", "needs_revision": False},
|
||||
]
|
||||
)
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("director", director_impl)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Director executed twice (initial + feedback)
|
||||
assert director_impl.execute_count == 2
|
||||
# Writer executed twice (first draft rejected, second accepted)
|
||||
assert writer_impl.execute_count == 2
|
||||
# Output executed once
|
||||
assert output_impl.execute_count == 1
|
||||
# Full path: director → writer → director → writer → output
|
||||
assert result.path == ["director", "writer", "director", "writer", "output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Conditional feedback edge does NOT fire
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_feedback_false(runtime, goal):
|
||||
"""Writer→Director backward edge does NOT fire when needs_revision is False."""
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="plans work",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="writes draft",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="final output",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="feedback_graph",
|
||||
goal_id="g1",
|
||||
name="Feedback Graph",
|
||||
entry_node="director",
|
||||
nodes=[director, writer, output_node],
|
||||
edges=[
|
||||
EdgeSpec(
|
||||
id="director_to_writer",
|
||||
source="director",
|
||||
target="writer",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') != True",
|
||||
priority=0,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="writer_feedback",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
director_impl = SuccessNode({"plan": "research AI"})
|
||||
# Writer always outputs good draft (no revision needed)
|
||||
writer_impl = SuccessNode({"draft": "perfect_draft", "needs_revision": False})
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("director", director_impl)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Director only executed once (no feedback loop)
|
||||
assert director_impl.execute_count == 1
|
||||
# Writer only executed once
|
||||
assert writer_impl.execute_count == 1
|
||||
# Output executed
|
||||
assert output_impl.execute_count == 1
|
||||
# Straight-through path
|
||||
assert result.path == ["director", "writer", "output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Visit counts in ExecutionResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visit_counts_in_result(runtime, goal):
|
||||
"""ExecutionResult.node_visit_counts is populated with actual visit counts."""
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="A",
|
||||
description="entry",
|
||||
node_type="function",
|
||||
output_keys=["a_out"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="B",
|
||||
description="terminal",
|
||||
node_type="function",
|
||||
input_keys=["a_out"],
|
||||
output_keys=["b_out"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="linear_graph",
|
||||
goal_id="g1",
|
||||
name="Linear Graph",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[
|
||||
EdgeSpec(id="a_to_b", source="a", target="b", condition=EdgeCondition.ON_SUCCESS),
|
||||
],
|
||||
terminal_nodes=["b"],
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
executor.register_node("a", SuccessNode({"a_out": "x"}))
|
||||
executor.register_node("b", SuccessNode({"b_out": "y"}))
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
assert result.node_visit_counts == {"a": 1, "b": 1}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Conditional priority prevents fan-out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_priority_prevents_fanout(runtime, goal):
|
||||
"""When multiple CONDITIONAL edges match, only highest-priority fires.
|
||||
|
||||
Simulates: writer produces output where both forward and feedback
|
||||
conditions could match. The higher-priority forward edge should win;
|
||||
the executor must NOT treat this as fan-out.
|
||||
"""
|
||||
writer = NodeSpec(
|
||||
id="writer",
|
||||
name="Writer",
|
||||
description="produces output",
|
||||
node_type="function",
|
||||
output_keys=["draft", "needs_revision"],
|
||||
)
|
||||
output_node = NodeSpec(
|
||||
id="output",
|
||||
name="Output",
|
||||
description="forward target",
|
||||
node_type="function",
|
||||
output_keys=["final"],
|
||||
)
|
||||
director = NodeSpec(
|
||||
id="director",
|
||||
name="Director",
|
||||
description="feedback target",
|
||||
node_type="function",
|
||||
output_keys=["plan"],
|
||||
max_node_visits=2,
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="priority_graph",
|
||||
goal_id="g1",
|
||||
name="Priority Graph",
|
||||
entry_node="writer",
|
||||
nodes=[writer, output_node, director],
|
||||
edges=[
|
||||
# Forward: higher priority (1)
|
||||
EdgeSpec(
|
||||
id="writer_to_output",
|
||||
source="writer",
|
||||
target="output",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('draft') is not None",
|
||||
priority=1,
|
||||
),
|
||||
# Feedback: lower priority (-1)
|
||||
EdgeSpec(
|
||||
id="writer_to_director",
|
||||
source="writer",
|
||||
target="director",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="output.get('needs_revision') == True",
|
||||
priority=-1,
|
||||
),
|
||||
],
|
||||
terminal_nodes=["output"],
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
# Writer sets BOTH output keys — both conditions are true
|
||||
writer_impl = SuccessNode({"draft": "my draft", "needs_revision": True})
|
||||
output_impl = SuccessNode({"final": "done"})
|
||||
director_impl = SuccessNode({"plan": "plan"})
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, enable_parallel_execution=True)
|
||||
executor.register_node("writer", writer_impl)
|
||||
executor.register_node("output", output_impl)
|
||||
executor.register_node("director", director_impl)
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# Forward edge (priority 1) wins — output executes, director does NOT
|
||||
assert output_impl.execute_count == 1
|
||||
assert director_impl.execute_count == 0
|
||||
assert result.path == ["writer", "output"]
|
||||
@@ -209,6 +209,62 @@ class TestLiteLLMProviderToolUse:
|
||||
assert result.output_tokens == 25 # 15 + 10
|
||||
assert mock_completion.call_count == 2
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools_invalid_json_arguments_are_handled(self, mock_completion):
|
||||
"""Test that invalid JSON tool arguments do not execute the tool."""
|
||||
# Mock response with invalid JSON arguments
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.choices = [MagicMock()]
|
||||
tool_call_response.choices[0].message.content = None
|
||||
tool_call_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
tool_call_response.choices[0].message.tool_calls[0].id = "call_123"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.name = "test_tool"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.arguments = "{invalid json"
|
||||
tool_call_response.choices[0].finish_reason = "tool_calls"
|
||||
tool_call_response.model = "gpt-4o-mini"
|
||||
tool_call_response.usage.prompt_tokens = 10
|
||||
tool_call_response.usage.completion_tokens = 5
|
||||
|
||||
# Final response (LLM continues after tool error)
|
||||
final_response = MagicMock()
|
||||
final_response.choices = [MagicMock()]
|
||||
final_response.choices[0].message.content = "Handled error"
|
||||
final_response.choices[0].message.tool_calls = None
|
||||
final_response.choices[0].finish_reason = "stop"
|
||||
final_response.model = "gpt-4o-mini"
|
||||
final_response.usage.prompt_tokens = 5
|
||||
final_response.usage.completion_tokens = 5
|
||||
|
||||
mock_completion.side_effect = [tool_call_response, final_response]
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={"properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
called["value"] = True
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id, content="should not be called", is_error=False
|
||||
)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "Run tool"}],
|
||||
system="You are a test assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert called["value"] is False
|
||||
assert result.content == "Handled error"
|
||||
|
||||
|
||||
class TestToolConversion:
|
||||
"""Test tool format conversion."""
|
||||
|
||||
@@ -0,0 +1,389 @@
|
||||
"""Real-API streaming tests for LiteLLM provider.
|
||||
|
||||
Calls live LLM APIs and dumps stream events to JSON files for review.
|
||||
Results are saved to core/tests/stream_event_dumps/{provider}_{model}_{scenario}.json
|
||||
|
||||
Run with:
|
||||
cd core && python -m pytest tests/test_litellm_streaming.py -v -s -k "RealAPI"
|
||||
|
||||
Requires API keys set in environment:
|
||||
ANTHROPIC_API_KEY, OPENAI_API_KEY, GEMINI_API_KEY (or via credential store)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMP_DIR = Path(__file__).parent / "stream_event_dumps"
|
||||
|
||||
|
||||
def _serialize_event(index: int, event: StreamEvent) -> dict:
|
||||
"""Serialize a StreamEvent to a JSON-safe dict."""
|
||||
d = asdict(event) # type: ignore[arg-type]
|
||||
d["index"] = index
|
||||
# Move index to front for readability
|
||||
return {"index": index, **{k: v for k, v in d.items() if k != "index"}}
|
||||
|
||||
|
||||
def _dump_events(events: list[StreamEvent], filename: str) -> Path:
|
||||
"""Write stream events to a JSON file in the dump directory."""
|
||||
DUMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filepath = DUMP_DIR / filename
|
||||
serialized = [_serialize_event(i, e) for i, e in enumerate(events)]
|
||||
filepath.write_text(json.dumps(serialized, indent=2) + "\n")
|
||||
logger.info(f"Dumped {len(events)} events to {filepath}")
|
||||
return filepath
|
||||
|
||||
|
||||
async def _collect_stream(provider: LiteLLMProvider, **kwargs) -> list[StreamEvent]:
|
||||
"""Collect all stream events from a provider.stream() call."""
|
||||
events: list[StreamEvent] = []
|
||||
async for event in provider.stream(**kwargs):
|
||||
events.append(event)
|
||||
# Log each event type as it arrives
|
||||
logger.debug(f" [{len(events) - 1}] {event.type}: {event}")
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test matrix: (model_id, dump_prefix, env_var_for_skip)
|
||||
# ---------------------------------------------------------------------------
|
||||
MODELS = [
|
||||
(
|
||||
"anthropic/claude-haiku-4-5-20251001",
|
||||
"anthropic_claude-haiku-4-5-20251001",
|
||||
"ANTHROPIC_API_KEY",
|
||||
),
|
||||
("gpt-4.1-nano", "gpt-4.1-nano", "OPENAI_API_KEY"),
|
||||
("gemini/gemini-2.0-flash", "gemini_gemini-2.0-flash", "GEMINI_API_KEY"),
|
||||
]
|
||||
|
||||
WEATHER_TOOL = Tool(
|
||||
name="get_weather",
|
||||
description="Get the current weather for a city.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name, e.g. 'Tokyo'",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
)
|
||||
|
||||
SEARCH_TOOL = Tool(
|
||||
name="web_search",
|
||||
description="Search the web for information.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
|
||||
CALCULATOR_TOOL = Tool(
|
||||
name="calculator",
|
||||
description="Perform arithmetic calculations.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Math expression to evaluate, e.g. '2 + 2'",
|
||||
}
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _has_api_key(env_var: str) -> bool:
|
||||
"""Check if an API key is available (env var or credential store)."""
|
||||
if os.environ.get(env_var):
|
||||
return True
|
||||
# Try credential store
|
||||
try:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
provider_name = env_var.replace("_API_KEY", "").lower()
|
||||
return creds.is_available(provider_name)
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — text streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRealAPITextStreaming:
|
||||
"""Stream a simple text response from each provider and dump events."""
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a multi-paragraph response to exercise chunked delivery."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
|
||||
"Cover fetch, decode, and execute stages. Be concise but thorough."
|
||||
),
|
||||
}
|
||||
],
|
||||
system="You are a computer science teacher. Give clear, structured explanations.",
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_text.json")
|
||||
|
||||
# Basic structural assertions
|
||||
assert len(events) >= 4, f"Expected at least 4 events, got {len(events)}"
|
||||
|
||||
# Must have multiple text deltas for a longer response
|
||||
text_deltas = [e for e in events if isinstance(e, TextDeltaEvent)]
|
||||
assert len(text_deltas) >= 3, f"Expected 3+ TextDeltaEvents, got {len(text_deltas)}"
|
||||
|
||||
# Snapshot must accumulate monotonically
|
||||
for i in range(1, len(text_deltas)):
|
||||
assert len(text_deltas[i].snapshot) > len(text_deltas[i - 1].snapshot), (
|
||||
f"Snapshot did not grow at index {i}"
|
||||
)
|
||||
|
||||
# Must end with TextEndEvent then FinishEvent
|
||||
text_ends = [e for e in events if isinstance(e, TextEndEvent)]
|
||||
assert len(text_ends) == 1, f"Expected 1 TextEndEvent, got {len(text_ends)}"
|
||||
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1, f"Expected 1 FinishEvent, got {len(finish_events)}"
|
||||
assert finish_events[0].stop_reason in ("stop", "end_turn")
|
||||
|
||||
# TextEndEvent.full_text should match last snapshot
|
||||
assert text_ends[0].full_text == text_deltas[-1].snapshot
|
||||
|
||||
# Response should actually contain multi-paragraph content
|
||||
full_text = text_ends[0].full_text
|
||||
assert len(full_text) > 200, f"Response too short ({len(full_text)} chars)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real API tests — tool call streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRealAPIToolCallStreaming:
|
||||
"""Stream a tool call response from each provider and dump events."""
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a single tool call with complex arguments."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web for 'Python 3.13 release notes'.",
|
||||
}
|
||||
],
|
||||
system="You have access to tools. Use the appropriate tool.",
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_tool_call.json")
|
||||
|
||||
# Basic structural assertions
|
||||
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"
|
||||
|
||||
# Must have a tool call event
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert len(tool_calls) >= 1, "No ToolCallEvent received"
|
||||
|
||||
tc = tool_calls[0]
|
||||
assert tc.tool_name == "web_search"
|
||||
assert "query" in tc.tool_input
|
||||
assert tc.tool_use_id != ""
|
||||
|
||||
# Must end with FinishEvent
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason in ("tool_calls", "tool_use", "stop")
|
||||
|
||||
@pytest.mark.parametrize("model,prefix,env_var", MODELS, ids=[m[1] for m in MODELS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_tool_call_stream(self, model: str, prefix: str, env_var: str):
|
||||
"""Stream a response that should invoke multiple tool calls."""
|
||||
if not _has_api_key(env_var):
|
||||
pytest.skip(f"{env_var} not set")
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I need three things done in parallel: "
|
||||
"1) Get the weather in London, "
|
||||
"2) Get the weather in New York, "
|
||||
"3) Calculate 1337 * 42. "
|
||||
"Use the tools for all three."
|
||||
),
|
||||
}
|
||||
],
|
||||
system=(
|
||||
"You have access to tools. When the user asks for multiple things, "
|
||||
"call all the needed tools. Always use tools, never guess results."
|
||||
),
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL],
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
# Dump to file
|
||||
_dump_events(events, f"{prefix}_multi_tool.json")
|
||||
|
||||
# Must have multiple tool call events
|
||||
tool_calls = [e for e in events if isinstance(e, ToolCallEvent)]
|
||||
assert len(tool_calls) >= 2, (
|
||||
f"Expected 2+ ToolCallEvents for parallel requests, got {len(tool_calls)}"
|
||||
)
|
||||
|
||||
# Verify tool names used
|
||||
tool_names = {tc.tool_name for tc in tool_calls}
|
||||
assert "get_weather" in tool_names, "Expected get_weather tool call"
|
||||
|
||||
# All tool calls should have non-empty IDs
|
||||
for tc in tool_calls:
|
||||
assert tc.tool_use_id != "", f"Empty tool_use_id on {tc.tool_name}"
|
||||
assert tc.tool_input, f"Empty tool_input on {tc.tool_name}"
|
||||
|
||||
# Must end with FinishEvent
|
||||
finish_events = [e for e in events if isinstance(e, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience runner for manual invocation
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
"""Run all streaming tests and dump results. Usage: python tests/test_litellm_streaming.py"""
|
||||
|
||||
ALL_TOOLS = [WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL]
|
||||
|
||||
async def _run_all():
|
||||
for model, prefix, env_var in MODELS:
|
||||
if not _has_api_key(env_var):
|
||||
print(f"SKIP {prefix}: {env_var} not set")
|
||||
continue
|
||||
|
||||
provider = LiteLLMProvider(model=model)
|
||||
|
||||
# Text streaming (multi-paragraph)
|
||||
print(f"\n--- {prefix} text ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Explain in 3 numbered paragraphs how a CPU executes an instruction. "
|
||||
"Cover fetch, decode, and execute stages. Be concise but thorough."
|
||||
),
|
||||
}
|
||||
],
|
||||
system="You are a computer science teacher. Give clear, structured explanations.",
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_text.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
# Tool call streaming
|
||||
print(f"\n--- {prefix} tool_call ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web for 'Python 3.13 release notes'.",
|
||||
}
|
||||
],
|
||||
system="You have access to tools. Use the appropriate tool.",
|
||||
tools=ALL_TOOLS,
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_tool_call.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
# Multi-tool call streaming
|
||||
print(f"\n--- {prefix} multi_tool ---")
|
||||
events = await _collect_stream(
|
||||
provider,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I need three things done in parallel: "
|
||||
"1) Get the weather in London, "
|
||||
"2) Get the weather in New York, "
|
||||
"3) Calculate 1337 * 42. "
|
||||
"Use the tools for all three."
|
||||
),
|
||||
}
|
||||
],
|
||||
system=(
|
||||
"You have access to tools. When the user asks for multiple things, "
|
||||
"call all the needed tools. Always use tools, never guess results."
|
||||
),
|
||||
tools=ALL_TOOLS,
|
||||
max_tokens=512,
|
||||
)
|
||||
path = _dump_events(events, f"{prefix}_multi_tool.json")
|
||||
print(f" {len(events)} events -> {path}")
|
||||
for i, e in enumerate(events):
|
||||
print(f" [{i}] {e.type}: {e}")
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(_run_all())
|
||||
@@ -0,0 +1,932 @@
|
||||
"""Tests for NodeConversation, Message, ConversationStore, and FileConversationStore."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import Message, NodeConversation
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockConversationStore:
|
||||
"""In-memory dict-based store for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._parts: dict[int, dict] = {}
|
||||
self._meta: dict | None = None
|
||||
self._cursor: dict | None = None
|
||||
|
||||
async def write_part(self, seq: int, data: dict[str, Any]) -> None:
|
||||
self._parts[seq] = data
|
||||
|
||||
async def read_parts(self) -> list[dict[str, Any]]:
|
||||
return [self._parts[k] for k in sorted(self._parts)]
|
||||
|
||||
async def write_meta(self, data: dict[str, Any]) -> None:
|
||||
self._meta = data
|
||||
|
||||
async def read_meta(self) -> dict[str, Any] | None:
|
||||
return self._meta
|
||||
|
||||
async def write_cursor(self, data: dict[str, Any]) -> None:
|
||||
self._cursor = data
|
||||
|
||||
async def read_cursor(self) -> dict[str, Any] | None:
|
||||
return self._cursor
|
||||
|
||||
async def delete_parts_before(self, seq: int) -> None:
|
||||
self._parts = {k: v for k, v in self._parts.items() if k >= seq}
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def destroy(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
SAMPLE_TOOL_CALLS = [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": '{"city":"SF"}'},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Message serialization
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestMessage:
|
||||
def test_user_and_assistant_to_llm_dict(self):
|
||||
"""User and assistant (no tools) produce simple role+content dicts."""
|
||||
assert Message(seq=0, role="user", content="hi").to_llm_dict() == {
|
||||
"role": "user",
|
||||
"content": "hi",
|
||||
}
|
||||
assert Message(seq=0, role="assistant", content="hello").to_llm_dict() == {
|
||||
"role": "assistant",
|
||||
"content": "hello",
|
||||
}
|
||||
|
||||
def test_assistant_to_llm_dict_with_tools(self):
|
||||
m = Message(seq=0, role="assistant", content="", tool_calls=SAMPLE_TOOL_CALLS)
|
||||
d = m.to_llm_dict()
|
||||
assert d["role"] == "assistant"
|
||||
assert d["tool_calls"] == SAMPLE_TOOL_CALLS
|
||||
|
||||
def test_tool_to_llm_dict(self):
|
||||
m = Message(seq=0, role="tool", content="sunny", tool_use_id="call_1")
|
||||
d = m.to_llm_dict()
|
||||
assert d == {"role": "tool", "tool_call_id": "call_1", "content": "sunny"}
|
||||
|
||||
def test_tool_error_to_llm_dict(self):
|
||||
m = Message(seq=0, role="tool", content="not found", tool_use_id="call_1", is_error=True)
|
||||
d = m.to_llm_dict()
|
||||
assert d["content"] == "ERROR: not found"
|
||||
assert d["tool_call_id"] == "call_1"
|
||||
|
||||
def test_storage_roundtrip(self):
|
||||
m = Message(seq=5, role="assistant", content="ok", tool_calls=SAMPLE_TOOL_CALLS)
|
||||
restored = Message.from_storage_dict(m.to_storage_dict())
|
||||
assert restored.seq == m.seq
|
||||
assert restored.role == m.role
|
||||
assert restored.content == m.content
|
||||
assert restored.tool_calls == m.tool_calls
|
||||
|
||||
def test_storage_dict_edge_cases(self):
|
||||
"""is_error is preserved; None/False fields are omitted."""
|
||||
m = Message(seq=1, role="tool", content="fail", tool_use_id="c1", is_error=True)
|
||||
d = m.to_storage_dict()
|
||||
assert d["is_error"] is True
|
||||
assert Message.from_storage_dict(d).is_error is True
|
||||
|
||||
d2 = Message(seq=0, role="user", content="hi").to_storage_dict()
|
||||
assert "tool_use_id" not in d2
|
||||
assert "tool_calls" not in d2
|
||||
assert "is_error" not in d2
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# NodeConversation (in-memory)
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestNodeConversation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_build_and_export(self):
|
||||
conv = NodeConversation(system_prompt="You are helpful.")
|
||||
await conv.add_user_message("hello")
|
||||
await conv.add_assistant_message("hi there")
|
||||
await conv.add_user_message("weather?")
|
||||
await conv.add_assistant_message("", tool_calls=SAMPLE_TOOL_CALLS)
|
||||
await conv.add_tool_result("call_1", "sunny")
|
||||
await conv.add_assistant_message("It's sunny!")
|
||||
|
||||
assert conv.turn_count == 2
|
||||
assert conv.message_count == 6
|
||||
llm = conv.to_llm_messages()
|
||||
assert len(llm) == 6
|
||||
assert llm[0]["role"] == "user"
|
||||
assert llm[3]["tool_calls"] == SAMPLE_TOOL_CALLS
|
||||
|
||||
summary = conv.export_summary()
|
||||
assert "turns: 2" in summary
|
||||
assert "messages: 6" in summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_prompt_excluded_from_messages(self):
|
||||
conv = NodeConversation(system_prompt="secret")
|
||||
await conv.add_user_message("hi")
|
||||
llm = conv.to_llm_messages()
|
||||
assert len(llm) == 1
|
||||
assert all("secret" not in str(m) for m in llm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_and_seq_counting(self):
|
||||
"""turn_count tracks user messages; next_seq increments on every add."""
|
||||
conv = NodeConversation()
|
||||
assert conv.turn_count == 0
|
||||
assert conv.next_seq == 0
|
||||
await conv.add_user_message("a")
|
||||
assert conv.turn_count == 1
|
||||
assert conv.next_seq == 1
|
||||
await conv.add_assistant_message("b")
|
||||
assert conv.turn_count == 1
|
||||
assert conv.next_seq == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_estimation(self):
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.estimate_tokens() == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_token_count_overrides_estimate(self):
|
||||
"""When actual API token count is provided, estimate_tokens uses it."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.estimate_tokens() == 100 # chars/4 fallback
|
||||
|
||||
conv.update_token_count(500)
|
||||
assert conv.estimate_tokens() == 500 # actual API value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_resets_token_count(self):
|
||||
"""After compaction, actual token count is cleared (recalibrates on next LLM call)."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("a" * 400)
|
||||
conv.update_token_count(500)
|
||||
assert conv.estimate_tokens() == 500
|
||||
|
||||
await conv.compact("summary", keep_recent=0)
|
||||
# Falls back to chars/4 for the summary message
|
||||
assert conv.estimate_tokens() == len("summary") // 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_resets_token_count(self):
|
||||
"""clear() also resets the actual token count."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("hello")
|
||||
conv.update_token_count(1000)
|
||||
assert conv.estimate_tokens() == 1000
|
||||
|
||||
await conv.clear()
|
||||
assert conv.estimate_tokens() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_ratio(self):
|
||||
"""usage_ratio returns estimate / max_history_tokens."""
|
||||
conv = NodeConversation(max_history_tokens=1000)
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.usage_ratio() == pytest.approx(0.1) # 100/1000
|
||||
|
||||
conv.update_token_count(800)
|
||||
assert conv.usage_ratio() == pytest.approx(0.8) # 800/1000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_ratio_zero_budget(self):
|
||||
"""usage_ratio returns 0 when max_history_tokens is 0 (unlimited)."""
|
||||
conv = NodeConversation(max_history_tokens=0)
|
||||
await conv.add_user_message("a" * 400)
|
||||
assert conv.usage_ratio() == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction_with_actual_tokens(self):
|
||||
"""needs_compaction uses actual API token count when available."""
|
||||
conv = NodeConversation(max_history_tokens=1000, compaction_threshold=0.8)
|
||||
await conv.add_user_message("a" * 100) # chars/4 = 25, well under 800
|
||||
|
||||
assert conv.needs_compaction() is False
|
||||
|
||||
# Simulate API reporting much higher actual token usage
|
||||
conv.update_token_count(850)
|
||||
assert conv.needs_compaction() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction(self):
|
||||
conv = NodeConversation(max_history_tokens=100, compaction_threshold=0.8)
|
||||
await conv.add_user_message("x" * 320)
|
||||
assert conv.needs_compaction() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_replaces_with_summary(self):
|
||||
"""keep_recent=0 replaces all messages; empty conversation is a no-op."""
|
||||
conv = NodeConversation()
|
||||
await conv.compact("summary")
|
||||
assert conv.turn_count == 0
|
||||
|
||||
conv2 = NodeConversation()
|
||||
await conv2.add_user_message("one")
|
||||
await conv2.add_assistant_message("two")
|
||||
seq_before = conv2.next_seq
|
||||
|
||||
await conv2.compact("summary of conversation", keep_recent=0)
|
||||
|
||||
assert conv2.turn_count == 1
|
||||
assert conv2.message_count == 1
|
||||
assert conv2.messages[0].content == "summary of conversation"
|
||||
assert conv2.messages[0].role == "user"
|
||||
assert conv2.messages[0].seq == seq_before
|
||||
assert conv2.next_seq == seq_before + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_keep_recent_default(self):
|
||||
"""Default keep_recent=2 keeps last 2 messages."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("m1")
|
||||
await conv.add_assistant_message("m2")
|
||||
await conv.add_user_message("m3")
|
||||
await conv.add_assistant_message("m4")
|
||||
await conv.add_user_message("m5")
|
||||
await conv.add_assistant_message("m6")
|
||||
|
||||
await conv.compact("summary of early conversation")
|
||||
|
||||
assert conv.message_count == 3
|
||||
assert conv.messages[0].content == "summary of early conversation"
|
||||
assert conv.messages[0].role == "user"
|
||||
assert conv.messages[1].content == "m5"
|
||||
assert conv.messages[2].content == "m6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_keep_recent_clamped(self):
|
||||
"""keep_recent larger than len-1 gets clamped."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("a")
|
||||
await conv.add_assistant_message("b")
|
||||
|
||||
await conv.compact("summary", keep_recent=5)
|
||||
|
||||
assert conv.message_count == 2
|
||||
assert conv.messages[0].content == "summary"
|
||||
assert conv.messages[1].content == "b"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_preserves_output_keys(self):
|
||||
"""PRESERVED VALUES block appears in summary when output_keys match."""
|
||||
conv = NodeConversation(output_keys=["score", "status"])
|
||||
await conv.add_user_message("process this")
|
||||
await conv.add_assistant_message("score: 87")
|
||||
await conv.add_assistant_message("status = complete")
|
||||
await conv.add_user_message("next question")
|
||||
|
||||
await conv.compact("conversation summary", keep_recent=1)
|
||||
|
||||
summary_content = conv.messages[0].content
|
||||
assert "PRESERVED VALUES" in summary_content
|
||||
assert "score: 87" in summary_content
|
||||
assert "status: complete" in summary_content
|
||||
assert "CONVERSATION SUMMARY:" in summary_content
|
||||
assert "conversation summary" in summary_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_seq_arithmetic_with_keep_recent(self):
|
||||
"""Summary seq = recent[0].seq - 1 when keeping recent messages."""
|
||||
conv = NodeConversation()
|
||||
await conv.add_user_message("m1") # seq=0
|
||||
await conv.add_assistant_message("m2") # seq=1
|
||||
await conv.add_user_message("m3") # seq=2
|
||||
await conv.add_assistant_message("m4") # seq=3
|
||||
|
||||
await conv.compact("summary", keep_recent=2)
|
||||
|
||||
assert conv.messages[0].seq == 1 # summary
|
||||
assert conv.messages[1].seq == 2 # m3
|
||||
assert conv.messages[2].seq == 3 # m4
|
||||
assert conv.next_seq == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear(self):
|
||||
"""Clear removes messages, keeps system prompt, preserves next_seq."""
|
||||
conv = NodeConversation(system_prompt="keep me")
|
||||
await conv.add_user_message("a")
|
||||
await conv.add_user_message("b")
|
||||
seq_before = conv.next_seq
|
||||
await conv.clear()
|
||||
assert conv.turn_count == 0
|
||||
assert conv.system_prompt == "keep me"
|
||||
assert conv.next_seq == seq_before
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_summary(self):
|
||||
conv = NodeConversation(system_prompt="Be helpful")
|
||||
await conv.add_user_message("q1")
|
||||
await conv.add_assistant_message("a1")
|
||||
s = conv.export_summary()
|
||||
assert "[STATS]" in s
|
||||
assert "turns: 1" in s
|
||||
assert "messages: 2" in s
|
||||
assert "[CONFIG]" in s
|
||||
assert "Be helpful" in s
|
||||
assert "[RECENT_MESSAGES]" in s
|
||||
assert "[user]" in s
|
||||
assert "[assistant]" in s
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_summary_output_keys(self):
|
||||
"""output_keys appear in CONFIG when set, absent when None."""
|
||||
conv = NodeConversation(
|
||||
system_prompt="test",
|
||||
output_keys=["confirmed_meetings", "lead_score"],
|
||||
)
|
||||
await conv.add_user_message("hi")
|
||||
assert "output_keys: confirmed_meetings, lead_score" in conv.export_summary()
|
||||
|
||||
conv2 = NodeConversation(system_prompt="test")
|
||||
await conv2.add_user_message("hi")
|
||||
assert "output_keys" not in conv2.export_summary()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Output-key extraction
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestExtractProtectedValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_colon_format(self):
|
||||
conv = NodeConversation(output_keys=["score"])
|
||||
await conv.add_assistant_message("The score: 87")
|
||||
assert conv._extract_protected_values(conv.messages) == {"score": "87"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_json_format(self):
|
||||
conv = NodeConversation(output_keys=["meetings"])
|
||||
await conv.add_assistant_message('{"meetings": ["standup", "retro"]}')
|
||||
assert conv._extract_protected_values(conv.messages) == {"meetings": '["standup", "retro"]'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_equals_format(self):
|
||||
conv = NodeConversation(output_keys=["status"])
|
||||
await conv.add_assistant_message("status = done")
|
||||
assert conv._extract_protected_values(conv.messages) == {"status": "done"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_most_recent_wins(self):
|
||||
conv = NodeConversation(output_keys=["score"])
|
||||
await conv.add_assistant_message("score: 50")
|
||||
await conv.add_assistant_message("score: 99")
|
||||
assert conv._extract_protected_values(conv.messages) == {"score": "99"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_embedded_json(self):
|
||||
conv = NodeConversation(output_keys=["lead_score"])
|
||||
await conv.add_assistant_message(
|
||||
'Based on my analysis, here are the results: {"lead_score": 87, "status": "hot"}'
|
||||
)
|
||||
assert conv._extract_protected_values(conv.messages) == {"lead_score": "87"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_no_match_cases(self):
|
||||
"""No extraction: user messages, no output_keys, key not found."""
|
||||
conv = NodeConversation(output_keys=["score"])
|
||||
await conv.add_user_message("score: 42")
|
||||
assert conv._extract_protected_values(conv.messages) == {}
|
||||
|
||||
conv2 = NodeConversation(output_keys=None)
|
||||
await conv2.add_assistant_message("score: 42")
|
||||
assert conv2._extract_protected_values(conv2.messages) == {}
|
||||
|
||||
conv3 = NodeConversation(output_keys=["missing_key"])
|
||||
await conv3.add_assistant_message("nothing relevant here")
|
||||
assert conv3._extract_protected_values(conv3.messages) == {}
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Persistence (MockConversationStore)
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_through_each_add(self):
|
||||
store = MockConversationStore()
|
||||
conv = NodeConversation(store=store)
|
||||
await conv.add_user_message("a")
|
||||
await conv.add_assistant_message("b")
|
||||
parts = await store.read_parts()
|
||||
assert len(parts) == 2
|
||||
assert parts[0]["content"] == "a"
|
||||
assert parts[1]["content"] == "b"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_and_cursor_persistence(self):
|
||||
"""Meta is lazily written on first add; cursor updated on each add."""
|
||||
store = MockConversationStore()
|
||||
conv = NodeConversation(system_prompt="sys", store=store)
|
||||
assert store._meta is None
|
||||
await conv.add_user_message("trigger")
|
||||
assert store._meta is not None
|
||||
assert store._meta["system_prompt"] == "sys"
|
||||
assert store._cursor == {"next_seq": 1}
|
||||
await conv.add_user_message("b")
|
||||
assert store._cursor == {"next_seq": 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_from_store(self):
|
||||
"""Restore reconstructs conversation; empty store returns None."""
|
||||
store = MockConversationStore()
|
||||
assert await NodeConversation.restore(store) is None
|
||||
|
||||
conv = NodeConversation(system_prompt="hello", max_history_tokens=500, store=store)
|
||||
await conv.add_user_message("u1")
|
||||
await conv.add_assistant_message("a1")
|
||||
|
||||
restored = await NodeConversation.restore(store)
|
||||
assert restored is not None
|
||||
assert restored.system_prompt == "hello"
|
||||
assert restored.turn_count == 1
|
||||
assert restored.message_count == 2
|
||||
assert restored.next_seq == 2
|
||||
assert restored.messages[0].content == "u1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_preserves_tool_messages(self):
|
||||
store = MockConversationStore()
|
||||
conv = NodeConversation(store=store)
|
||||
await conv.add_assistant_message("", tool_calls=SAMPLE_TOOL_CALLS)
|
||||
await conv.add_tool_result("call_1", "result", is_error=True)
|
||||
|
||||
restored = await NodeConversation.restore(store)
|
||||
assert restored is not None
|
||||
msgs = restored.messages
|
||||
assert msgs[0].tool_calls == SAMPLE_TOOL_CALLS
|
||||
assert msgs[1].tool_use_id == "call_1"
|
||||
assert msgs[1].is_error is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_deletes_old_parts(self):
|
||||
store = MockConversationStore()
|
||||
conv = NodeConversation(store=store)
|
||||
await conv.add_user_message("a")
|
||||
await conv.add_user_message("b")
|
||||
assert len(store._parts) == 2
|
||||
|
||||
await conv.compact("summary", keep_recent=0)
|
||||
assert len(store._parts) == 1
|
||||
remaining = list(store._parts.values())
|
||||
assert remaining[0]["content"] == "summary"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_then_restore(self):
|
||||
"""Compact with keep_recent persists correctly and restores."""
|
||||
store = MockConversationStore()
|
||||
conv = NodeConversation(system_prompt="sp", store=store)
|
||||
await conv.add_user_message("m1")
|
||||
await conv.add_assistant_message("m2")
|
||||
await conv.add_user_message("m3")
|
||||
await conv.add_assistant_message("m4")
|
||||
|
||||
await conv.compact("early summary", keep_recent=2)
|
||||
|
||||
restored = await NodeConversation.restore(store)
|
||||
assert restored is not None
|
||||
assert restored.message_count == 3
|
||||
assert restored.messages[0].content == "early summary"
|
||||
assert restored.messages[1].content == "m3"
|
||||
assert restored.messages[2].content == "m4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_deletes_store_parts(self):
|
||||
store = MockConversationStore()
|
||||
conv = NodeConversation(store=store)
|
||||
await conv.add_user_message("a")
|
||||
await conv.add_user_message("b")
|
||||
await conv.clear()
|
||||
assert len(store._parts) == 0
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# FileConversationStore
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestFileConversationStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_and_cursor_crud(self, tmp_path):
|
||||
"""Write/read meta and cursor; empty reads return None."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
assert await store.read_meta() is None
|
||||
await store.write_meta({"system_prompt": "hi"})
|
||||
assert await store.read_meta() == {"system_prompt": "hi"}
|
||||
|
||||
await store.write_cursor({"next_seq": 5})
|
||||
assert await store.read_cursor() == {"next_seq": 5}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_and_read_parts_in_order(self, tmp_path):
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
await store.write_part(2, {"seq": 2, "content": "second"})
|
||||
await store.write_part(0, {"seq": 0, "content": "first"})
|
||||
await store.write_part(1, {"seq": 1, "content": "middle"})
|
||||
parts = await store.read_parts()
|
||||
assert [p["seq"] for p in parts] == [0, 1, 2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_parts_before(self, tmp_path):
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
for i in range(5):
|
||||
await store.write_part(i, {"seq": i})
|
||||
await store.delete_parts_before(3)
|
||||
parts = await store.read_parts()
|
||||
assert [p["seq"] for p in parts] == [3, 4]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idempotent_write_part(self, tmp_path):
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
await store.write_part(0, {"seq": 0, "v": 1})
|
||||
await store.write_part(0, {"seq": 0, "v": 2})
|
||||
parts = await store.read_parts()
|
||||
assert len(parts) == 1
|
||||
assert parts[0]["v"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_integration_with_node_conversation(self, tmp_path):
|
||||
"""Full round-trip: create -> add messages -> restore from file store."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
conv = NodeConversation(system_prompt="test", store=store)
|
||||
await conv.add_user_message("u1")
|
||||
await conv.add_assistant_message("a1", tool_calls=SAMPLE_TOOL_CALLS)
|
||||
await conv.add_tool_result("call_1", "r1", is_error=True)
|
||||
|
||||
restored = await NodeConversation.restore(store)
|
||||
assert restored is not None
|
||||
assert restored.system_prompt == "test"
|
||||
assert restored.turn_count == 1
|
||||
assert restored.message_count == 3
|
||||
assert restored.next_seq == 3
|
||||
msgs = restored.messages
|
||||
assert msgs[0].content == "u1"
|
||||
assert msgs[1].tool_calls == SAMPLE_TOOL_CALLS
|
||||
assert msgs[2].is_error is True
|
||||
|
||||
llm = restored.to_llm_messages()
|
||||
assert llm[2]["content"] == "ERROR: r1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_corrupt_part_skipped_on_read(self, tmp_path):
|
||||
"""A corrupt JSON part file is skipped, not fatal to restore."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
await store.write_part(0, {"seq": 0, "content": "ok"})
|
||||
await store.write_part(1, {"seq": 1, "content": "good"})
|
||||
|
||||
# Simulate crash mid-write: corrupt part 0
|
||||
corrupt_path = tmp_path / "conv" / "parts" / "0000000000.json"
|
||||
corrupt_path.write_text("{truncated", encoding="utf-8")
|
||||
|
||||
parts = await store.read_parts()
|
||||
assert len(parts) == 1
|
||||
assert parts[0]["seq"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_structure(self, tmp_path):
|
||||
"""Verify meta.json, cursor.json, and parts/*.json files exist after writes."""
|
||||
store = FileConversationStore(tmp_path / "conv")
|
||||
await store.write_meta({"system_prompt": "hi"})
|
||||
await store.write_cursor({"next_seq": 2})
|
||||
await store.write_part(0, {"seq": 0, "content": "first"})
|
||||
await store.write_part(1, {"seq": 1, "content": "second"})
|
||||
|
||||
base = tmp_path / "conv"
|
||||
assert (base / "meta.json").exists()
|
||||
assert (base / "cursor.json").exists()
|
||||
assert (base / "parts" / "0000000000.json").exists()
|
||||
assert (base / "parts" / "0000000001.json").exists()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Integration tests — real FileConversationStore, no mocks
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestConversationIntegration:
|
||||
"""End-to-end tests using real FileConversationStore on disk.
|
||||
|
||||
Every test creates a fresh directory, writes real JSON files,
|
||||
and restores from a *new* store instance (simulating process restart).
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_agent_conversation(self, tmp_path):
|
||||
"""Simulate a realistic agent conversation with multiple turns,
|
||||
tool calls, and tool results — then restore from disk."""
|
||||
base = tmp_path / "agent_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="You are a helpful travel agent.",
|
||||
max_history_tokens=16000,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Turn 1: user asks, assistant responds with tool call
|
||||
await conv.add_user_message("Find me flights from NYC to London next Friday.")
|
||||
await conv.add_assistant_message(
|
||||
"Let me search for flights.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_flight_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_flights",
|
||||
"arguments": '{"origin":"JFK","destination":"LHR","date":"2025-06-13"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_flight_1",
|
||||
'{"flights":[{"airline":"BA","price":450,"departure":"08:00"},{"airline":"AA","price":520,"departure":"14:30"}]}',
|
||||
)
|
||||
|
||||
# Turn 2: assistant presents results, user picks one
|
||||
await conv.add_assistant_message(
|
||||
"I found 2 flights:\n"
|
||||
"1. British Airways at $450, departing 08:00\n"
|
||||
"2. American Airlines at $520, departing 14:30\n"
|
||||
"Which one would you like?"
|
||||
)
|
||||
await conv.add_user_message("Book the British Airways one.")
|
||||
await conv.add_assistant_message(
|
||||
"Booking the BA flight now.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_book_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "book_flight",
|
||||
"arguments": '{"flight_id":"BA-JFK-LHR-0800","passenger":"user"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_book_1",
|
||||
'{"confirmation":"BA-12345","status":"confirmed"}',
|
||||
)
|
||||
await conv.add_assistant_message("Your flight is booked! Confirmation: BA-12345.")
|
||||
|
||||
# Verify in-memory state
|
||||
assert conv.turn_count == 2
|
||||
assert conv.message_count == 8
|
||||
assert conv.next_seq == 8
|
||||
|
||||
# --- Simulate process restart: new store, same path ---
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.system_prompt == "You are a helpful travel agent."
|
||||
assert restored.turn_count == 2
|
||||
assert restored.message_count == 8
|
||||
assert restored.next_seq == 8
|
||||
|
||||
# Verify message content integrity
|
||||
msgs = restored.messages
|
||||
assert msgs[0].role == "user"
|
||||
assert "NYC to London" in msgs[0].content
|
||||
assert msgs[1].role == "assistant"
|
||||
assert msgs[1].tool_calls[0]["id"] == "call_flight_1"
|
||||
assert msgs[2].role == "tool"
|
||||
assert msgs[2].tool_use_id == "call_flight_1"
|
||||
assert "BA" in msgs[2].content
|
||||
assert msgs[7].content == "Your flight is booked! Confirmation: BA-12345."
|
||||
|
||||
# Verify LLM-format output
|
||||
llm_msgs = restored.to_llm_messages()
|
||||
assert llm_msgs[0] == {"role": "user", "content": msgs[0].content}
|
||||
assert llm_msgs[2]["role"] == "tool"
|
||||
assert llm_msgs[2]["tool_call_id"] == "call_flight_1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_and_restore_preserves_continuity(self, tmp_path):
|
||||
"""Build up a long conversation, compact it, continue adding
|
||||
messages, then restore — verifying seq continuity and content."""
|
||||
base = tmp_path / "compact_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="research assistant",
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Build 10 messages (5 turns)
|
||||
for i in range(5):
|
||||
await conv.add_user_message(f"question {i}")
|
||||
await conv.add_assistant_message(f"answer {i}")
|
||||
|
||||
assert conv.message_count == 10
|
||||
assert conv.next_seq == 10
|
||||
|
||||
# Compact: keep last 2 messages (question 4, answer 4)
|
||||
await conv.compact("Summary of questions 0-3 and their answers.", keep_recent=2)
|
||||
|
||||
assert conv.message_count == 3 # summary + 2 recent
|
||||
assert conv.messages[0].content == "Summary of questions 0-3 and their answers."
|
||||
assert conv.messages[1].content == "question 4"
|
||||
assert conv.messages[2].content == "answer 4"
|
||||
|
||||
# Continue the conversation post-compaction
|
||||
await conv.add_user_message("question 5")
|
||||
await conv.add_assistant_message("answer 5")
|
||||
assert conv.next_seq == 12
|
||||
|
||||
# Verify disk: old part files (seq 0-7) should be deleted
|
||||
parts_dir = base / "parts"
|
||||
part_files = sorted(parts_dir.glob("*.json"))
|
||||
part_seqs = [int(f.stem) for f in part_files]
|
||||
# Should have: summary (seq 7), question 4 (seq 8), answer 4 (seq 9),
|
||||
# question 5 (seq 10), answer 5 (seq 11)
|
||||
assert all(s >= 7 for s in part_seqs), f"Stale parts found: {part_seqs}"
|
||||
|
||||
# Restore from fresh store
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.next_seq == 12
|
||||
assert restored.message_count == 5
|
||||
assert "Summary of questions 0-3" in restored.messages[0].content
|
||||
assert restored.messages[-1].content == "answer 5"
|
||||
|
||||
# Verify seq monotonicity across all restored messages
|
||||
seqs = [m.seq for m in restored.messages]
|
||||
assert seqs == sorted(seqs), f"Seqs not monotonic: {seqs}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_key_preservation_through_compact_and_restore(self, tmp_path):
|
||||
"""Output keys in compacted messages survive disk persistence."""
|
||||
base = tmp_path / "output_key_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(
|
||||
system_prompt="classifier",
|
||||
output_keys=["classification", "confidence"],
|
||||
store=store,
|
||||
)
|
||||
|
||||
await conv.add_user_message("Classify this email: 'You won a prize!'")
|
||||
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.97"}')
|
||||
await conv.add_user_message("What about: 'Meeting at 3pm'")
|
||||
await conv.add_assistant_message('{"classification": "ham", "confidence": "0.99"}')
|
||||
await conv.add_user_message("And: 'Buy cheap meds now'")
|
||||
await conv.add_assistant_message('{"classification": "spam", "confidence": "0.95"}')
|
||||
|
||||
# Compact keeping only the last 2 messages
|
||||
await conv.compact("Classified 3 emails.", keep_recent=2)
|
||||
|
||||
# The summary should contain preserved output keys from discarded messages
|
||||
summary_content = conv.messages[0].content
|
||||
assert "PRESERVED VALUES" in summary_content
|
||||
# Most recent values from discarded messages (msgs 0-3) are "ham"/"0.99"
|
||||
assert "ham" in summary_content or "spam" in summary_content
|
||||
|
||||
# Restore and verify the preserved values survived
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
assert "PRESERVED VALUES" in restored.messages[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_roundtrip(self, tmp_path):
|
||||
"""Tool errors persist and restore with ERROR: prefix in LLM output."""
|
||||
base = tmp_path / "error_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(store=store)
|
||||
|
||||
await conv.add_user_message("Calculate 1/0")
|
||||
await conv.add_assistant_message(
|
||||
"Let me calculate that.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_calc",
|
||||
"type": "function",
|
||||
"function": {"name": "calculator", "arguments": '{"expr":"1/0"}'},
|
||||
}
|
||||
],
|
||||
)
|
||||
await conv.add_tool_result(
|
||||
"call_calc", "ZeroDivisionError: division by zero", is_error=True
|
||||
)
|
||||
await conv.add_assistant_message("The calculation failed: division by zero is undefined.")
|
||||
|
||||
# Restore
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
|
||||
tool_msg = restored.messages[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.is_error is True
|
||||
assert tool_msg.tool_use_id == "call_calc"
|
||||
|
||||
llm_dict = tool_msg.to_llm_dict()
|
||||
assert llm_dict["content"].startswith("ERROR: ")
|
||||
assert "ZeroDivisionError" in llm_dict["content"]
|
||||
assert llm_dict["tool_call_id"] == "call_calc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_conversations_isolated(self, tmp_path):
|
||||
"""Two conversations in separate directories don't interfere."""
|
||||
store_a = FileConversationStore(tmp_path / "conv_a")
|
||||
store_b = FileConversationStore(tmp_path / "conv_b")
|
||||
|
||||
conv_a = NodeConversation(system_prompt="Agent A", store=store_a)
|
||||
conv_b = NodeConversation(system_prompt="Agent B", store=store_b)
|
||||
|
||||
await conv_a.add_user_message("Hello from A")
|
||||
await conv_b.add_user_message("Hello from B")
|
||||
await conv_a.add_assistant_message("Response A")
|
||||
await conv_b.add_assistant_message("Response B")
|
||||
await conv_b.add_user_message("Follow-up B")
|
||||
|
||||
# Restore independently
|
||||
restored_a = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_a"))
|
||||
restored_b = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_b"))
|
||||
|
||||
assert restored_a.system_prompt == "Agent A"
|
||||
assert restored_b.system_prompt == "Agent B"
|
||||
assert restored_a.message_count == 2
|
||||
assert restored_b.message_count == 3
|
||||
assert restored_a.messages[0].content == "Hello from A"
|
||||
assert restored_b.messages[2].content == "Follow-up B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_removes_all_files(self, tmp_path):
|
||||
"""destroy() wipes the entire conversation directory."""
|
||||
base = tmp_path / "doomed_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(system_prompt="temp", store=store)
|
||||
await conv.add_user_message("ephemeral")
|
||||
await conv.add_assistant_message("gone soon")
|
||||
|
||||
assert base.exists()
|
||||
assert (base / "meta.json").exists()
|
||||
assert (base / "parts").exists()
|
||||
|
||||
await store.destroy()
|
||||
|
||||
assert not base.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_empty_store_returns_none(self, tmp_path):
|
||||
"""Restoring from a path that was never written to returns None."""
|
||||
store = FileConversationStore(tmp_path / "empty")
|
||||
result = await NodeConversation.restore(store)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_then_continue_then_restore(self, tmp_path):
|
||||
"""clear() removes messages but preserves seq counter for new messages."""
|
||||
base = tmp_path / "clear_conv"
|
||||
store = FileConversationStore(base)
|
||||
conv = NodeConversation(system_prompt="s", store=store)
|
||||
|
||||
await conv.add_user_message("old msg 0")
|
||||
await conv.add_assistant_message("old msg 1")
|
||||
assert conv.next_seq == 2
|
||||
|
||||
await conv.clear()
|
||||
assert conv.message_count == 0
|
||||
assert conv.next_seq == 2 # seq counter preserved
|
||||
|
||||
# Continue with new messages — seqs should start at 2
|
||||
await conv.add_user_message("new msg")
|
||||
await conv.add_assistant_message("new response")
|
||||
assert conv.next_seq == 4
|
||||
assert conv.messages[0].seq == 2
|
||||
assert conv.messages[1].seq == 3
|
||||
|
||||
# Restore
|
||||
store2 = FileConversationStore(base)
|
||||
restored = await NodeConversation.restore(store2)
|
||||
assert restored is not None
|
||||
assert restored.message_count == 2
|
||||
assert restored.next_seq == 4
|
||||
assert restored.messages[0].content == "new msg"
|
||||
assert restored.messages[0].seq == 2
|
||||
@@ -0,0 +1,318 @@
|
||||
"""Tests for stream event dataclasses.
|
||||
|
||||
Validates construction, defaults, immutability, serialization, and the
|
||||
StreamEvent discriminated union type.
|
||||
"""
|
||||
|
||||
from dataclasses import FrozenInstanceError, asdict, fields
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
# All concrete event classes in the union
|
||||
ALL_EVENT_CLASSES = [
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
ReasoningStartEvent,
|
||||
ReasoningDeltaEvent,
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction & defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventDefaults:
|
||||
"""Each event class should be constructible with zero arguments."""
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_default_construction(self, cls):
|
||||
event = cls()
|
||||
assert event.type != ""
|
||||
|
||||
def test_text_delta_defaults(self):
|
||||
e = TextDeltaEvent()
|
||||
assert e.type == "text_delta"
|
||||
assert e.content == ""
|
||||
assert e.snapshot == ""
|
||||
|
||||
def test_text_end_defaults(self):
|
||||
e = TextEndEvent()
|
||||
assert e.type == "text_end"
|
||||
assert e.full_text == ""
|
||||
|
||||
def test_tool_call_defaults(self):
|
||||
e = ToolCallEvent()
|
||||
assert e.type == "tool_call"
|
||||
assert e.tool_use_id == ""
|
||||
assert e.tool_name == ""
|
||||
assert e.tool_input == {}
|
||||
|
||||
def test_tool_result_defaults(self):
|
||||
e = ToolResultEvent()
|
||||
assert e.type == "tool_result"
|
||||
assert e.tool_use_id == ""
|
||||
assert e.content == ""
|
||||
assert e.is_error is False
|
||||
|
||||
def test_reasoning_start_defaults(self):
|
||||
e = ReasoningStartEvent()
|
||||
assert e.type == "reasoning_start"
|
||||
|
||||
def test_reasoning_delta_defaults(self):
|
||||
e = ReasoningDeltaEvent()
|
||||
assert e.type == "reasoning_delta"
|
||||
assert e.content == ""
|
||||
|
||||
def test_finish_defaults(self):
|
||||
e = FinishEvent()
|
||||
assert e.type == "finish"
|
||||
assert e.stop_reason == ""
|
||||
assert e.input_tokens == 0
|
||||
assert e.output_tokens == 0
|
||||
assert e.model == ""
|
||||
|
||||
def test_stream_error_defaults(self):
|
||||
e = StreamErrorEvent()
|
||||
assert e.type == "error"
|
||||
assert e.error == ""
|
||||
assert e.recoverable is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction with values
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventConstruction:
|
||||
"""Events should store provided field values correctly."""
|
||||
|
||||
def test_text_delta_with_values(self):
|
||||
e = TextDeltaEvent(content="hello", snapshot="hello world")
|
||||
assert e.content == "hello"
|
||||
assert e.snapshot == "hello world"
|
||||
|
||||
def test_text_end_with_values(self):
|
||||
e = TextEndEvent(full_text="the complete response")
|
||||
assert e.full_text == "the complete response"
|
||||
|
||||
def test_tool_call_with_values(self):
|
||||
e = ToolCallEvent(
|
||||
tool_use_id="call_abc123",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "python", "num_results": 5},
|
||||
)
|
||||
assert e.tool_use_id == "call_abc123"
|
||||
assert e.tool_name == "web_search"
|
||||
assert e.tool_input == {"query": "python", "num_results": 5}
|
||||
|
||||
def test_tool_result_with_values(self):
|
||||
e = ToolResultEvent(
|
||||
tool_use_id="call_abc123",
|
||||
content="search results here",
|
||||
is_error=False,
|
||||
)
|
||||
assert e.tool_use_id == "call_abc123"
|
||||
assert e.content == "search results here"
|
||||
assert e.is_error is False
|
||||
|
||||
def test_tool_result_error(self):
|
||||
e = ToolResultEvent(
|
||||
tool_use_id="call_fail",
|
||||
content="timeout",
|
||||
is_error=True,
|
||||
)
|
||||
assert e.is_error is True
|
||||
|
||||
def test_reasoning_delta_with_content(self):
|
||||
e = ReasoningDeltaEvent(content="Let me think about this...")
|
||||
assert e.content == "Let me think about this..."
|
||||
|
||||
def test_finish_with_values(self):
|
||||
e = FinishEvent(
|
||||
stop_reason="end_turn",
|
||||
input_tokens=150,
|
||||
output_tokens=300,
|
||||
model="claude-haiku-4-5",
|
||||
)
|
||||
assert e.stop_reason == "end_turn"
|
||||
assert e.input_tokens == 150
|
||||
assert e.output_tokens == 300
|
||||
assert e.model == "claude-haiku-4-5"
|
||||
|
||||
def test_stream_error_with_values(self):
|
||||
e = StreamErrorEvent(error="rate limit exceeded", recoverable=True)
|
||||
assert e.error == "rate limit exceeded"
|
||||
assert e.recoverable is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frozen immutability
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventImmutability:
|
||||
"""All events are frozen dataclasses — fields cannot be reassigned."""
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_frozen(self, cls):
|
||||
event = cls()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
event.type = "modified"
|
||||
|
||||
def test_text_delta_frozen_content(self):
|
||||
e = TextDeltaEvent(content="hello")
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.content = "modified"
|
||||
|
||||
def test_tool_call_frozen_input(self):
|
||||
e = ToolCallEvent(tool_input={"key": "value"})
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
e.tool_input = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type literal values
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTypeLiterals:
|
||||
"""Each event's `type` field should match its Literal annotation."""
|
||||
|
||||
EXPECTED_TYPES = {
|
||||
TextDeltaEvent: "text_delta",
|
||||
TextEndEvent: "text_end",
|
||||
ToolCallEvent: "tool_call",
|
||||
ToolResultEvent: "tool_result",
|
||||
ReasoningStartEvent: "reasoning_start",
|
||||
ReasoningDeltaEvent: "reasoning_delta",
|
||||
FinishEvent: "finish",
|
||||
StreamErrorEvent: "error",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls,expected_type",
|
||||
EXPECTED_TYPES.items(),
|
||||
ids=lambda x: x.__name__ if isinstance(x, type) else x,
|
||||
)
|
||||
def test_type_value(self, cls, expected_type):
|
||||
assert cls().type == expected_type
|
||||
|
||||
def test_all_types_unique(self):
|
||||
types = [cls().type for cls in ALL_EVENT_CLASSES]
|
||||
assert len(types) == len(set(types)), f"Duplicate type values: {types}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization via dataclasses.asdict
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventSerialization:
|
||||
"""Events should round-trip through asdict for JSON serialization."""
|
||||
|
||||
def test_text_delta_asdict(self):
|
||||
e = TextDeltaEvent(content="chunk", snapshot="full chunk")
|
||||
d = asdict(e)
|
||||
assert d == {"type": "text_delta", "content": "chunk", "snapshot": "full chunk"}
|
||||
|
||||
def test_tool_call_asdict(self):
|
||||
e = ToolCallEvent(
|
||||
tool_use_id="id_1",
|
||||
tool_name="calc",
|
||||
tool_input={"expression": "2+2"},
|
||||
)
|
||||
d = asdict(e)
|
||||
assert d["tool_name"] == "calc"
|
||||
assert d["tool_input"] == {"expression": "2+2"}
|
||||
|
||||
def test_finish_asdict(self):
|
||||
e = FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=20, model="gpt-4")
|
||||
d = asdict(e)
|
||||
assert d == {
|
||||
"type": "finish",
|
||||
"stop_reason": "stop",
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_asdict_contains_type(self, cls):
|
||||
d = asdict(cls())
|
||||
assert "type" in d
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_asdict_keys_match_fields(self, cls):
|
||||
event = cls()
|
||||
d = asdict(event)
|
||||
field_names = {f.name for f in fields(cls)}
|
||||
assert set(d.keys()) == field_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StreamEvent union type
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestStreamEventUnion:
|
||||
"""The StreamEvent union should include all event classes."""
|
||||
|
||||
def test_union_contains_all_classes(self):
|
||||
# StreamEvent is a UnionType (PEP 604 syntax: X | Y | Z)
|
||||
union_args = StreamEvent.__args__ # type: ignore[attr-defined]
|
||||
for cls in ALL_EVENT_CLASSES:
|
||||
assert cls in union_args, f"{cls.__name__} not in StreamEvent union"
|
||||
|
||||
def test_union_has_exactly_expected_members(self):
|
||||
union_args = set(StreamEvent.__args__) # type: ignore[attr-defined]
|
||||
expected = set(ALL_EVENT_CLASSES)
|
||||
assert union_args == expected
|
||||
|
||||
@pytest.mark.parametrize("cls", ALL_EVENT_CLASSES, ids=lambda c: c.__name__)
|
||||
def test_isinstance_check(self, cls):
|
||||
"""Each event instance should be an instance of its class (basic sanity)."""
|
||||
event = cls()
|
||||
assert isinstance(event, cls)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Equality & hashing (frozen dataclasses support both)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventEquality:
|
||||
"""Frozen dataclasses support equality and hashing."""
|
||||
|
||||
def test_equal_events(self):
|
||||
a = TextDeltaEvent(content="hi", snapshot="hi")
|
||||
b = TextDeltaEvent(content="hi", snapshot="hi")
|
||||
assert a == b
|
||||
|
||||
def test_unequal_events(self):
|
||||
a = TextDeltaEvent(content="hi")
|
||||
b = TextDeltaEvent(content="bye")
|
||||
assert a != b
|
||||
|
||||
def test_different_types_not_equal(self):
|
||||
a = TextDeltaEvent(content="hi")
|
||||
b = ReasoningDeltaEvent(content="hi")
|
||||
assert a != b
|
||||
|
||||
def test_hashable(self):
|
||||
e = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
s = {e} # should be hashable since frozen
|
||||
assert e in s
|
||||
|
||||
def test_equal_events_same_hash(self):
|
||||
a = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
b = FinishEvent(stop_reason="stop", model="gpt-4")
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
def test_events_with_dict_not_hashable(self):
|
||||
"""Events containing dict fields (e.g. tool_input) are not hashable."""
|
||||
e = ToolCallEvent(tool_use_id="x", tool_name="y", tool_input={"key": "val"})
|
||||
with pytest.raises(TypeError, match="unhashable type"):
|
||||
hash(e)
|
||||
@@ -362,7 +362,6 @@ class AgentRequest(BaseModel):
|
||||
raise ValueError('max_tokens too high')
|
||||
return v
|
||||
```
|
||||
|
||||
### Output Sanitization
|
||||
> **Note:** The following snippet is illustrative and shows a simplified example
|
||||
> of output sanitization logic. Actual implementations may differ.
|
||||
|
||||
@@ -1757,7 +1757,7 @@ tools/src/aden_tools/credentials/
|
||||
|
||||
### Manual Testing
|
||||
|
||||
- [ ] Create encrypted credential store
|
||||
- [ ] Create local encrypted store
|
||||
- [ ] Save and load multi-key credentials
|
||||
- [ ] Verify template resolution in tool headers
|
||||
- [ ] Test OAuth2 token refresh
|
||||
|
||||
+1
-3
@@ -14,7 +14,6 @@
|
||||
|
||||
[](https://github.com/adenhq/hive/blob/main/LICENSE)
|
||||
[](https://www.ycombinator.com/companies/aden)
|
||||
[](https://hub.docker.com/u/adenhq)
|
||||
[](https://discord.com/invite/MXE49hrKDk)
|
||||
[](https://x.com/aden_hq)
|
||||
[](https://www.linkedin.com/company/teamaden/)
|
||||
@@ -42,7 +41,7 @@ Visita [adenhq.com](https://adenhq.com) para documentación completa, ejemplos y
|
||||
## ¿Qué es Aden?
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden es una plataforma para construir, desplegar, operar y adaptar agentes de IA:
|
||||
@@ -66,7 +65,6 @@ Aden es una plataforma para construir, desplegar, operar y adaptar agentes de IA
|
||||
### Prerrequisitos
|
||||
|
||||
- [Python 3.11+](https://www.python.org/downloads/) - Para desarrollo de agentes
|
||||
- [Docker](https://docs.docker.com/get-docker/) (v20.10+) - Opcional, para herramientas en contenedores
|
||||
|
||||
### Instalación
|
||||
|
||||
|
||||
+1
-1
@@ -44,7 +44,7 @@
|
||||
# Aden क्या है?
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden एक ऐसा प्लेटफ़ॉर्म है जो AI एजेंट्स को बनाने, डिप्लॉय करने, ऑपरेट करने और अनुकूलित करने के लिए उपयोग होता है:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Adenとは
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Adenは、AIエージェントの構築、デプロイ、運用、適応のためのプラットフォームです:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Aden이란 무엇인가
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden은 AI 에이전트를 구축, 배포, 운영, 적응시키기 위한 플랫폼입니다:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@ Visite [adenhq.com](https://adenhq.com) para documentação completa, exemplos e
|
||||
## O que é Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden é uma plataforma para construir, implantar, operar e adaptar agentes de IA:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## Что такое Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden — это платформа для создания, развёртывания, эксплуатации и адаптации ИИ-агентов:
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@
|
||||
## 什么是 Aden
|
||||
|
||||
<p align="center">
|
||||
<img width="100%" alt="Aden Architecture" src="docs/assets/aden-architecture-diagram.jpg" />
|
||||
<img width="100%" alt="Aden Architecture" src="../assets/aden-architecture-diagram.jpg" />
|
||||
</p>
|
||||
|
||||
Aden 是一个用于构建、部署、运营和适应 AI 智能体的平台:
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# Examples
|
||||
|
||||
This directory contains two types of examples to help you build agents with the Hive framework.
|
||||
|
||||
## Recipes vs Templates
|
||||
|
||||
### [recipes/](recipes/) — "How to make it"
|
||||
|
||||
A recipe is a **prompt-only** description of an agent. It tells you the goal, the nodes, the prompts, the edge routing logic, and what tools to wire in — but it's not runnable code. You read the recipe, then build the agent yourself.
|
||||
|
||||
Use recipes when you want to:
|
||||
- Understand a pattern before committing to an implementation
|
||||
- Adapt an idea to your own codebase or tooling
|
||||
- Learn how to think about agent design (goals, nodes, edges, prompts)
|
||||
|
||||
### [templates/](templates/) — "Ready to eat"
|
||||
|
||||
A template is a **working agent scaffold** that follows the standard Hive export structure. Copy the folder, rename it, swap in your own prompts and tools, and run it.
|
||||
|
||||
Use templates when you want to:
|
||||
- Get a new agent running quickly
|
||||
- Start from a known-good structure instead of from scratch
|
||||
- See how all the pieces (goal, nodes, edges, config, CLI) fit together in real code
|
||||
|
||||
## How to use a template
|
||||
|
||||
```bash
|
||||
# 1. Copy the template
|
||||
cp -r examples/templates/marketing_agent exports/my_agent
|
||||
|
||||
# 2. Edit the goal, nodes, and edges in agent.py and nodes/__init__.py
|
||||
|
||||
# 3. Run it
|
||||
PYTHONPATH=core python -m exports.my_agent --help
|
||||
```
|
||||
|
||||
## How to use a recipe
|
||||
|
||||
1. Read the recipe markdown file
|
||||
2. Use the patterns described to build your own agent — either manually or with the builder agent (`/agent-workflow`)
|
||||
3. Refer to the [core README](../core/README.md) for framework API details
|
||||
@@ -0,0 +1,27 @@
|
||||
# Recipes
|
||||
|
||||
A recipe describes an agent's design — the goal, nodes, prompts, edge logic, and tools — without providing runnable code. Think of it as a blueprint: it tells you *how* to build the agent, but you do the building.
|
||||
|
||||
## What's in a recipe
|
||||
|
||||
Each recipe is a markdown file (or folder with a markdown file) containing:
|
||||
|
||||
- **Goal**: What the agent accomplishes, including success criteria and constraints
|
||||
- **Nodes**: Each step in the workflow, with the system prompt, node type, and input/output keys
|
||||
- **Edges**: How nodes connect, including conditions and routing logic
|
||||
- **Tools**: What external tools or MCP servers the agent needs
|
||||
- **Usage notes**: Tips, gotchas, and suggested variations
|
||||
|
||||
## How to use a recipe
|
||||
|
||||
1. Read through the recipe to understand the design
|
||||
2. Create a new agent using the standard export structure (see [templates/](../templates/) for a scaffold)
|
||||
3. Translate the recipe's goal, nodes, and edges into code
|
||||
4. Wire in the tools described
|
||||
5. Test and iterate
|
||||
|
||||
## Available recipes
|
||||
|
||||
| Recipe | Description |
|
||||
|--------|-------------|
|
||||
| [marketing_agent](marketing_agent/) | Multi-channel marketing content generator with audience analysis and A/B copy variants |
|
||||
@@ -0,0 +1,156 @@
|
||||
# Recipe: Marketing Content Agent
|
||||
|
||||
A multi-channel marketing content generator. Given a product description and target audience, this agent analyzes the audience, generates tailored copy for multiple channels, and produces A/B variants.
|
||||
|
||||
## Goal
|
||||
|
||||
```
|
||||
Name: Marketing Content Generator
|
||||
Description: Generate targeted marketing content across multiple channels
|
||||
for a given product and audience.
|
||||
|
||||
Success criteria:
|
||||
- Audience analysis is produced with demographics and pain points
|
||||
- At least 2 channel-specific content pieces are generated
|
||||
- A/B variants are provided for each piece
|
||||
- All content aligns with the specified brand voice
|
||||
|
||||
Constraints:
|
||||
- (hard) No competitor brand names in generated content
|
||||
- (soft) Content should be under 280 characters for social media channels
|
||||
```
|
||||
|
||||
## Input / Output
|
||||
|
||||
**Input:**
|
||||
- `product_description` (str) — What the product is and does
|
||||
- `target_audience` (str) — Who the content is for
|
||||
- `brand_voice` (str) — Tone and style guidelines (e.g., "professional but approachable")
|
||||
- `channels` (list[str]) — Target channels, e.g. `["email", "twitter", "linkedin"]`
|
||||
|
||||
**Output:**
|
||||
- `audience_analysis` (dict) — Demographics, pain points, motivations
|
||||
- `content` (list[dict]) — Per-channel content with A/B variants
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
[analyze_audience] → [generate_content] → [review_and_refine]
|
||||
|
|
||||
(conditional)
|
||||
|
|
||||
needs_revision == True → [generate_content]
|
||||
needs_revision == False → (done)
|
||||
```
|
||||
|
||||
## Nodes
|
||||
|
||||
### 1. analyze_audience
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `product_description`, `target_audience` |
|
||||
| Output keys | `audience_analysis` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a marketing strategist. Analyze the target audience for a product.
|
||||
|
||||
Product: {product_description}
|
||||
Target audience: {target_audience}
|
||||
|
||||
Produce a structured analysis in JSON:
|
||||
{{
|
||||
"audience_analysis": {{
|
||||
"demographics": "...",
|
||||
"pain_points": ["..."],
|
||||
"motivations": ["..."],
|
||||
"preferred_channels": ["..."],
|
||||
"messaging_angle": "..."
|
||||
}}
|
||||
}}
|
||||
```
|
||||
|
||||
### 2. generate_content
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `product_description`, `audience_analysis`, `brand_voice`, `channels` |
|
||||
| Output keys | `content` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a marketing copywriter. Generate content for each channel.
|
||||
|
||||
Product: {product_description}
|
||||
Audience analysis: {audience_analysis}
|
||||
Brand voice: {brand_voice}
|
||||
Channels: {channels}
|
||||
|
||||
For each channel, produce two variants (A and B).
|
||||
|
||||
Output as JSON:
|
||||
{{
|
||||
"content": [
|
||||
{{
|
||||
"channel": "twitter",
|
||||
"variant_a": "...",
|
||||
"variant_b": "..."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
### 3. review_and_refine
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Type | `llm_generate` |
|
||||
| Input keys | `content`, `brand_voice` |
|
||||
| Output keys | `content`, `needs_revision` |
|
||||
| Tools | None |
|
||||
|
||||
**System prompt:**
|
||||
```
|
||||
You are a senior marketing editor. Review the following content for brand
|
||||
voice alignment, clarity, and channel appropriateness.
|
||||
|
||||
Content: {content}
|
||||
Brand voice: {brand_voice}
|
||||
|
||||
If any piece needs revision, fix it and set needs_revision to true.
|
||||
If everything looks good, return the content unchanged with needs_revision false.
|
||||
|
||||
Output as JSON:
|
||||
{{
|
||||
"content": [...],
|
||||
"needs_revision": false
|
||||
}}
|
||||
```
|
||||
|
||||
## Edges
|
||||
|
||||
| Source | Target | Condition | Priority |
|
||||
|--------|--------|-----------|----------|
|
||||
| analyze_audience | generate_content | `on_success` | 0 |
|
||||
| generate_content | review_and_refine | `on_success` | 0 |
|
||||
| review_and_refine | generate_content | `conditional: needs_revision == True` | 10 |
|
||||
|
||||
The `review_and_refine → generate_content` loop has higher priority so it's checked first. If `needs_revision` is false, execution ends at `review_and_refine` (terminal node).
|
||||
|
||||
## Tools
|
||||
|
||||
This recipe uses no external tools — all nodes are `llm_generate`. To extend it, consider adding:
|
||||
- A web search tool for competitive analysis in the `analyze_audience` node
|
||||
- A URL shortener tool for social media content
|
||||
- An image generation tool for visual content variants
|
||||
|
||||
## Variations
|
||||
|
||||
- **Single-channel mode**: Remove the `channels` input and hardcode to one channel for simpler output
|
||||
- **With approval gate**: Add a `human_input` node between `review_and_refine` and the terminal to require human sign-off
|
||||
- **With analytics**: Add a `function` node that logs generated content to a tracking system
|
||||
@@ -0,0 +1,38 @@
|
||||
# Templates
|
||||
|
||||
A template is a working agent scaffold that follows the standard Hive export structure. Copy it, rename it, customize the goal/nodes/edges, and run it.
|
||||
|
||||
## What's in a template
|
||||
|
||||
Each template is a complete agent package:
|
||||
|
||||
```
|
||||
template_name/
|
||||
├── __init__.py # Package exports
|
||||
├── __main__.py # CLI entry point
|
||||
├── agent.py # Goal, edges, graph spec, agent class
|
||||
├── config.py # Runtime configuration
|
||||
├── nodes/
|
||||
│ └── __init__.py # Node definitions (NodeSpec instances)
|
||||
└── README.md # What this template demonstrates
|
||||
```
|
||||
|
||||
## How to use a template
|
||||
|
||||
```bash
|
||||
# 1. Copy to your exports directory
|
||||
cp -r examples/templates/marketing_agent exports/my_marketing_agent
|
||||
|
||||
# 2. Update the module references in __main__.py and __init__.py
|
||||
|
||||
# 3. Customize goal, nodes, edges, and prompts
|
||||
|
||||
# 4. Run it
|
||||
PYTHONPATH=core python -m exports.my_marketing_agent --input '{"product_description": "..."}'
|
||||
```
|
||||
|
||||
## Available templates
|
||||
|
||||
| Template | Description |
|
||||
|----------|-------------|
|
||||
| [marketing_agent](marketing_agent/) | Multi-channel marketing content generator with audience analysis, content generation, and editorial review nodes |
|
||||
@@ -0,0 +1,57 @@
|
||||
# Template: Marketing Content Agent
|
||||
|
||||
A multi-channel marketing content generator. Given a product and audience, this agent analyzes the audience, generates tailored copy for multiple channels with A/B variants, and reviews the output for quality.
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
[analyze-audience] → [generate-content] → [review-and-refine]
|
||||
|
|
||||
(conditional)
|
||||
|
|
||||
needs_revision == True → [generate-content]
|
||||
needs_revision == False → (done)
|
||||
```
|
||||
|
||||
## Nodes
|
||||
|
||||
| Node | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `analyze-audience` | `llm_generate` | Produces structured audience analysis |
|
||||
| `generate-content` | `llm_generate` | Creates per-channel copy with A/B variants |
|
||||
| `review-and-refine` | `llm_generate` | Reviews and optionally revises content |
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# From the repo root
|
||||
PYTHONPATH=core python -m examples.templates.marketing_agent
|
||||
|
||||
# With custom input
|
||||
PYTHONPATH=core python -m examples.templates.marketing_agent --input '{
|
||||
"product_description": "A fitness tracking app",
|
||||
"target_audience": "Health-conscious millennials",
|
||||
"brand_voice": "Energetic and motivational",
|
||||
"channels": ["instagram", "email"]
|
||||
}'
|
||||
```
|
||||
|
||||
## Customization ideas
|
||||
|
||||
- Add a `function` node to call an analytics API and inform audience analysis with real data
|
||||
- Add a `human_input` pause node before final output for editorial approval
|
||||
- Swap `llm_generate` nodes to `llm_tool_use` and add web search tools for competitive research
|
||||
- Add an image generation tool to produce visual assets alongside copy
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
marketing_agent/
|
||||
├── __init__.py # Package exports
|
||||
├── __main__.py # CLI entry point
|
||||
├── agent.py # Goal, edges, graph spec, MarketingAgent class
|
||||
├── config.py # RuntimeConfig and AgentMetadata
|
||||
├── nodes/
|
||||
│ └── __init__.py # NodeSpec definitions
|
||||
└── README.md # This file
|
||||
```
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Marketing Content Agent — template example."""
|
||||
|
||||
from .agent import MarketingAgent, goal, edges, nodes
|
||||
from .config import default_config
|
||||
|
||||
__all__ = ["MarketingAgent", "goal", "edges", "nodes", "default_config"]
|
||||
@@ -0,0 +1,31 @@
|
||||
"""CLI entry point for Marketing Content Agent."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
from .agent import MarketingAgent
|
||||
from .config import default_config
|
||||
|
||||
# Simple CLI — replace with Click for production use
|
||||
input_data = {
|
||||
"product_description": "An AI-powered project management tool for remote teams",
|
||||
"target_audience": "Engineering managers at mid-size tech companies",
|
||||
"brand_voice": "Professional but approachable, concise, data-driven",
|
||||
"channels": ["email", "twitter", "linkedin"],
|
||||
}
|
||||
|
||||
# Accept JSON input from command line
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--input":
|
||||
input_data = json.loads(sys.argv[2])
|
||||
|
||||
agent = MarketingAgent(config=default_config)
|
||||
result = asyncio.run(agent.run(input_data))
|
||||
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,161 @@
|
||||
"""Marketing Content Agent — goal, edges, graph spec, and agent class."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import EdgeCondition, EdgeSpec, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
|
||||
from .config import default_config, RuntimeConfig
|
||||
from .nodes import all_nodes
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Goal
|
||||
# ---------------------------------------------------------------------------
|
||||
goal = Goal(
|
||||
id="marketing-content",
|
||||
name="Marketing Content Generator",
|
||||
description=(
|
||||
"Generate targeted marketing content across multiple channels "
|
||||
"for a given product and audience."
|
||||
),
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="audience-analyzed",
|
||||
description="Audience analysis is produced with demographics and pain points",
|
||||
metric="output_contains",
|
||||
target="audience_analysis",
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="content-generated",
|
||||
description="At least 2 channel-specific content pieces are generated",
|
||||
metric="custom",
|
||||
target="len(content) >= 2",
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="variants-provided",
|
||||
description="A/B variants are provided for each content piece",
|
||||
metric="custom",
|
||||
target="all variants present",
|
||||
),
|
||||
],
|
||||
constraints=[
|
||||
Constraint(
|
||||
id="no-competitor-names",
|
||||
description="No competitor brand names in generated content",
|
||||
constraint_type="hard",
|
||||
category="safety",
|
||||
),
|
||||
Constraint(
|
||||
id="social-length",
|
||||
description="Social media content should be under 280 characters",
|
||||
constraint_type="soft",
|
||||
category="quality",
|
||||
),
|
||||
],
|
||||
input_schema={
|
||||
"product_description": {"type": "string"},
|
||||
"target_audience": {"type": "string"},
|
||||
"brand_voice": {"type": "string"},
|
||||
"channels": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
output_schema={
|
||||
"audience_analysis": {"type": "object"},
|
||||
"content": {"type": "array"},
|
||||
},
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edges
|
||||
# ---------------------------------------------------------------------------
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="analyze-to-generate",
|
||||
source="analyze-audience",
|
||||
target="generate-content",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
description="After audience analysis, generate content",
|
||||
),
|
||||
EdgeSpec(
|
||||
id="generate-to-review",
|
||||
source="generate-content",
|
||||
target="review-and-refine",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
description="After content generation, review and refine",
|
||||
),
|
||||
EdgeSpec(
|
||||
id="review-to-regenerate",
|
||||
source="review-and-refine",
|
||||
target="generate-content",
|
||||
condition=EdgeCondition.CONDITIONAL,
|
||||
condition_expr="needs_revision == True",
|
||||
priority=10,
|
||||
description="If revision needed, loop back to content generation",
|
||||
),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph structure
|
||||
# ---------------------------------------------------------------------------
|
||||
entry_node = "analyze-audience"
|
||||
entry_points = {"start": "analyze-audience"}
|
||||
terminal_nodes = ["review-and-refine"]
|
||||
pause_nodes = []
|
||||
nodes = all_nodes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent class
|
||||
# ---------------------------------------------------------------------------
|
||||
class MarketingAgent:
|
||||
"""Multi-channel marketing content generator agent."""
|
||||
|
||||
def __init__(self, config: RuntimeConfig | None = None):
|
||||
self.config = config or default_config
|
||||
self.goal = goal
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
self.entry_node = entry_node
|
||||
self.terminal_nodes = terminal_nodes
|
||||
self.executor = None
|
||||
|
||||
def _build_graph(self) -> GraphSpec:
|
||||
return GraphSpec(
|
||||
id="marketing-content-graph",
|
||||
goal_id=self.goal.id,
|
||||
entry_node=self.entry_node,
|
||||
entry_points=entry_points,
|
||||
terminal_nodes=self.terminal_nodes,
|
||||
pause_nodes=pause_nodes,
|
||||
nodes=self.nodes,
|
||||
edges=self.edges,
|
||||
default_model=self.config.model,
|
||||
max_tokens=self.config.max_tokens,
|
||||
description="Marketing content generation workflow",
|
||||
)
|
||||
|
||||
def _create_executor(self):
|
||||
runtime = Runtime(storage_path=Path(self.config.storage_path).expanduser())
|
||||
llm = AnthropicProvider(model=self.config.model)
|
||||
self.executor = GraphExecutor(runtime=runtime, llm=llm)
|
||||
return self.executor
|
||||
|
||||
async def run(self, context: dict, mock_mode: bool = False) -> dict:
|
||||
graph = self._build_graph()
|
||||
executor = self._create_executor()
|
||||
result = await executor.execute(
|
||||
graph=graph,
|
||||
goal=self.goal,
|
||||
input_data=context,
|
||||
)
|
||||
return {
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"steps": result.steps_executed,
|
||||
"path": result.path,
|
||||
}
|
||||
|
||||
|
||||
default_agent = MarketingAgent()
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Runtime configuration for Marketing Content Agent."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfig:
|
||||
model: str = "claude-haiku-4-5-20251001"
|
||||
max_tokens: int = 2048
|
||||
storage_path: str = "~/.hive/storage"
|
||||
mock_mode: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
name: str = "marketing_agent"
|
||||
version: str = "0.1.0"
|
||||
description: str = "Multi-channel marketing content generator"
|
||||
author: str = ""
|
||||
tags: list[str] = field(
|
||||
default_factory=lambda: ["marketing", "content", "template"]
|
||||
)
|
||||
|
||||
|
||||
default_config = RuntimeConfig()
|
||||
metadata = AgentMetadata()
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Node definitions for Marketing Content Agent."""
|
||||
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 1: Analyze the target audience
|
||||
# ---------------------------------------------------------------------------
|
||||
analyze_audience_node = NodeSpec(
|
||||
id="analyze-audience",
|
||||
name="Analyze Audience",
|
||||
description="Produce a structured audience analysis from the product and target audience description.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["product_description", "target_audience"],
|
||||
output_keys=["audience_analysis"],
|
||||
system_prompt="""\
|
||||
You are a marketing strategist. Analyze the target audience for a product.
|
||||
|
||||
Product: {product_description}
|
||||
Target audience: {target_audience}
|
||||
|
||||
Produce a structured analysis as raw JSON (no markdown):
|
||||
{{
|
||||
"audience_analysis": {{
|
||||
"demographics": "...",
|
||||
"pain_points": ["..."],
|
||||
"motivations": ["..."],
|
||||
"preferred_channels": ["..."],
|
||||
"messaging_angle": "..."
|
||||
}}
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 2: Generate channel-specific content with A/B variants
|
||||
# ---------------------------------------------------------------------------
|
||||
generate_content_node = NodeSpec(
|
||||
id="generate-content",
|
||||
name="Generate Content",
|
||||
description="Create marketing copy for each requested channel with two variants per channel.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["product_description", "audience_analysis", "brand_voice", "channels"],
|
||||
output_keys=["content"],
|
||||
system_prompt="""\
|
||||
You are a marketing copywriter. Generate content for each channel.
|
||||
|
||||
Product: {product_description}
|
||||
Audience analysis: {audience_analysis}
|
||||
Brand voice: {brand_voice}
|
||||
Channels: {channels}
|
||||
|
||||
For each channel, produce two variants (A and B).
|
||||
|
||||
Output as raw JSON (no markdown):
|
||||
{{
|
||||
"content": [
|
||||
{{
|
||||
"channel": "twitter",
|
||||
"variant_a": "...",
|
||||
"variant_b": "..."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node 3: Review and refine content
|
||||
# ---------------------------------------------------------------------------
|
||||
review_and_refine_node = NodeSpec(
|
||||
id="review-and-refine",
|
||||
name="Review and Refine",
|
||||
description="Review generated content for brand voice alignment and channel fit. Revise if needed.",
|
||||
node_type="llm_generate",
|
||||
input_keys=["content", "brand_voice"],
|
||||
output_keys=["content", "needs_revision"],
|
||||
system_prompt="""\
|
||||
You are a senior marketing editor. Review the following content for brand
|
||||
voice alignment, clarity, and channel appropriateness.
|
||||
|
||||
Content: {content}
|
||||
Brand voice: {brand_voice}
|
||||
|
||||
If any piece needs revision, fix it and set needs_revision to true.
|
||||
If everything looks good, return the content unchanged with needs_revision false.
|
||||
|
||||
Output as raw JSON (no markdown):
|
||||
{{
|
||||
"content": [...],
|
||||
"needs_revision": false
|
||||
}}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# All nodes for easy import
|
||||
all_nodes = [
|
||||
analyze_audience_node,
|
||||
generate_content_node,
|
||||
review_and_refine_node,
|
||||
]
|
||||
@@ -0,0 +1,2 @@
|
||||
[tool.uv.workspace]
|
||||
members = ["core", "tools"]
|
||||
+149
-153
@@ -11,6 +11,14 @@
|
||||
|
||||
set -e
|
||||
|
||||
# Detect Bash version for compatibility
|
||||
BASH_MAJOR_VERSION="${BASH_VERSINFO[0]}"
|
||||
USE_ASSOC_ARRAYS=false
|
||||
if [ "$BASH_MAJOR_VERSION" -ge 4 ]; then
|
||||
USE_ASSOC_ARRAYS=true
|
||||
fi
|
||||
echo "[debug] Bash version: ${BASH_VERSION}"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
@@ -52,7 +60,7 @@ prompt_choice() {
|
||||
echo -e "${BOLD}$prompt${NC}"
|
||||
for opt in "${options[@]}"; do
|
||||
echo -e " ${CYAN}$i)${NC} $opt"
|
||||
((i++))
|
||||
i=$((i + 1))
|
||||
done
|
||||
echo ""
|
||||
|
||||
@@ -60,7 +68,8 @@ prompt_choice() {
|
||||
while true; do
|
||||
read -r -p "Enter choice (1-${#options[@]}): " choice
|
||||
if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#options[@]}" ]; then
|
||||
return $((choice - 1))
|
||||
PROMPT_CHOICE=$((choice - 1))
|
||||
return 0
|
||||
fi
|
||||
echo -e "${RED}Invalid choice. Please enter 1-${#options[@]}${NC}"
|
||||
done
|
||||
@@ -174,63 +183,26 @@ echo ""
|
||||
echo -e "${DIM}This may take a minute...${NC}"
|
||||
echo ""
|
||||
|
||||
# Upgrade pip, setuptools, and wheel
|
||||
echo -n " Upgrading pip... "
|
||||
$PYTHON_CMD -m pip install --upgrade pip setuptools wheel > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install framework package from core/
|
||||
echo -n " Installing framework... "
|
||||
cd "$SCRIPT_DIR/core"
|
||||
# Install all workspace packages (core + tools) from workspace root
|
||||
echo -n " Installing workspace packages... "
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
if [ -f "pyproject.toml" ]; then
|
||||
uv sync > /dev/null 2>&1
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${GREEN} ✓ framework package installed${NC}"
|
||||
if uv sync > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ workspace packages installed${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ framework installation had issues (may be OK)${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}failed (no pyproject.toml)${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install aden_tools package from tools/
|
||||
echo -n " Installing tools... "
|
||||
cd "$SCRIPT_DIR/tools"
|
||||
|
||||
if [ -f "pyproject.toml" ]; then
|
||||
uv sync > /dev/null 2>&1
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${GREEN} ✓ aden_tools package installed${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ aden_tools installation failed${NC}"
|
||||
echo -e "${RED} ✗ workspace installation failed${NC}"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
echo -e "${RED}failed (no root pyproject.toml)${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install MCP dependencies
|
||||
echo -n " Installing MCP... "
|
||||
$PYTHON_CMD -m pip install mcp fastmcp > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Fix openai version compatibility
|
||||
echo -n " Checking openai... "
|
||||
$PYTHON_CMD -m pip install "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install click for CLI
|
||||
echo -n " Installing CLI tools... "
|
||||
$PYTHON_CMD -m pip install click > /dev/null 2>&1
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
|
||||
# Install Playwright browser
|
||||
echo -n " Installing Playwright browser... "
|
||||
if $PYTHON_CMD -c "import playwright" > /dev/null 2>&1; then
|
||||
if $PYTHON_CMD -m playwright install chromium > /dev/null 2>&1; then
|
||||
if uv run python -c "import playwright" > /dev/null 2>&1; then
|
||||
if uv run python -m playwright install chromium > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}⏭${NC}"
|
||||
@@ -249,33 +221,6 @@ echo ""
|
||||
# ============================================================
|
||||
|
||||
echo -e "${YELLOW}⬢${NC} ${BLUE}${BOLD}Step 3: Configuring LLM provider...${NC}"
|
||||
# Install MCP dependencies (in tools venv)
|
||||
echo " Installing MCP dependencies..."
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
uv pip install --python "$TOOLS_PYTHON" mcp fastmcp > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ MCP dependencies installed${NC}"
|
||||
|
||||
# Fix openai version compatibility (in tools venv)
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
OPENAI_VERSION=$($TOOLS_PYTHON -c "import openai; print(openai.__version__)" 2>/dev/null || echo "not_installed")
|
||||
if [ "$OPENAI_VERSION" = "not_installed" ]; then
|
||||
echo " Installing openai package..."
|
||||
uv pip install --python "$TOOLS_PYTHON" "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ openai installed${NC}"
|
||||
elif [[ "$OPENAI_VERSION" =~ ^0\. ]]; then
|
||||
echo " Upgrading openai to 1.x+ for litellm compatibility..."
|
||||
uv pip install --python "$TOOLS_PYTHON" --upgrade "openai>=1.0.0" > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ openai upgraded${NC}"
|
||||
else
|
||||
echo -e "${GREEN} ✓ openai $OPENAI_VERSION is compatible${NC}"
|
||||
fi
|
||||
|
||||
# Install click for CLI (in tools venv)
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
uv pip install --python "$TOOLS_PYTHON" click > /dev/null 2>&1
|
||||
echo -e "${GREEN} ✓ click installed${NC}"
|
||||
|
||||
cd "$SCRIPT_DIR"
|
||||
echo ""
|
||||
|
||||
# ============================================================
|
||||
@@ -287,42 +232,28 @@ echo ""
|
||||
|
||||
IMPORT_ERRORS=0
|
||||
|
||||
# Test imports using their respective venvs
|
||||
CORE_PYTHON="$SCRIPT_DIR/core/.venv/bin/python"
|
||||
TOOLS_PYTHON="$SCRIPT_DIR/tools/.venv/bin/python"
|
||||
|
||||
# Test framework import (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "import framework" > /dev/null 2>&1; then
|
||||
# Test imports using workspace venv via uv run
|
||||
if uv run python -c "import framework" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ framework imports OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ framework import failed${NC}"
|
||||
IMPORT_ERRORS=$((IMPORT_ERRORS + 1))
|
||||
fi
|
||||
|
||||
# Test aden_tools import (from tools venv)
|
||||
if [ -f "$TOOLS_PYTHON" ] && $TOOLS_PYTHON -c "import aden_tools" > /dev/null 2>&1; then
|
||||
if uv run python -c "import aden_tools" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ aden_tools imports OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ aden_tools import failed${NC}"
|
||||
IMPORT_ERRORS=$((IMPORT_ERRORS + 1))
|
||||
fi
|
||||
|
||||
# Test litellm import (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK (core)${NC}"
|
||||
if uv run python -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ litellm import issues in core (may be OK)${NC}"
|
||||
echo -e "${YELLOW} ⚠ litellm import issues (may be OK)${NC}"
|
||||
fi
|
||||
|
||||
# Test litellm import (from tools venv)
|
||||
if [ -f "$TOOLS_PYTHON" ] && $TOOLS_PYTHON -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ litellm imports OK (tools)${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ litellm import issues in tools (may be OK)${NC}"
|
||||
fi
|
||||
|
||||
# Test MCP server module (from core venv)
|
||||
if [ -f "$CORE_PYTHON" ] && $CORE_PYTHON -c "from framework.mcp import agent_builder_server" > /dev/null 2>&1; then
|
||||
if uv run python -c "from framework.mcp import agent_builder_server" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN} ✓ MCP server module OK${NC}"
|
||||
else
|
||||
echo -e "${RED} ✗ MCP server module failed${NC}"
|
||||
@@ -344,53 +275,105 @@ echo ""
|
||||
echo -e "${BLUE}Step 4: Verifying Claude Code skills...${NC}"
|
||||
echo ""
|
||||
|
||||
# Provider data as parallel indexed arrays (Bash 3.2 compatible — no declare -A)
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_DISPLAY_NAMES=("Anthropic (Claude)" "OpenAI (GPT)" "Google Gemini" "Google AI" "Groq" "Cerebras" "Mistral" "Together AI" "DeepSeek")
|
||||
PROVIDER_ID_LIST=(anthropic openai gemini google groq cerebras mistral together deepseek)
|
||||
# Provider configuration - use associative arrays (Bash 4+) or indexed arrays (Bash 3.2)
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
# Bash 4+ - use associative arrays (cleaner and more efficient)
|
||||
declare -A PROVIDER_NAMES=(
|
||||
["ANTHROPIC_API_KEY"]="Anthropic (Claude)"
|
||||
["OPENAI_API_KEY"]="OpenAI (GPT)"
|
||||
["GEMINI_API_KEY"]="Google Gemini"
|
||||
["GOOGLE_API_KEY"]="Google AI"
|
||||
["GROQ_API_KEY"]="Groq"
|
||||
["CEREBRAS_API_KEY"]="Cerebras"
|
||||
["MISTRAL_API_KEY"]="Mistral"
|
||||
["TOGETHER_API_KEY"]="Together AI"
|
||||
["DEEPSEEK_API_KEY"]="DeepSeek"
|
||||
)
|
||||
|
||||
# Default models by provider id (parallel arrays)
|
||||
MODEL_PROVIDER_IDS=(anthropic openai gemini groq cerebras mistral together_ai deepseek)
|
||||
MODEL_DEFAULTS=("claude-sonnet-4-5-20250929" "gpt-4o" "gemini-3.0-flash-preview" "moonshotai/kimi-k2-instruct-0905" "zai-glm-4.7" "mistral-large-latest" "meta-llama/Llama-3.3-70B-Instruct-Turbo" "deepseek-chat")
|
||||
declare -A PROVIDER_IDS=(
|
||||
["ANTHROPIC_API_KEY"]="anthropic"
|
||||
["OPENAI_API_KEY"]="openai"
|
||||
["GEMINI_API_KEY"]="gemini"
|
||||
["GOOGLE_API_KEY"]="google"
|
||||
["GROQ_API_KEY"]="groq"
|
||||
["CEREBRAS_API_KEY"]="cerebras"
|
||||
["MISTRAL_API_KEY"]="mistral"
|
||||
["TOGETHER_API_KEY"]="together"
|
||||
["DEEPSEEK_API_KEY"]="deepseek"
|
||||
)
|
||||
|
||||
# Helper: get provider display name for an env var
|
||||
get_provider_name() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_DISPLAY_NAMES[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
declare -A DEFAULT_MODELS=(
|
||||
["anthropic"]="claude-sonnet-4-5-20250929"
|
||||
["openai"]="gpt-4o"
|
||||
["gemini"]="gemini-3.0-flash-preview"
|
||||
["groq"]="moonshotai/kimi-k2-instruct-0905"
|
||||
["cerebras"]="zai-glm-4.7"
|
||||
["mistral"]="mistral-large-latest"
|
||||
["together_ai"]="meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||
["deepseek"]="deepseek-chat"
|
||||
)
|
||||
|
||||
# Helper: get provider id for an env var
|
||||
get_provider_id() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_ID_LIST[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
# Helper functions for Bash 4+
|
||||
get_provider_name() {
|
||||
echo "${PROVIDER_NAMES[$1]}"
|
||||
}
|
||||
|
||||
# Helper: get default model for a provider id
|
||||
get_default_model() {
|
||||
local provider_id="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#MODEL_PROVIDER_IDS[@]} ]; do
|
||||
if [ "${MODEL_PROVIDER_IDS[$i]}" = "$provider_id" ]; then
|
||||
echo "${MODEL_DEFAULTS[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
get_provider_id() {
|
||||
echo "${PROVIDER_IDS[$1]}"
|
||||
}
|
||||
|
||||
get_default_model() {
|
||||
echo "${DEFAULT_MODELS[$1]}"
|
||||
}
|
||||
else
|
||||
# Bash 3.2 - use parallel indexed arrays
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_DISPLAY_NAMES=("Anthropic (Claude)" "OpenAI (GPT)" "Google Gemini" "Google AI" "Groq" "Cerebras" "Mistral" "Together AI" "DeepSeek")
|
||||
PROVIDER_ID_LIST=(anthropic openai gemini google groq cerebras mistral together deepseek)
|
||||
|
||||
# Default models by provider id (parallel arrays)
|
||||
MODEL_PROVIDER_IDS=(anthropic openai gemini groq cerebras mistral together_ai deepseek)
|
||||
MODEL_DEFAULTS=("claude-sonnet-4-5-20250929" "gpt-4o" "gemini-3.0-flash-preview" "moonshotai/kimi-k2-instruct-0905" "zai-glm-4.7" "mistral-large-latest" "meta-llama/Llama-3.3-70B-Instruct-Turbo" "deepseek-chat")
|
||||
|
||||
# Helper: get provider display name for an env var
|
||||
get_provider_name() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_DISPLAY_NAMES[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
|
||||
# Helper: get provider id for an env var
|
||||
get_provider_id() {
|
||||
local env_var="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#PROVIDER_ENV_VARS[@]} ]; do
|
||||
if [ "${PROVIDER_ENV_VARS[$i]}" = "$env_var" ]; then
|
||||
echo "${PROVIDER_ID_LIST[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
|
||||
# Helper: get default model for a provider id
|
||||
get_default_model() {
|
||||
local provider_id="$1"
|
||||
local i=0
|
||||
while [ $i -lt ${#MODEL_PROVIDER_IDS[@]} ]; do
|
||||
if [ "${MODEL_PROVIDER_IDS[$i]}" = "$provider_id" ]; then
|
||||
echo "${MODEL_DEFAULTS[$i]}"
|
||||
return
|
||||
fi
|
||||
i=$((i + 1))
|
||||
done
|
||||
}
|
||||
fi
|
||||
|
||||
# Configuration directory
|
||||
HIVE_CONFIG_DIR="$HOME/.hive"
|
||||
@@ -413,7 +396,7 @@ config = {
|
||||
'model': '$model',
|
||||
'api_key_env_var': '$env_var'
|
||||
},
|
||||
'created_at': '$(date -Iseconds)'
|
||||
'created_at': '$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")'
|
||||
}
|
||||
with open('$HIVE_CONFIG_FILE', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
@@ -442,13 +425,25 @@ FOUND_ENV_VARS=() # Corresponding env var names
|
||||
SELECTED_PROVIDER_ID="" # Will hold the chosen provider ID
|
||||
SELECTED_ENV_VAR="" # Will hold the chosen env var
|
||||
|
||||
for env_var in "${PROVIDER_ENV_VARS[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
# Bash 4+ - iterate over associative array keys
|
||||
for env_var in "${!PROVIDER_NAMES[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
else
|
||||
# Bash 3.2 - iterate over indexed array
|
||||
for env_var in "${PROVIDER_ENV_VARS[@]}"; do
|
||||
value="${!env_var}"
|
||||
if [ -n "$value" ]; then
|
||||
FOUND_PROVIDERS+=("$(get_provider_name "$env_var")")
|
||||
FOUND_ENV_VARS+=("$env_var")
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${#FOUND_PROVIDERS[@]} -gt 0 ]; then
|
||||
echo "Found API keys:"
|
||||
@@ -476,7 +471,7 @@ if [ ${#FOUND_PROVIDERS[@]} -gt 0 ]; then
|
||||
i=1
|
||||
for provider in "${FOUND_PROVIDERS[@]}"; do
|
||||
echo -e " ${CYAN}$i)${NC} $provider"
|
||||
((i++))
|
||||
i=$((i + 1))
|
||||
done
|
||||
echo ""
|
||||
|
||||
@@ -507,7 +502,7 @@ if [ -z "$SELECTED_PROVIDER_ID" ]; then
|
||||
"Groq - Fast, free tier" \
|
||||
"Cerebras - Fast, free tier" \
|
||||
"Skip for now"
|
||||
choice=$?
|
||||
choice=$PROMPT_CHOICE
|
||||
|
||||
case $choice in
|
||||
0)
|
||||
@@ -542,7 +537,8 @@ if [ -z "$SELECTED_PROVIDER_ID" ]; then
|
||||
;;
|
||||
5)
|
||||
echo ""
|
||||
echo -e "${YELLOW}Skipped.${NC} Add your API key later:"
|
||||
echo -e "${YELLOW}Skipped.${NC} An LLM API key is required to test and use worker agents."
|
||||
echo -e "Add your API key later by running:"
|
||||
echo ""
|
||||
echo -e " ${CYAN}echo 'ANTHROPIC_API_KEY=your-key' >> .env${NC}"
|
||||
echo ""
|
||||
@@ -595,7 +591,7 @@ ERRORS=0
|
||||
|
||||
# Test imports
|
||||
echo -n " ⬡ framework... "
|
||||
if $PYTHON_CMD -c "import framework" > /dev/null 2>&1; then
|
||||
if uv run python -c "import framework" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
@@ -603,7 +599,7 @@ else
|
||||
fi
|
||||
|
||||
echo -n " ⬡ aden_tools... "
|
||||
if $PYTHON_CMD -c "import aden_tools" > /dev/null 2>&1; then
|
||||
if uv run python -c "import aden_tools" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
@@ -611,7 +607,7 @@ else
|
||||
fi
|
||||
|
||||
echo -n " ⬡ litellm... "
|
||||
if $PYTHON_CMD -c "import litellm" > /dev/null 2>&1; then
|
||||
if uv run python -c "import litellm" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}--${NC}"
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
<#
|
||||
|
||||
setup-python.ps1 - Python Environment Setup for Aden Agent Framework
|
||||
|
||||
This script sets up the Python environment with all required packages
|
||||
for building and running goal-driven agents.
|
||||
#>
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# Colors for output
|
||||
$RED = "Red"
|
||||
$GREEN = "Green"
|
||||
$YELLOW = "Yellow"
|
||||
$BLUE = "Cyan"
|
||||
|
||||
# Get the directory where this script is located
|
||||
$SCRIPT_DIR = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||
$PROJECT_ROOT = Split-Path -Parent $SCRIPT_DIR
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "=================================================="
|
||||
Write-Host " Aden Agent Framework - Python Setup"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
# Check for Python
|
||||
$pythonCmd = $null
|
||||
if (Get-Command python -ErrorAction SilentlyContinue) {
|
||||
$pythonCmd = "python"
|
||||
}
|
||||
|
||||
if (-not $pythonCmd) {
|
||||
Write-Host "Error: Python is not installed." -ForegroundColor $RED
|
||||
Write-Host "Please install Python 3.11+ from https://python.org"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check Python version
|
||||
$versionInfo = & $pythonCmd -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')"
|
||||
$major = & $pythonCmd -c "import sys; print(sys.version_info.major)"
|
||||
$minor = & $pythonCmd -c "import sys; print(sys.version_info.minor)"
|
||||
|
||||
Write-Host "Detected Python: $versionInfo" -ForegroundColor $BLUE
|
||||
|
||||
if ($major -lt 3 -or ($major -eq 3 -and $minor -lt 11)) {
|
||||
Write-Host "Error: Python 3.11+ is required (found $versionInfo)" -ForegroundColor $RED
|
||||
Write-Host "Please upgrade your Python installation"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if ($minor -lt 11) {
|
||||
Write-Host "Warning: Python 3.11+ is recommended for best compatibility" -ForegroundColor $YELLOW
|
||||
Write-Host "You have Python $versionInfo which may work but is not officially supported" -ForegroundColor $YELLOW
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
Write-Host "[OK] Python version check passed" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Create and activate virtual environment
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Setting up Python Virtual Environment"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
$VENV_PATH = Join-Path $PROJECT_ROOT ".venv"
|
||||
$VENV_PYTHON = Join-Path $VENV_PATH "Scripts\python.exe"
|
||||
$VENV_ACTIVATE = Join-Path $VENV_PATH "Scripts\Activate.ps1"
|
||||
|
||||
if (-not (Test-Path $VENV_PYTHON)) {
|
||||
Write-Host "Creating virtual environment at .venv..."
|
||||
& $pythonCmd -m venv $VENV_PATH
|
||||
Write-Host "[OK] Virtual environment created" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[OK] Virtual environment already exists" -ForegroundColor $GREEN
|
||||
}
|
||||
|
||||
# Activate venv
|
||||
Write-Host "Activating virtual environment..."
|
||||
& $VENV_ACTIVATE
|
||||
Write-Host "[OK] Virtual environment activated" -ForegroundColor $GREEN
|
||||
|
||||
# From here on, always use venv python
|
||||
$pythonCmd = $VENV_PYTHON
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Check for pip
|
||||
try {
|
||||
& $pythonCmd -m pip --version | Out-Null
|
||||
}
|
||||
catch {
|
||||
Write-Host "Error: pip is not installed" -ForegroundColor $RED
|
||||
Write-Host "Please install pip for Python $versionInfo"
|
||||
exit 1
|
||||
}
|
||||
|
||||
Write-Host "[OK] pip detected" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Upgrade pip, setuptools, and wheel
|
||||
Write-Host "Upgrading pip, setuptools, and wheel..."
|
||||
& $pythonCmd -m pip install --upgrade pip setuptools wheel
|
||||
Write-Host "[OK] Core packages upgraded" -ForegroundColor $GREEN
|
||||
Write-Host ""
|
||||
|
||||
# Install core framework package
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Installing Core Framework Package"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location "$PROJECT_ROOT\core"
|
||||
|
||||
if (Test-Path "pyproject.toml") {
|
||||
Write-Host "Installing framework from core/ (editable mode)..."
|
||||
& $pythonCmd -m pip install -e . | Out-Null
|
||||
Write-Host "[OK] Framework package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[WARN] No pyproject.toml found in core/, skipping framework installation" -ForegroundColor $YELLOW
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Install tools package
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Installing Tools Package (aden_tools)"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location "$PROJECT_ROOT\tools"
|
||||
|
||||
if (Test-Path "pyproject.toml") {
|
||||
Write-Host "Installing aden_tools from tools/ (editable mode)..."
|
||||
& $pythonCmd -m pip install -e . | Out-Null
|
||||
Write-Host "[OK] Tools package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "Error: No pyproject.toml found in tools/" -ForegroundColor $RED
|
||||
exit 1
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Fix openai version compatibility with litellm
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Fixing Package Compatibility"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
try {
|
||||
$openaiVersion = & $pythonCmd -c "import openai; print(openai.__version__)"
|
||||
}
|
||||
catch {
|
||||
$openaiVersion = "not_installed"
|
||||
}
|
||||
|
||||
if ($openaiVersion -eq "not_installed") {
|
||||
Write-Host "Installing openai package..."
|
||||
& $pythonCmd -m pip install "openai>=1.0.0" | Out-Null
|
||||
Write-Host "[OK] openai package installed" -ForegroundColor $GREEN
|
||||
}
|
||||
elseif ($openaiVersion.StartsWith("0.")) {
|
||||
Write-Host "Found old openai version: $openaiVersion" -ForegroundColor $YELLOW
|
||||
Write-Host "Upgrading to openai 1.x+ for litellm compatibility..."
|
||||
& $pythonCmd -m pip install --upgrade "openai>=1.0.0" | Out-Null
|
||||
$openaiVersion = & $pythonCmd -c "import openai; print(openai.__version__)"
|
||||
Write-Host "[OK] openai upgraded to $openaiVersion" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[OK] openai $openaiVersion is compatible" -ForegroundColor $GREEN
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Verify installations
|
||||
Write-Host "=================================================="
|
||||
Write-Host "Verifying Installation"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
|
||||
Set-Location $PROJECT_ROOT
|
||||
|
||||
# Test framework import
|
||||
& $pythonCmd -c "import framework" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] framework package imports successfully" -ForegroundColor Green
|
||||
}
|
||||
else {
|
||||
Write-Host "[FAIL] framework package import failed" -ForegroundColor Red
|
||||
}
|
||||
|
||||
# Test aden_tools import
|
||||
& $pythonCmd -c "import aden_tools" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] aden_tools package imports successfully" -ForegroundColor Green
|
||||
}
|
||||
else {
|
||||
Write-Host "[FAIL] aden_tools package import failed" -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Test litellm
|
||||
& $pythonCmd -c "import litellm" 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host "[OK] litellm package imports successfully" -ForegroundColor $GREEN
|
||||
}
|
||||
else {
|
||||
Write-Host "[WARN] litellm import had issues (may be OK if not using LLM features)" -ForegroundColor $YELLOW
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
|
||||
# Print agent commands
|
||||
Write-Host "=================================================="
|
||||
Write-Host " Setup Complete!"
|
||||
Write-Host "=================================================="
|
||||
Write-Host ""
|
||||
Write-Host "Python packages installed:"
|
||||
Write-Host " - framework (core agent runtime)"
|
||||
Write-Host " - aden_tools (tools and MCP servers)"
|
||||
Write-Host " - All dependencies and compatibility fixes applied"
|
||||
Write-Host ""
|
||||
Write-Host "To run agents on Windows (PowerShell):"
|
||||
Write-Host ""
|
||||
Write-Host "1. From the project root, set PYTHONPATH:"
|
||||
Write-Host " `$env:PYTHONPATH=`"core;exports`""
|
||||
Write-Host ""
|
||||
Write-Host "2. Run an agent command:"
|
||||
Write-Host " python -m agent_name validate"
|
||||
Write-Host " python -m agent_name info"
|
||||
Write-Host " python -m agent_name run --input '{...}'"
|
||||
Write-Host ""
|
||||
Write-Host "Example (support_ticket_agent):"
|
||||
Write-Host " python -m support_ticket_agent validate"
|
||||
Write-Host " python -m support_ticket_agent info"
|
||||
Write-Host " python -m support_ticket_agent run --input '{""ticket_content"":""..."",""customer_id"":""..."",""ticket_id"":""...""}'"
|
||||
Write-Host ""
|
||||
Write-Host "Notes:"
|
||||
Write-Host " - Ensure the virtual environment is activated (.venv)"
|
||||
Write-Host " - PYTHONPATH must be set in each new PowerShell session"
|
||||
Write-Host ""
|
||||
Write-Host "Documentation:"
|
||||
Write-Host " $PROJECT_ROOT\README.md"
|
||||
Write-Host ""
|
||||
Write-Host "Agent Examples:"
|
||||
Write-Host " $PROJECT_ROOT\exports\"
|
||||
Write-Host ""
|
||||
+1
-12
@@ -68,18 +68,7 @@ from starlette.responses import PlainTextResponse # noqa: E402
|
||||
from aden_tools.credentials import CredentialError, CredentialStoreAdapter # noqa: E402
|
||||
from aden_tools.tools import register_all_tools # noqa: E402
|
||||
|
||||
# Create credential store with access to both env vars AND encrypted store
|
||||
# This allows using Aden-synced credentials from ~/.hive/credentials
|
||||
try:
|
||||
from framework.credentials import CredentialStore
|
||||
|
||||
store = CredentialStore.with_encrypted_storage() # ~/.hive/credentials
|
||||
credentials = CredentialStoreAdapter(store)
|
||||
logger.info("Using CredentialStoreAdapter with encrypted storage")
|
||||
except Exception as e:
|
||||
# Fall back to env-only adapter if encrypted storage fails
|
||||
credentials = CredentialStoreAdapter.with_env_storage()
|
||||
logger.warning(f"Falling back to env-only CredentialStoreAdapter: {e}")
|
||||
credentials = CredentialStoreAdapter.default()
|
||||
|
||||
# Tier 1: Validate startup-required credentials (if any)
|
||||
try:
|
||||
|
||||
@@ -30,6 +30,7 @@ dependencies = [
|
||||
"playwright-stealth>=1.0.5",
|
||||
"litellm>=1.81.0",
|
||||
"resend>=2.0.0",
|
||||
"framework",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -54,6 +55,9 @@ all = [
|
||||
"duckdb>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
framework = { workspace = true }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user